Skip to main content

rlx_sam3/
neck.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native SAM3 detection neck (`Sam3DualViTDetNeck` without the SAM2 head).
17//!
18//! Per-level branch on the last trunk feature `[B, 1024, 72, 72]`:
19//!
20//! | scale | branch                                                           |
21//! |-------|------------------------------------------------------------------|
22//! | 4.0   | dconv2x2(1024→512)·GELU·dconv2x2(512→256)·conv1x1(256→256)·conv3x3 |
23//! | 2.0   | dconv2x2(1024→512)·conv1x1(512→256)·conv3x3                       |
24//! | 1.0   | conv1x1(1024→256)·conv3x3                                         |
25//! | 0.5   | maxpool2x2·conv1x1(1024→256)·conv3x3                              |
26//!
27//! Each branch also emits a sinusoidal positional encoding of matching
28//! shape, computed by `position_encoding_sine_sam3`.
29
30use super::config::SAM3_DET_DIM;
31use super::neck_branch_ir::Sam3NeckBranchCompiled;
32use super::vision_encoder::Sam3VisionOutput;
33use anyhow::{Result, ensure};
34use rlx_core::weight_map::WeightMap;
35use rlx_flow::CompileProfile;
36use rlx_runtime::Device;
37
38#[derive(Debug, Clone)]
39pub struct Sam3FeatureLevel {
40    pub features: Vec<f32>, // NCHW flat [c, h, w]
41    pub pos: Vec<f32>,
42    pub h: usize,
43    pub w: usize,
44    pub channels: usize,
45}
46
47#[derive(Default)]
48pub struct Sam3NeckWeights {
49    pub loaded: bool,
50    pub branches: Vec<Sam3NeckBranch>,
51    /// Per-branch IR graphs (filled by [`compile_neck_branches`]).
52    pub ir: Vec<Sam3NeckBranchCompiled>,
53}
54
55#[derive(Clone, Default)]
56pub struct Sam3NeckBranch {
57    pub scale: f32,
58    /// First deconv if scale ∈ {2.0, 4.0}.
59    pub dconv0_w: Option<Vec<f32>>,
60    pub dconv0_b: Option<Vec<f32>>,
61    /// Second deconv if scale == 4.0.
62    pub dconv1_w: Option<Vec<f32>>,
63    pub dconv1_b: Option<Vec<f32>>,
64    /// 1x1 conv (after the optional resampling).
65    pub c1x1_w: Vec<f32>,
66    pub c1x1_b: Vec<f32>,
67    pub c1x1_in: usize,
68    /// 3x3 conv, in_dim == out_dim == d_model.
69    pub c3x3_w: Vec<f32>,
70    pub c3x3_b: Vec<f32>,
71}
72
73pub fn extract_neck_weights(weights: &mut WeightMap) -> Result<Sam3NeckWeights> {
74    let prefixes = [
75        "detector.backbone.vision_backbone",
76        "backbone.vision_backbone",
77        "vision_backbone",
78    ];
79    let scales = [4.0f32, 2.0, 1.0, 0.5];
80    let mut branches = Vec::with_capacity(scales.len());
81    for (i, scale) in scales.iter().enumerate() {
82        let mut found = None;
83        for pref in prefixes {
84            let base = format!("{pref}.convs.{i}");
85            if weights.has(&format!("{base}.conv_1x1.weight")) {
86                found = Some(base);
87                break;
88            }
89        }
90        let base = found.ok_or_else(|| {
91            anyhow::anyhow!("SAM3 neck branch {i} (scale={scale}) not found in checkpoint")
92        })?;
93
94        let (dconv0_w, dconv0_b) = if (*scale - 4.0).abs() < 1e-6 {
95            let (w, ws) = weights.take(&format!("{base}.dconv_2x2_0.weight"))?;
96            ensure!(
97                ws == vec![1024, 512, 2, 2],
98                "dconv_2x2_0.weight shape {ws:?}"
99            );
100            let (b, _) = weights.take(&format!("{base}.dconv_2x2_0.bias"))?;
101            (Some(w), Some(b))
102        } else if (*scale - 2.0).abs() < 1e-6 {
103            let (w, ws) = weights.take(&format!("{base}.dconv_2x2.weight"))?;
104            ensure!(ws == vec![1024, 512, 2, 2], "dconv_2x2.weight shape {ws:?}");
105            let (b, _) = weights.take(&format!("{base}.dconv_2x2.bias"))?;
106            (Some(w), Some(b))
107        } else {
108            (None, None)
109        };
110        let (dconv1_w, dconv1_b) = if (*scale - 4.0).abs() < 1e-6 {
111            let (w, ws) = weights.take(&format!("{base}.dconv_2x2_1.weight"))?;
112            ensure!(
113                ws == vec![512, 256, 2, 2],
114                "dconv_2x2_1.weight shape {ws:?}"
115            );
116            let (b, _) = weights.take(&format!("{base}.dconv_2x2_1.bias"))?;
117            (Some(w), Some(b))
118        } else {
119            (None, None)
120        };
121
122        let (c1x1_w, c1_shape) = weights.take(&format!("{base}.conv_1x1.weight"))?;
123        ensure!(c1_shape.len() == 4 && c1_shape[2] == 1 && c1_shape[3] == 1);
124        let c1x1_in = c1_shape[1];
125        let (c1x1_b, _) = weights.take(&format!("{base}.conv_1x1.bias"))?;
126        let (c3x3_w, c3_shape) = weights.take(&format!("{base}.conv_3x3.weight"))?;
127        ensure!(
128            c3_shape == vec![SAM3_DET_DIM, SAM3_DET_DIM, 3, 3],
129            "conv_3x3.weight shape {c3_shape:?}"
130        );
131        let (c3x3_b, _) = weights.take(&format!("{base}.conv_3x3.bias"))?;
132
133        branches.push(Sam3NeckBranch {
134            scale: *scale,
135            dconv0_w,
136            dconv0_b,
137            dconv1_w,
138            dconv1_b,
139            c1x1_w,
140            c1x1_b,
141            c1x1_in,
142            c3x3_w,
143            c3x3_b,
144        });
145    }
146
147    // Drop the sam2_convs branch — we don't run the SAM2 head in image-only mode.
148    for pref in prefixes {
149        let base = format!("{pref}.sam2_convs");
150        let keys: Vec<String> = weights
151            .keys()
152            .filter(|k| k.starts_with(&base))
153            .map(|s| s.to_string())
154            .collect();
155        for k in keys {
156            let _ = weights.take(&k);
157        }
158    }
159
160    Ok(Sam3NeckWeights {
161        loaded: true,
162        branches,
163        ir: Vec::new(),
164    })
165}
166
167/// Compile each neck branch for the given trunk grid / device.
168pub fn compile_neck_branches(
169    neck: &mut Sam3NeckWeights,
170    in_c: usize,
171    grid: usize,
172    device: Device,
173    profile: &CompileProfile,
174) -> Result<()> {
175    neck.ir = neck
176        .branches
177        .iter()
178        .map(|b| Sam3NeckBranchCompiled::compile_with_profile(b, in_c, grid, grid, device, profile))
179        .collect::<Result<_>>()?;
180    Ok(())
181}
182
183pub fn apply_neck_native(
184    weights: &mut Sam3NeckWeights,
185    vision: &Sam3VisionOutput,
186) -> Result<Vec<Sam3FeatureLevel>> {
187    ensure!(
188        weights.loaded,
189        "SAM3 neck weights not loaded; call extract_neck_weights()"
190    );
191    let grid = vision.grid;
192    let dim = vision.dim;
193
194    // Vision output is NHWC `[grid*grid, dim]`. Reshape to NCHW for convs.
195    let mut x_nchw = vec![0f32; dim * grid * grid];
196    for y in 0..grid {
197        for xc in 0..grid {
198            for c in 0..dim {
199                x_nchw[c * grid * grid + y * grid + xc] = vision.tokens[(y * grid + xc) * dim + c];
200            }
201        }
202    }
203
204    let mut levels = Vec::with_capacity(weights.branches.len());
205    if weights.ir.len() == weights.branches.len() {
206        for compiled in &mut weights.ir {
207            let features = compiled.run(&x_nchw, dim, grid, grid)?;
208            let pos = position_encoding_sine_sam3(SAM3_DET_DIM, compiled.out_h, compiled.out_w);
209            levels.push(Sam3FeatureLevel {
210                features,
211                pos,
212                h: compiled.out_h,
213                w: compiled.out_w,
214                channels: SAM3_DET_DIM,
215            });
216        }
217    } else {
218        for branch in &weights.branches {
219            let level = apply_branch_host(branch, &x_nchw, dim, grid, grid)?;
220            levels.push(level);
221        }
222    }
223    Ok(levels)
224}
225
226fn apply_branch_host(
227    branch: &Sam3NeckBranch,
228    x: &[f32],
229    in_c: usize,
230    h: usize,
231    w: usize,
232) -> Result<Sam3FeatureLevel> {
233    let mut cur = x.to_vec();
234    let mut cur_c = in_c;
235    let mut cur_h = h;
236    let mut cur_w = w;
237
238    if (branch.scale - 4.0).abs() < 1e-6 {
239        let dw0 = branch.dconv0_w.as_ref().unwrap();
240        let db0 = branch.dconv0_b.as_ref().unwrap();
241        cur = conv_transpose2d_stride2_k2(&cur, cur_c, 512, cur_h, cur_w, dw0, db0);
242        cur_c = 512;
243        cur_h *= 2;
244        cur_w *= 2;
245        gelu_inplace(&mut cur);
246        let dw1 = branch.dconv1_w.as_ref().unwrap();
247        let db1 = branch.dconv1_b.as_ref().unwrap();
248        cur = conv_transpose2d_stride2_k2(&cur, cur_c, 256, cur_h, cur_w, dw1, db1);
249        cur_c = 256;
250        cur_h *= 2;
251        cur_w *= 2;
252    } else if (branch.scale - 2.0).abs() < 1e-6 {
253        let dw = branch.dconv0_w.as_ref().unwrap();
254        let db = branch.dconv0_b.as_ref().unwrap();
255        cur = conv_transpose2d_stride2_k2(&cur, cur_c, 512, cur_h, cur_w, dw, db);
256        cur_c = 512;
257        cur_h *= 2;
258        cur_w *= 2;
259    } else if (branch.scale - 0.5).abs() < 1e-6 {
260        cur = maxpool2x2_stride2(&cur, cur_c, cur_h, cur_w);
261        cur_h /= 2;
262        cur_w /= 2;
263        // cur_c unchanged.
264    }
265    ensure!(cur_c == branch.c1x1_in, "branch input channels mismatch");
266
267    // 1×1 conv: cur_c → SAM3_DET_DIM.
268    cur = conv2d_1x1(
269        &cur,
270        cur_c,
271        SAM3_DET_DIM,
272        cur_h,
273        cur_w,
274        &branch.c1x1_w,
275        &branch.c1x1_b,
276    );
277    cur_c = SAM3_DET_DIM;
278
279    // 3×3 conv with padding=1 stride=1.
280    cur = conv2d_3x3_pad1(&cur, cur_c, cur_h, cur_w, &branch.c3x3_w, &branch.c3x3_b);
281
282    let pos = position_encoding_sine_sam3(SAM3_DET_DIM, cur_h, cur_w);
283    Ok(Sam3FeatureLevel {
284        features: cur,
285        pos,
286        h: cur_h,
287        w: cur_w,
288        channels: cur_c,
289    })
290}
291
292fn gelu_inplace(x: &mut [f32]) {
293    // PyTorch's default `nn.GELU()` is the exact (erf-based) form. We
294    // approximate `erf` with a high-accuracy Abramowitz-Stegun series so
295    // we don't pick up a new dep just for the neck branch.
296    let inv_sqrt2 = 1.0f32 / std::f32::consts::SQRT_2;
297    for v in x.iter_mut() {
298        *v = 0.5 * *v * (1.0 + erf_approx(*v * inv_sqrt2));
299    }
300}
301
302fn erf_approx(x: f32) -> f32 {
303    // Abramowitz & Stegun 7.1.26. Max abs error ≈ 1.5e-7, plenty for f32.
304    let sign = if x < 0.0 { -1.0f32 } else { 1.0 };
305    let ax = x.abs();
306    let p = 0.3275911f32;
307    let a1 = 0.2548296f32;
308    let a2 = -0.2844967f32;
309    let a3 = 1.4214138f32;
310    let a4 = -1.4531521f32;
311    let a5 = 1.0614054f32;
312    let t = 1.0 / (1.0 + p * ax);
313    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-ax * ax).exp();
314    sign * y
315}
316
317fn maxpool2x2_stride2(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
318    let out_h = h / 2;
319    let out_w = w / 2;
320    let mut out = vec![0f32; c * out_h * out_w];
321    for cc in 0..c {
322        let inp = &input[cc * h * w..(cc + 1) * h * w];
323        let oup = &mut out[cc * out_h * out_w..(cc + 1) * out_h * out_w];
324        for oy in 0..out_h {
325            for ox in 0..out_w {
326                let iy = oy * 2;
327                let ix = ox * 2;
328                let a = inp[iy * w + ix];
329                let b = inp[iy * w + ix + 1];
330                let cv = inp[(iy + 1) * w + ix];
331                let d = inp[(iy + 1) * w + ix + 1];
332                oup[oy * out_w + ox] = a.max(b).max(cv).max(d);
333            }
334        }
335    }
336    out
337}
338
339fn conv2d_1x1(
340    input: &[f32],
341    in_c: usize,
342    out_c: usize,
343    h: usize,
344    w: usize,
345    weight: &[f32], // [out_c, in_c, 1, 1] → row-major [out_c, in_c]
346    bias: &[f32],
347) -> Vec<f32> {
348    // Treat the spatial dims as the GEMM K-axis-free batch; weights map
349    // channels through a single matmul: out[oc, n] = sum_ic w[oc, ic] * in[ic, n].
350    // Use sgemm: A = weight [out_c, in_c], B = input [in_c, hw], C [out_c, hw].
351    let n = h * w;
352    let mut out = vec![0f32; out_c * n];
353    rlx_cpu::blas::sgemm(weight, input, &mut out, out_c, in_c, n);
354    // Add bias.
355    for oc in 0..out_c {
356        let b = bias[oc];
357        let row = &mut out[oc * n..(oc + 1) * n];
358        for v in row {
359            *v += b;
360        }
361    }
362    out
363}
364
365fn conv2d_3x3_pad1(
366    input: &[f32],
367    c: usize,
368    h: usize,
369    w: usize,
370    weight: &[f32], // [out_c=c, in_c=c, 3, 3]
371    bias: &[f32],
372) -> Vec<f32> {
373    let mut out = vec![0f32; c * h * w];
374    for oc in 0..c {
375        let b = bias[oc];
376        let oup = &mut out[oc * h * w..(oc + 1) * h * w];
377        for v in oup.iter_mut() {
378            *v = b;
379        }
380    }
381    for oc in 0..c {
382        for ic in 0..c {
383            let w_oi = &weight[((oc * c + ic) * 9)..((oc * c + ic) * 9 + 9)];
384            let inp = &input[ic * h * w..(ic + 1) * h * w];
385            let oup = &mut out[oc * h * w..(oc + 1) * h * w];
386            for oy in 0..h {
387                for ox in 0..w {
388                    let mut acc = 0.0f32;
389                    for ky in 0..3 {
390                        let iy = oy as isize + ky as isize - 1;
391                        if iy < 0 || iy >= h as isize {
392                            continue;
393                        }
394                        for kx in 0..3 {
395                            let ix = ox as isize + kx as isize - 1;
396                            if ix < 0 || ix >= w as isize {
397                                continue;
398                            }
399                            acc += inp[iy as usize * w + ix as usize] * w_oi[ky * 3 + kx];
400                        }
401                    }
402                    oup[oy * w + ox] += acc;
403                }
404            }
405        }
406    }
407    out
408}
409
410fn conv_transpose2d_stride2_k2(
411    input: &[f32],
412    in_c: usize,
413    out_c: usize,
414    h: usize,
415    w: usize,
416    weight: &[f32], // PyTorch ConvTranspose2d weight: [in_c, out_c, k, k]
417    bias: &[f32],
418) -> Vec<f32> {
419    let out_h = h * 2;
420    let out_w = w * 2;
421    let mut out = vec![0f32; out_c * out_h * out_w];
422    for oc in 0..out_c {
423        let b = bias[oc];
424        let plane = &mut out[oc * out_h * out_w..(oc + 1) * out_h * out_w];
425        for v in plane.iter_mut() {
426            *v = b;
427        }
428    }
429    for ic in 0..in_c {
430        for iy in 0..h {
431            for ix in 0..w {
432                let v = input[ic * h * w + iy * w + ix];
433                if v == 0.0 {
434                    continue;
435                }
436                for ky in 0..2 {
437                    let oy = iy * 2 + ky;
438                    for kx in 0..2 {
439                        let ox = ix * 2 + kx;
440                        for oc in 0..out_c {
441                            let w_idx = ((ic * out_c + oc) * 2 + ky) * 2 + kx;
442                            out[oc * out_h * out_w + oy * out_w + ox] += v * weight[w_idx];
443                        }
444                    }
445                }
446            }
447        }
448    }
449    out
450}
451
452/// SAM3-flavour 2D sinusoidal positional encoding (matches
453/// `sam3.model.position_encoding.PositionEmbeddingSine`). Output shape
454/// `[d_model, h, w]` (NCHW) with `d_model = 2 * num_pos_feats`.
455pub fn position_encoding_sine_sam3(d_model: usize, h: usize, w: usize) -> Vec<f32> {
456    assert!(d_model.is_multiple_of(2), "d_model must be even");
457    let num_pos_feats = d_model / 2;
458    let scale = 2.0 * std::f32::consts::PI;
459    let eps = 1e-6f32;
460    let temperature = 10000.0f32;
461
462    let mut dim_t = vec![0.0f32; num_pos_feats];
463    for i in 0..num_pos_feats {
464        let exp = 2.0 * ((i / 2) as f32) / num_pos_feats as f32;
465        dim_t[i] = temperature.powf(exp);
466    }
467
468    let mut out = vec![0.0f32; d_model * h * w];
469    let y_denom = h as f32 + eps; // last row index after +1 is h
470    let x_denom = w as f32 + eps;
471
472    for y in 0..h {
473        let y_norm = ((y + 1) as f32) / y_denom * scale;
474        for x in 0..w {
475            let x_norm = ((x + 1) as f32) / x_denom * scale;
476            // pos_y in the first num_pos_feats channels, pos_x in the second.
477            for i in 0..num_pos_feats {
478                let py = y_norm / dim_t[i];
479                let v = if i % 2 == 0 { py.sin() } else { py.cos() };
480                out[i * h * w + y * w + x] = v;
481            }
482            for i in 0..num_pos_feats {
483                let px = x_norm / dim_t[i];
484                let v = if i % 2 == 0 { px.sin() } else { px.cos() };
485                out[(num_pos_feats + i) * h * w + y * w + x] = v;
486            }
487        }
488    }
489    out
490}