Skip to main content

rlx_sam/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 host-side preprocessing.
17//!
18//! Two host-side tensor manipulations live here instead of in the IR
19//! graph:
20//!   1. **Image preprocess** — resize the long side to 1024 (preserve
21//!      aspect ratio), normalize, zero-pad to 1024×1024 NCHW. Matches
22//!      `sam.rs::preprocess()` in candle exactly.
23//!   2. **Patch embedding** — Conv2d(in=3, out=embed_dim, k=16, s=16)
24//!      with no padding, equivalent to per-patch matmul. We do it on
25//!      the CPU for the same reason as DINOv2: rlx-ir has no f32
26//!      forward Conv2d. The output is the input to the encoder graph,
27//!      already in `[B, H, W, C]` BHWC layout (matching SAM's internal
28//!      convention).
29
30use super::config::{
31    SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD, SamEncoderConfig,
32};
33use anyhow::{Result, ensure};
34use rlx_core::weight_map::WeightMap;
35
36/// Weights extracted from the safetensors checkpoint that the host
37/// uses *before* the encoder graph runs.
38pub struct SamPreprocessWeights {
39    /// Patch projection weight `[E, 3, 16, 16]` flattened+transposed to
40    /// `[3·16·16, E]` for row-major sgemm.
41    pub patch_proj_w: Vec<f32>,
42    /// Patch projection bias `[E]`.
43    pub patch_proj_b: Vec<f32>,
44    /// Optional absolute positional embedding `[1, hw, hw, E]`
45    /// flattened to `[hw · hw · E]`. Added to the patch embeddings
46    /// before they enter the IR graph.
47    pub pos_embed: Option<Vec<f32>>,
48    pub embed_dim: usize,
49    pub hw: usize,
50}
51
52pub(super) fn extract_preprocess_weights(
53    weights: &mut WeightMap,
54    cfg: &SamEncoderConfig,
55) -> Result<SamPreprocessWeights> {
56    let e = cfg.embed_dim;
57    let hw = cfg.num_patches_per_side();
58    let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
59
60    // image_encoder.patch_embed.proj.weight  [E, 3, 16, 16]
61    let (proj_raw, proj_shape) = weights.take("image_encoder.patch_embed.proj.weight")?;
62    ensure!(
63        proj_shape == vec![e, 3, SAM_PATCH_SIZE, SAM_PATCH_SIZE],
64        "patch_embed.proj.weight expected [{e}, 3, {SAM_PATCH_SIZE}, {SAM_PATCH_SIZE}], got {proj_shape:?}"
65    );
66    // Flatten [E, 3, 16, 16] → [E, patch_dim] (already contiguous) then
67    // transpose to [patch_dim, E].
68    let mut patch_proj_w = vec![0f32; e * pd];
69    for ei in 0..e {
70        for d in 0..pd {
71            patch_proj_w[d * e + ei] = proj_raw[ei * pd + d];
72        }
73    }
74    let (patch_proj_b, _) = weights.take("image_encoder.patch_embed.proj.bias")?;
75
76    let pos_embed = if cfg.use_abs_pos {
77        let (data, shape) = weights.take("image_encoder.pos_embed")?;
78        ensure!(
79            shape == vec![1, hw, hw, e],
80            "pos_embed expected [1, {hw}, {hw}, {e}], got {shape:?}"
81        );
82        Some(data)
83    } else {
84        None
85    };
86
87    Ok(SamPreprocessWeights {
88        patch_proj_w,
89        patch_proj_b,
90        pos_embed,
91        embed_dim: e,
92        hw,
93    })
94}
95
96/// Resize an RGB u8 image to fit within `SAM_IMG_SIZE` on the long
97/// side (aspect-ratio preserved), normalize with SAM's pixel stats,
98/// and zero-pad to a square `[3, 1024, 1024]` NCHW f32 tensor.
99///
100/// `rgb` is `H_in · W_in · 3` row-major (u8). Returns `(nchw, (h, w))`
101/// where `(h, w)` are the resized (pre-pad) dimensions — needed at the
102/// decoder to crop predicted masks back to the original aspect ratio.
103pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> (Vec<f32>, (usize, usize)) {
104    let scale = (SAM_IMG_SIZE as f32) / (h_in.max(w_in) as f32);
105    let new_h = ((h_in as f32) * scale).round() as usize;
106    let new_w = ((w_in as f32) * scale).round() as usize;
107    // Bilinear resize.
108    let mut resized = vec![0f32; 3 * new_h * new_w];
109    let sx = (w_in as f32 - 1.0) / (new_w.max(1) as f32 - 1.0).max(1.0);
110    let sy = (h_in as f32 - 1.0) / (new_h.max(1) as f32 - 1.0).max(1.0);
111    for y in 0..new_h {
112        let fy = y as f32 * sy;
113        let y0 = fy.floor() as usize;
114        let y1 = (y0 + 1).min(h_in - 1);
115        let dy = fy - y0 as f32;
116        for x in 0..new_w {
117            let fx = x as f32 * sx;
118            let x0 = fx.floor() as usize;
119            let x1 = (x0 + 1).min(w_in - 1);
120            let dx = fx - x0 as f32;
121            for c in 0..3 {
122                let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
123                let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
124                let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
125                let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
126                let top = p00 * (1.0 - dx) + p01 * dx;
127                let bot = p10 * (1.0 - dx) + p11 * dx;
128                let v = top * (1.0 - dy) + bot * dy;
129                // SAM normalises raw pixel values (NOT /255 first).
130                resized[c * new_h * new_w + y * new_w + x] =
131                    (v - SAM_PIXEL_MEAN[c]) / SAM_PIXEL_STD[c];
132            }
133        }
134    }
135    // Zero-pad to [3, 1024, 1024].
136    let mut padded = vec![0f32; 3 * SAM_IMG_SIZE * SAM_IMG_SIZE];
137    for c in 0..3 {
138        for y in 0..new_h {
139            let src_row = c * new_h * new_w + y * new_w;
140            let dst_row = c * SAM_IMG_SIZE * SAM_IMG_SIZE + y * SAM_IMG_SIZE;
141            padded[dst_row..dst_row + new_w].copy_from_slice(&resized[src_row..src_row + new_w]);
142        }
143    }
144    (padded, (new_h, new_w))
145}
146
147/// Run the patch embedding (Conv2d k=16 s=16 no padding) on the host
148/// and add the absolute positional embedding. Output is `[1, hw, hw,
149/// E]` BHWC (SAM's internal convention) flattened to a contiguous
150/// f32 buffer for the encoder graph.
151pub fn assemble_patch_tokens(pre: &SamPreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
152    let e = pre.embed_dim;
153    let hw = pre.hw;
154    let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
155    ensure!(
156        image_nchw.len() == 3 * SAM_IMG_SIZE * SAM_IMG_SIZE,
157        "image must be [3, {SAM_IMG_SIZE}, {SAM_IMG_SIZE}] NCHW, got len {}",
158        image_nchw.len()
159    );
160
161    let mut out = vec![0f32; hw * hw * e];
162    let mut patch_buf = vec![0f32; pd];
163    for py in 0..hw {
164        for px in 0..hw {
165            // Fill patch_buf in CHW order matching the Conv2d weight
166            // layout `[E, C=3, ph, pw]` that we flattened earlier.
167            for c in 0..3 {
168                for ry in 0..SAM_PATCH_SIZE {
169                    let src_y = py * SAM_PATCH_SIZE + ry;
170                    for rx in 0..SAM_PATCH_SIZE {
171                        let src_x = px * SAM_PATCH_SIZE + rx;
172                        let src = c * SAM_IMG_SIZE * SAM_IMG_SIZE + src_y * SAM_IMG_SIZE + src_x;
173                        let dst = c * SAM_PATCH_SIZE * SAM_PATCH_SIZE + ry * SAM_PATCH_SIZE + rx;
174                        patch_buf[dst] = image_nchw[src];
175                    }
176                }
177            }
178            // patch_buf @ proj_w + proj_b → embed_dim vector.
179            let row = py * hw + px;
180            let dst = &mut out[row * e..(row + 1) * e];
181            dst.copy_from_slice(&pre.patch_proj_b);
182            for d in 0..pd {
183                let v = patch_buf[d];
184                if v == 0.0 {
185                    continue;
186                }
187                let w_row = &pre.patch_proj_w[d * e..(d + 1) * e];
188                for k in 0..e {
189                    dst[k] += v * w_row[k];
190                }
191            }
192        }
193    }
194
195    // Add absolute positional embedding (broadcast over batch).
196    if let Some(pos) = &pre.pos_embed {
197        ensure!(pos.len() == hw * hw * e, "pos_embed size mismatch");
198        for i in 0..hw * hw * e {
199            out[i] += pos[i];
200        }
201    }
202    Ok(out)
203}