Skip to main content

rlx_sam3/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side SAM3 preprocessing and patch embedding.
17//!
18//! SAM3's public builder uses a ViT patch-14 backbone at 1008x1008.
19//! RLX does patch projection on the host, matching the existing SAM and
20//! DINOv2 ports, because the IR surface does not currently include a
21//! general f32 Conv2d forward.
22
23use super::config::{
24    SAM3_IMG_SIZE, SAM3_PATCH_GRID, SAM3_PIXEL_MEAN, SAM3_PIXEL_STD, Sam3VitConfig,
25};
26use anyhow::{Result, ensure};
27use rlx_core::weight_map::WeightMap;
28
29#[derive(Clone)]
30pub struct Sam3PreprocessWeights {
31    /// Patch projection `[patch_dim, embed_dim]`, transposed for row-major matmul.
32    pub patch_proj_w: Vec<f32>,
33    pub patch_proj_b: Vec<f32>,
34    pub pos_embed: Option<Vec<f32>>,
35    pub embed_dim: usize,
36    pub patch_size: usize,
37    pub grid: usize,
38}
39
40pub(crate) fn extract_preprocess_weights(
41    weights: &mut WeightMap,
42    cfg: &Sam3VitConfig,
43) -> Result<Sam3PreprocessWeights> {
44    let e = cfg.embed_dim;
45    let ps = cfg.patch_size;
46    let grid = cfg.patch_grid();
47    let pd = 3 * ps * ps;
48
49    let (proj_raw, proj_shape) = take_first(
50        weights,
51        &[
52            "detector.backbone.vision_backbone.trunk.patch_embed.proj.weight",
53            "detector.backbone.visual.trunk.patch_embed.proj.weight",
54            "backbone.vision_backbone.trunk.patch_embed.proj.weight",
55            "backbone.visual.trunk.patch_embed.proj.weight",
56            "visual.trunk.patch_embed.proj.weight",
57            "trunk.patch_embed.proj.weight",
58        ],
59    )?;
60    ensure!(
61        proj_shape == vec![e, 3, ps, ps],
62        "SAM3 patch_embed.proj.weight expected [{e}, 3, {ps}, {ps}], got {proj_shape:?}"
63    );
64
65    let mut patch_proj_w = vec![0f32; e * pd];
66    for ei in 0..e {
67        for d in 0..pd {
68            patch_proj_w[d * e + ei] = proj_raw[ei * pd + d];
69        }
70    }
71
72    let patch_proj_b = if cfg.bias_patch_embed {
73        let (data, shape) = take_first(
74            weights,
75            &[
76                "detector.backbone.vision_backbone.trunk.patch_embed.proj.bias",
77                "detector.backbone.visual.trunk.patch_embed.proj.bias",
78                "backbone.vision_backbone.trunk.patch_embed.proj.bias",
79                "backbone.visual.trunk.patch_embed.proj.bias",
80                "visual.trunk.patch_embed.proj.bias",
81                "trunk.patch_embed.proj.bias",
82            ],
83        )?;
84        ensure!(
85            shape == vec![e],
86            "SAM3 patch bias expected [{e}], got {shape:?}"
87        );
88        data
89    } else {
90        vec![0.0; e]
91    };
92
93    let pos_embed = if cfg.use_abs_pos {
94        take_optional_first(
95            weights,
96            &[
97                "detector.backbone.vision_backbone.trunk.pos_embed",
98                "detector.backbone.visual.trunk.pos_embed",
99                "backbone.vision_backbone.trunk.pos_embed",
100                "backbone.visual.trunk.pos_embed",
101                "visual.trunk.pos_embed",
102                "trunk.pos_embed",
103            ],
104        )?
105        .map(|(data, shape)| materialize_pos_embed(&data, &shape, cfg, grid, e))
106        .transpose()?
107    } else {
108        None
109    };
110
111    Ok(Sam3PreprocessWeights {
112        patch_proj_w,
113        patch_proj_b,
114        pos_embed,
115        embed_dim: e,
116        patch_size: ps,
117        grid,
118    })
119}
120
121/// Resize an RGB u8 image to fit in SAM3's square canvas, normalize, and pad.
122pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> (Vec<f32>, (usize, usize)) {
123    let scale = (SAM3_IMG_SIZE as f32) / (h_in.max(w_in) as f32);
124    let new_h = ((h_in as f32) * scale).round() as usize;
125    let new_w = ((w_in as f32) * scale).round() as usize;
126
127    let mut resized = vec![0f32; 3 * new_h * new_w];
128    let sx = (w_in as f32 - 1.0) / (new_w.max(1) as f32 - 1.0).max(1.0);
129    let sy = (h_in as f32 - 1.0) / (new_h.max(1) as f32 - 1.0).max(1.0);
130    for y in 0..new_h {
131        let fy = y as f32 * sy;
132        let y0 = fy.floor() as usize;
133        let y1 = (y0 + 1).min(h_in - 1);
134        let dy = fy - y0 as f32;
135        for x in 0..new_w {
136            let fx = x as f32 * sx;
137            let x0 = fx.floor() as usize;
138            let x1 = (x0 + 1).min(w_in - 1);
139            let dx = fx - x0 as f32;
140            for c in 0..3 {
141                let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32 / 255.0;
142                let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32 / 255.0;
143                let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32 / 255.0;
144                let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32 / 255.0;
145                let top = p00 * (1.0 - dx) + p01 * dx;
146                let bot = p10 * (1.0 - dx) + p11 * dx;
147                let v = top * (1.0 - dy) + bot * dy;
148                resized[c * new_h * new_w + y * new_w + x] =
149                    (v - SAM3_PIXEL_MEAN[c]) / SAM3_PIXEL_STD[c];
150            }
151        }
152    }
153
154    let mut padded = vec![0f32; 3 * SAM3_IMG_SIZE * SAM3_IMG_SIZE];
155    for c in 0..3 {
156        for y in 0..new_h {
157            let src_row = c * new_h * new_w + y * new_w;
158            let dst_row = c * SAM3_IMG_SIZE * SAM3_IMG_SIZE + y * SAM3_IMG_SIZE;
159            padded[dst_row..dst_row + new_w].copy_from_slice(&resized[src_row..src_row + new_w]);
160        }
161    }
162    (padded, (new_h, new_w))
163}
164
165pub fn assemble_patch_tokens(pre: &Sam3PreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
166    let e = pre.embed_dim;
167    let ps = pre.patch_size;
168    let grid = pre.grid;
169    let pd = 3 * ps * ps;
170    ensure!(
171        image_nchw.len() == 3 * SAM3_IMG_SIZE * SAM3_IMG_SIZE,
172        "SAM3 image must be [3, {SAM3_IMG_SIZE}, {SAM3_IMG_SIZE}] NCHW, got len {}",
173        image_nchw.len()
174    );
175    ensure!(
176        grid == SAM3_PATCH_GRID,
177        "SAM3 base grid must be {SAM3_PATCH_GRID}"
178    );
179
180    let mut out = vec![0f32; grid * grid * e];
181    let mut patch_buf = vec![0f32; pd];
182    for py in 0..grid {
183        for px in 0..grid {
184            for c in 0..3 {
185                for ry in 0..ps {
186                    let src_y = py * ps + ry;
187                    for rx in 0..ps {
188                        let src_x = px * ps + rx;
189                        let src = c * SAM3_IMG_SIZE * SAM3_IMG_SIZE + src_y * SAM3_IMG_SIZE + src_x;
190                        let dst = c * ps * ps + ry * ps + rx;
191                        patch_buf[dst] = image_nchw[src];
192                    }
193                }
194            }
195            let row = py * grid + px;
196            let dst = &mut out[row * e..(row + 1) * e];
197            dst.copy_from_slice(&pre.patch_proj_b);
198            for d in 0..pd {
199                let v = patch_buf[d];
200                if v == 0.0 {
201                    continue;
202                }
203                let w_row = &pre.patch_proj_w[d * e..(d + 1) * e];
204                for k in 0..e {
205                    dst[k] += v * w_row[k];
206                }
207            }
208        }
209    }
210
211    if let Some(pos) = &pre.pos_embed {
212        ensure!(pos.len() == out.len(), "SAM3 pos_embed size mismatch");
213        for i in 0..out.len() {
214            out[i] += pos[i];
215        }
216    }
217
218    Ok(out)
219}
220
221/// Materialise the absolute positional embedding so the trunk can add it
222/// directly to the [grid, grid, embed_dim] patch tokens. Upstream stores a
223/// `[1, num_positions, embed_dim]` pretrain table: when
224/// `pretrain_use_cls_token` is set the first row is the CLS position and the
225/// rest is a `pretrain_grid x pretrain_grid` table. We then tile (or, when
226/// `tile_abs_pos=False`, bicubic-interpolate) to the deployment grid.
227fn materialize_pos_embed(
228    data: &[f32],
229    shape: &[usize],
230    cfg: &Sam3VitConfig,
231    grid: usize,
232    e: usize,
233) -> Result<Vec<f32>> {
234    if shape == [1, grid, grid, e] || shape == [grid, grid, e] {
235        return Ok(data.to_vec());
236    }
237    ensure!(
238        shape.len() == 3 && shape[0] == 1 && shape[2] == e,
239        "SAM3 pos_embed expected [1, *, {e}], got {shape:?}"
240    );
241    let num_positions = shape[1];
242    let has_cls = num_positions % 2 == 1;
243    let spatial = if has_cls {
244        num_positions - 1
245    } else {
246        num_positions
247    };
248    let pretrain_grid = (spatial as f64).sqrt().round() as usize;
249    ensure!(
250        pretrain_grid * pretrain_grid == spatial,
251        "SAM3 pos_embed spatial portion not square: {spatial} positions"
252    );
253
254    let src = if has_cls { &data[e..] } else { data };
255    let mut out = vec![0f32; grid * grid * e];
256
257    if cfg.tile_abs_pos {
258        for y in 0..grid {
259            for x in 0..grid {
260                let sy = y % pretrain_grid;
261                let sx = x % pretrain_grid;
262                let src_row = (sy * pretrain_grid + sx) * e;
263                let dst_row = (y * grid + x) * e;
264                out[dst_row..dst_row + e].copy_from_slice(&src[src_row..src_row + e]);
265            }
266        }
267    } else {
268        // Bicubic interpolation (matches torch.nn.functional.interpolate
269        // mode="bicubic", align_corners=False).
270        bicubic_interp_nhwc(src, pretrain_grid, pretrain_grid, &mut out, grid, grid, e);
271    }
272
273    Ok(out)
274}
275
276fn bicubic_interp_nhwc(
277    src: &[f32],
278    src_h: usize,
279    src_w: usize,
280    dst: &mut [f32],
281    dst_h: usize,
282    dst_w: usize,
283    c: usize,
284) {
285    // Convert to [C, H, W] for sampling, then back.
286    let mut src_chw = vec![0f32; c * src_h * src_w];
287    for y in 0..src_h {
288        for x in 0..src_w {
289            for ch in 0..c {
290                src_chw[ch * src_h * src_w + y * src_w + x] = src[(y * src_w + x) * c + ch];
291            }
292        }
293    }
294    let scale_y = src_h as f32 / dst_h as f32;
295    let scale_x = src_w as f32 / dst_w as f32;
296    for y in 0..dst_h {
297        let fy = (y as f32 + 0.5) * scale_y - 0.5;
298        let y_floor = fy.floor() as i32;
299        let dy = fy - y_floor as f32;
300        let wy = cubic_weights(dy);
301        for x in 0..dst_w {
302            let fx = (x as f32 + 0.5) * scale_x - 0.5;
303            let x_floor = fx.floor() as i32;
304            let dx = fx - x_floor as f32;
305            let wx = cubic_weights(dx);
306            for ch in 0..c {
307                let plane = &src_chw[ch * src_h * src_w..(ch + 1) * src_h * src_w];
308                let mut v = 0.0f32;
309                for j in -1..=2 {
310                    let sy = (y_floor + j).clamp(0, src_h as i32 - 1) as usize;
311                    let mut row_acc = 0.0f32;
312                    for i in -1..=2 {
313                        let sx = (x_floor + i).clamp(0, src_w as i32 - 1) as usize;
314                        row_acc += plane[sy * src_w + sx] * wx[(i + 1) as usize];
315                    }
316                    v += row_acc * wy[(j + 1) as usize];
317                }
318                dst[(y * dst_w + x) * c + ch] = v;
319            }
320        }
321    }
322}
323
324fn cubic_weights(t: f32) -> [f32; 4] {
325    // Cubic convolution kernel with a=-0.75 (matches PyTorch's bicubic).
326    let a = -0.75f32;
327    let t1 = 1.0 + t; // distance to leftmost
328    let t2 = t; // distance to next
329    let t3 = 1.0 - t; // distance to next
330    let t4 = 2.0 - t; // distance to rightmost
331    [
332        cubic_kernel(t1, a),
333        cubic_kernel(t2, a),
334        cubic_kernel(t3, a),
335        cubic_kernel(t4, a),
336    ]
337}
338
339fn cubic_kernel(x: f32, a: f32) -> f32 {
340    let x = x.abs();
341    if x < 1.0 {
342        (a + 2.0) * x * x * x - (a + 3.0) * x * x + 1.0
343    } else if x < 2.0 {
344        a * x * x * x - 5.0 * a * x * x + 8.0 * a * x - 4.0 * a
345    } else {
346        0.0
347    }
348}
349
350fn take_first(weights: &mut WeightMap, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
351    for key in keys {
352        if weights.has(key) {
353            return weights.take(key);
354        }
355    }
356    anyhow::bail!("none of the SAM3 weight keys were found: {keys:?}")
357}
358
359fn take_optional_first(
360    weights: &mut WeightMap,
361    keys: &[&str],
362) -> Result<Option<(Vec<f32>, Vec<usize>)>> {
363    for key in keys {
364        if weights.has(key) {
365            return weights.take(key).map(Some);
366        }
367    }
368    Ok(None)
369}