Skip to main content

rlx_sam2/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 host-side preprocessing.
17//!
18//! Three pieces live on the host (outside the IR graph):
19//!   1. **Image preprocess** — square resize to 1024×1024 (bilinear, no
20//!      aspect-ratio preservation), /255, ImageNet normalize. Differs
21//!      from SAM v1 in two ways: SAM 2 *does* divide by 255 first, and
22//!      *does not* keep aspect ratio (the reference's
23//!      `SAM2Transforms.__call__` is just `Resize((1024,1024))` +
24//!      `Normalize`).
25//!   2. **Patch embedding** — Conv2d(in=3, out=embed_dim, k=7, s=4,
26//!      p=3). Overlapping kernel (k > s) so we can't reduce to a plain
27//!      per-patch matmul like SAM v1. Runs as a direct host-side
28//!      Conv2d; cheap (once per image vs. per-block).
29//!   3. **Stage-0 position embedding** — bicubic-interpolated
30//!      `pos_embed` table + tiled `pos_embed_window`, summed into the
31//!      patch tokens before they enter the encoder body. Materialised
32//!      host-side because IR has no bicubic-resample op.
33
34use super::config::{
35    SAM2_IMG_SIZE, SAM2_PATCH_GRID, SAM2_PATCH_KERNEL, SAM2_PATCH_PADDING, SAM2_PATCH_STRIDE,
36    SAM2_PIXEL_MEAN, SAM2_PIXEL_STD, Sam2HieraConfig,
37};
38use anyhow::{Result, ensure};
39use rlx_core::weight_map::WeightMap;
40
41/// Weights extracted from the safetensors checkpoint that the host
42/// uses *before* the encoder graph runs.
43pub struct Sam2PreprocessWeights {
44    /// Patch projection weight in raw `[E, 3, k, k]` NCHW layout. Kept
45    /// raw (not transposed) because the host-side conv2d needs index
46    /// access; see `assemble_patch_tokens`.
47    pub patch_proj_w: Vec<f32>,
48    /// Patch projection bias `[E]`.
49    pub patch_proj_b: Vec<f32>,
50    /// Stage-0 position embedding, already interpolated + tiled to
51    /// `[grid · grid · E]` BHWC. Added to patch tokens before the
52    /// encoder body.
53    pub pos_embed_full: Vec<f32>,
54    pub embed_dim: usize,
55    pub grid: usize, // = 256 (SAM2_PATCH_GRID)
56}
57
58pub(super) fn extract_preprocess_weights(
59    weights: &mut WeightMap,
60    cfg: &Sam2HieraConfig,
61) -> Result<Sam2PreprocessWeights> {
62    let e = cfg.embed_dim;
63    let k = SAM2_PATCH_KERNEL;
64    let grid = SAM2_PATCH_GRID;
65
66    // image_encoder.trunk.patch_embed.proj.weight [E, 3, k, k]
67    let (patch_proj_w, w_shape) = weights.take("image_encoder.trunk.patch_embed.proj.weight")?;
68    ensure!(
69        w_shape == vec![e, 3, k, k],
70        "patch_embed.proj.weight expected [{e}, 3, {k}, {k}], got {w_shape:?}"
71    );
72    let (patch_proj_b, _) = weights.take("image_encoder.trunk.patch_embed.proj.bias")?;
73
74    // image_encoder.trunk.pos_embed         [1, E, Ph, Pw]
75    // image_encoder.trunk.pos_embed_window  [1, E, mu, mu]
76    let (pe_raw, pe_shape) = weights.take("image_encoder.trunk.pos_embed")?;
77    let [ph, pw] = cfg.window_pos_embed_bkg_spatial_size;
78    ensure!(
79        pe_shape == vec![1, e, ph, pw],
80        "pos_embed expected [1, {e}, {ph}, {pw}], got {pe_shape:?}"
81    );
82
83    let mu = cfg.window_size_at_stage(0);
84    let (pew_raw, pew_shape) = weights.take("image_encoder.trunk.pos_embed_window")?;
85    ensure!(
86        pew_shape == vec![1, e, mu, mu],
87        "pos_embed_window expected [1, {e}, {mu}, {mu}], got {pew_shape:?}"
88    );
89
90    let pos_embed_full = build_full_pos_embed(&pe_raw, &pew_raw, e, ph, pw, mu, grid);
91
92    Ok(Sam2PreprocessWeights {
93        patch_proj_w,
94        patch_proj_b,
95        pos_embed_full,
96        embed_dim: e,
97        grid,
98    })
99}
100
101/// Replicates the reference's `Hiera._get_pos_embed`:
102///   - bicubic-interpolate `pos_embed` from `[Ph, Pw]` to `[grid, grid]`
103///   - tile `pos_embed_window` to `[grid, grid]`
104///   - sum, permute to NHWC (`[grid, grid, E]`), flatten
105fn build_full_pos_embed(
106    pe: &[f32],
107    pew: &[f32],
108    e: usize,
109    ph: usize,
110    pw: usize,
111    mu: usize,
112    grid: usize,
113) -> Vec<f32> {
114    debug_assert_eq!(pe.len(), e * ph * pw);
115    debug_assert_eq!(pew.len(), e * mu * mu);
116    debug_assert_eq!(
117        grid % mu,
118        0,
119        "Hiera pos_embed_window must tile grid evenly (grid={grid}, mu={mu})"
120    );
121
122    // 1) bicubic-interpolate pe per channel into `interp_pe` [E, grid, grid].
123    let mut interp_pe = vec![0f32; e * grid * grid];
124    for c in 0..e {
125        let src = &pe[c * ph * pw..(c + 1) * ph * pw];
126        let dst = &mut interp_pe[c * grid * grid..(c + 1) * grid * grid];
127        bicubic_resize_2d(src, ph, pw, dst, grid, grid);
128    }
129
130    // 2) Tile pew across grid (it tiles by integer factor since
131    //    grid is a multiple of mu) and sum.
132    let mut out_nchw = interp_pe; // reuse
133    for c in 0..e {
134        for y in 0..grid {
135            let ty = y % mu;
136            for x in 0..grid {
137                let tx = x % mu;
138                let w_val = pew[c * mu * mu + ty * mu + tx];
139                out_nchw[c * grid * grid + y * grid + x] += w_val;
140            }
141        }
142    }
143
144    // 3) Permute NCHW → BHWC (just a single sample, B=1) and flatten.
145    let mut out_bhwc = vec![0f32; grid * grid * e];
146    for y in 0..grid {
147        for x in 0..grid {
148            for c in 0..e {
149                out_bhwc[(y * grid + x) * e + c] = out_nchw[c * grid * grid + y * grid + x];
150            }
151        }
152    }
153    out_bhwc
154}
155
156/// Catmull-Rom bicubic resize of a single-channel `[h_in, w_in]` image
157/// into `[h_out, w_out]`. Uses the OpenCV / PyTorch default
158/// `align_corners=False` convention.
159///
160/// Only used for the 14×14 → 256×256 `pos_embed` interpolation (i.e.
161/// once per model load, not per inference) so the simple loop is fine.
162fn bicubic_resize_2d(
163    src: &[f32],
164    h_in: usize,
165    w_in: usize,
166    dst: &mut [f32],
167    h_out: usize,
168    w_out: usize,
169) {
170    fn cubic(t: f32) -> f32 {
171        // Standard Catmull-Rom kernel (a = -0.75, PyTorch / cv2 default).
172        let a = -0.75_f32;
173        let t = t.abs();
174        if t < 1.0 {
175            ((a + 2.0) * t - (a + 3.0)) * t * t + 1.0
176        } else if t < 2.0 {
177            (((t - 5.0) * t + 8.0) * t - 4.0) * a
178        } else {
179            0.0
180        }
181    }
182    fn idx(i: isize, max: isize) -> usize {
183        // Replicate-edge (clamped) indexing.
184        i.clamp(0, max - 1) as usize
185    }
186
187    let sx = (w_in as f32) / (w_out as f32);
188    let sy = (h_in as f32) / (h_out as f32);
189
190    for y_o in 0..h_out {
191        // align_corners=False: src y = (y_o + 0.5) * sy - 0.5
192        let yf = (y_o as f32 + 0.5) * sy - 0.5;
193        let yi = yf.floor();
194        let dy = yf - yi;
195        let wy = [cubic(1.0 + dy), cubic(dy), cubic(1.0 - dy), cubic(2.0 - dy)];
196        for x_o in 0..w_out {
197            let xf = (x_o as f32 + 0.5) * sx - 0.5;
198            let xi = xf.floor();
199            let dx = xf - xi;
200            let wx = [cubic(1.0 + dx), cubic(dx), cubic(1.0 - dx), cubic(2.0 - dx)];
201
202            let mut acc = 0f32;
203            for jy in 0..4 {
204                let iy = idx(yi as isize - 1 + jy, h_in as isize);
205                for jx in 0..4 {
206                    let ix = idx(xi as isize - 1 + jx as isize, w_in as isize);
207                    acc += src[iy * w_in + ix] * wy[jy as usize] * wx[jx];
208                }
209            }
210            dst[y_o * w_out + x_o] = acc;
211        }
212    }
213}
214
215/// Square-resize an RGB u8 image to 1024×1024 (bilinear, no aspect-
216/// ratio preservation), /255, then ImageNet-normalise. Returns a
217/// contiguous `[3, 1024, 1024]` NCHW f32 buffer.
218///
219/// Matches `SAM2Transforms` in the reference exactly:
220/// `Resize((1024, 1024))` (PIL bilinear) → `ToTensor` (/255) →
221/// `Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])`.
222pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> Vec<f32> {
223    debug_assert_eq!(rgb.len(), h_in * w_in * 3);
224    let out_size = SAM2_IMG_SIZE;
225    let mut nchw = vec![0f32; 3 * out_size * out_size];
226
227    // PIL Resize uses `align_corners=False` bilinear.
228    let sx = (w_in as f32) / (out_size as f32);
229    let sy = (h_in as f32) / (out_size as f32);
230
231    for y_o in 0..out_size {
232        let yf = (y_o as f32 + 0.5) * sy - 0.5;
233        let y0 = yf.floor().max(0.0) as usize;
234        let y1 = (y0 + 1).min(h_in - 1);
235        let dy = (yf - yf.floor()).clamp(0.0, 1.0);
236        for x_o in 0..out_size {
237            let xf = (x_o as f32 + 0.5) * sx - 0.5;
238            let x0 = xf.floor().max(0.0) as usize;
239            let x1 = (x0 + 1).min(w_in - 1);
240            let dx = (xf - xf.floor()).clamp(0.0, 1.0);
241            for c in 0..3 {
242                let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
243                let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
244                let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
245                let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
246                let top = p00 * (1.0 - dx) + p01 * dx;
247                let bot = p10 * (1.0 - dx) + p11 * dx;
248                let v01 = (top * (1.0 - dy) + bot * dy) / 255.0;
249                nchw[c * out_size * out_size + y_o * out_size + x_o] =
250                    (v01 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
251            }
252        }
253    }
254    nchw
255}
256
257/// Run Hiera's patch embedding (Conv2d k=7 s=4 p=3) on the host, then
258/// add the stage-0 position embedding. Output is `[grid, grid, E]`
259/// BHWC (the layout Hiera operates on internally), flattened.
260///
261/// `image_nchw` is the `[3, 1024, 1024]` tensor from `preprocess_image`.
262pub fn assemble_patch_tokens(pre: &Sam2PreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
263    let e = pre.embed_dim;
264    let grid = pre.grid;
265    let k = SAM2_PATCH_KERNEL;
266    let s = SAM2_PATCH_STRIDE;
267    let pad = SAM2_PATCH_PADDING;
268    ensure!(
269        image_nchw.len() == 3 * SAM2_IMG_SIZE * SAM2_IMG_SIZE,
270        "image must be [3, {}, {}] NCHW, got len {}",
271        SAM2_IMG_SIZE,
272        SAM2_IMG_SIZE,
273        image_nchw.len()
274    );
275
276    let h = SAM2_IMG_SIZE;
277    let w = SAM2_IMG_SIZE;
278    let mut out = vec![0f32; grid * grid * e];
279
280    // Direct Conv2d. Per-output-pixel cost is k·k·in_c·E = 7·7·3·E.
281    // For E=112, grid=256 this is ~256² · 7² · 3 · 112 ≈ 1.1 G fmas —
282    // about the same as a single transformer block, run once.
283    for py in 0..grid {
284        for px in 0..grid {
285            // dst row in BHWC
286            let dst = &mut out[(py * grid + px) * e..(py * grid + px + 1) * e];
287            // Start with bias.
288            dst.copy_from_slice(&pre.patch_proj_b);
289            // Convolve.
290            for ky in 0..k {
291                let iy = (py * s) as isize + ky as isize - pad as isize;
292                if iy < 0 || iy >= h as isize {
293                    continue;
294                }
295                let iy = iy as usize;
296                for kx in 0..k {
297                    let ix = (px * s) as isize + kx as isize - pad as isize;
298                    if ix < 0 || ix >= w as isize {
299                        continue;
300                    }
301                    let ix = ix as usize;
302                    for c in 0..3 {
303                        let v = image_nchw[c * h * w + iy * w + ix];
304                        // weight is [E, 3, k, k]: row-major
305                        let w_base = c * k * k + ky * k + kx;
306                        let stride = 3 * k * k;
307                        for ei in 0..e {
308                            dst[ei] += v * pre.patch_proj_w[ei * stride + w_base];
309                        }
310                    }
311                }
312            }
313        }
314    }
315
316    // Add stage-0 position embedding (already in BHWC).
317    ensure!(
318        pre.pos_embed_full.len() == grid * grid * e,
319        "pos_embed_full size mismatch"
320    );
321    for i in 0..grid * grid * e {
322        out[i] += pre.pos_embed_full[i];
323    }
324    Ok(out)
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn preprocess_shape_and_range() {
333        // 50×30 RGB → 1024×1024 NCHW.
334        let img = vec![128u8; 50 * 30 * 3];
335        let nchw = preprocess_image(&img, 50, 30);
336        assert_eq!(nchw.len(), 3 * 1024 * 1024);
337        // 128/255 ≈ 0.5019; per-channel normalised ≈ (0.502 - mean) / std.
338        for c in 0..3 {
339            let expected = (128.0 / 255.0 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
340            let mid = nchw[c * 1024 * 1024 + 512 * 1024 + 512];
341            assert!(
342                (mid - expected).abs() < 1e-4,
343                "channel {c}: {mid} vs {expected}"
344            );
345        }
346    }
347
348    #[test]
349    fn bicubic_identity() {
350        // 8×8 → 8×8 should be (close to) identity for bicubic with
351        // align_corners=False.
352        let src: Vec<f32> = (0..64).map(|i| i as f32).collect();
353        let mut dst = vec![0f32; 64];
354        bicubic_resize_2d(&src, 8, 8, &mut dst, 8, 8);
355        for i in 0..64 {
356            assert!(
357                (src[i] - dst[i]).abs() < 1e-4,
358                "identity broken at {i}: {} vs {}",
359                src[i],
360                dst[i]
361            );
362        }
363    }
364}