Skip to main content

rlx_sam2/
fpn_neck.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 FPN neck (mirrors `sam2/modeling/backbones/image_encoder.py::FpnNeck`).
17//!
18//! Per-stage 1×1 lateral convs and top-down fusion run via IR
19//! ([`super::fpn_neck_ir`]); sinusoidal PE is precomputed at compile time.
20
21use super::config::{Sam2FpnConfig, Sam2HieraConfig};
22use anyhow::{Result, ensure};
23use rlx_core::weight_map::WeightMap;
24use std::f32::consts::PI;
25
26/// Weights for the FPN neck — one 1×1 conv (`weight` + `bias`) per
27/// backbone level. Stored coarse → fine to match the checkpoint's
28/// `image_encoder.neck.convs.{i}.conv.{weight,bias}` ordering.
29pub struct FpnNeckWeights {
30    /// `[d_model, backbone_channel_list[i]]` per level (1×1 conv = a
31    /// per-pixel linear, so the kernel dims collapse).
32    pub conv_w: Vec<Vec<f32>>,
33    pub conv_b: Vec<Vec<f32>>,
34    pub d_model: usize,
35    pub backbone_channel_list: Vec<usize>,
36    pub fpn_top_down_levels: Vec<usize>,
37    pub nearest: bool,
38}
39
40pub(super) fn extract_fpn_weights(
41    weights: &mut WeightMap,
42    cfg: &Sam2HieraConfig,
43) -> Result<FpnNeckWeights> {
44    let fpn = Sam2FpnConfig::for_hiera(cfg);
45    let n = fpn.backbone_channel_list.len();
46    let d = fpn.d_model;
47
48    let mut conv_w = Vec::with_capacity(n);
49    let mut conv_b = Vec::with_capacity(n);
50    for i in 0..n {
51        let cin = fpn.backbone_channel_list[i];
52        let (raw_w, w_shape) =
53            weights.take(&format!("image_encoder.neck.convs.{i}.conv.weight"))?;
54        ensure!(
55            w_shape == vec![d, cin, 1, 1],
56            "neck.convs.{i}.conv.weight expected [{d}, {cin}, 1, 1], got {w_shape:?}"
57        );
58        let (raw_b, _) = weights.take(&format!("image_encoder.neck.convs.{i}.conv.bias"))?;
59        conv_w.push(raw_w);
60        conv_b.push(raw_b);
61    }
62    Ok(FpnNeckWeights {
63        conv_w,
64        conv_b,
65        d_model: d,
66        backbone_channel_list: fpn.backbone_channel_list,
67        fpn_top_down_levels: fpn.fpn_top_down_levels,
68        nearest: fpn.interpolation_nearest,
69    })
70}
71
72/// A single FPN level output — BCHW features + matched sinusoidal
73/// positional encoding.
74pub struct FpnLevel {
75    /// `[d_model, h, w]` NCHW.
76    pub features: Vec<f32>,
77    /// `[d_model, h, w]` NCHW — sinusoidal absolute pos embed.
78    pub pos: Vec<f32>,
79    pub h: usize,
80    pub w: usize,
81}
82
83/// Run the FPN neck. `stage_outputs[i]` is the encoder's
84/// stage-`i` output flattened from BHWC `[1, h, w, dim]` to
85/// `[h·w·dim]`. `stage_dims[i] = dim`, `stage_hw[i] = (h, w)` — pulled
86/// straight from the graph's stage-output shapes (or computed from
87/// `cfg.embed_dim_at_stage(s)` / `cfg.grid_size_at_stage(s)`).
88///
89/// Returns four `FpnLevel`s in **fine → coarse** order (so callers
90/// downstream can naturally index `[0]` for the highest-resolution
91/// stride-4 feature map, matching the reference).
92pub fn apply_fpn_neck(
93    neck: &FpnNeckWeights,
94    ir: &mut super::fpn_neck_ir::Sam2FpnNeckIr,
95    stage_outputs: &[Vec<f32>],
96    stage_hw: &[(usize, usize)],
97    stage_dims: &[usize],
98) -> Result<Vec<FpnLevel>> {
99    apply_fpn_neck_impl(neck, Some(ir), stage_outputs, stage_hw, stage_dims)
100}
101
102/// Host-only lateral convs (legacy entry point).
103pub fn apply_fpn_neck_host(
104    neck: &FpnNeckWeights,
105    stage_outputs: &[Vec<f32>],
106    stage_hw: &[(usize, usize)],
107    stage_dims: &[usize],
108) -> Vec<FpnLevel> {
109    apply_fpn_neck_impl(neck, None, stage_outputs, stage_hw, stage_dims).expect("host FPN neck")
110}
111
112fn apply_fpn_neck_impl(
113    neck: &FpnNeckWeights,
114    mut ir: Option<&mut super::fpn_neck_ir::Sam2FpnNeckIr>,
115    stage_outputs: &[Vec<f32>],
116    stage_hw: &[(usize, usize)],
117    stage_dims: &[usize],
118) -> Result<Vec<FpnLevel>> {
119    let n = neck.backbone_channel_list.len();
120    assert_eq!(stage_outputs.len(), n);
121    assert_eq!(stage_hw.len(), n);
122    assert_eq!(stage_dims.len(), n);
123    let d = neck.d_model;
124
125    // The reference loops `i = n-1 .. 0` (coarse → fine). Our
126    // `stage_outputs[0]` is the *finest* stage (stride 4); but the
127    // neck's `convs[0]` projects the *coarsest*. So convs index =
128    // `n - 1 - stage_idx`.
129    //
130    // We iterate coarse → fine to do the top-down sum, then return
131    // results in fine → coarse order.
132    let mut top_down: Option<Vec<f32>> = None;
133    let mut top_down_hw: Option<(usize, usize)> = None;
134    let mut levels: Vec<FpnLevel> = Vec::with_capacity(n);
135
136    for coarse_i in 0..n {
137        // coarse_i = 0 is the coarsest stage; iterate fine -> coarse
138        // via stage index then reverse at the end.
139        let stage_idx = n - 1 - coarse_i; // n-1, n-2, ..., 0
140        let conv_idx = coarse_i; // matches `convs[n-i]` with i=stage_idx since (n-1) - stage_idx = coarse_i
141        let (h, w) = stage_hw[stage_idx];
142        let dim_in = stage_dims[stage_idx];
143        debug_assert_eq!(dim_in, neck.backbone_channel_list[conv_idx]);
144
145        // 1) 1×1 lateral conv: dim_in → d_model.
146        let lat = match ir.as_deref_mut() {
147            Some(ir_neck) => ir_neck.laterals[stage_idx].run(&stage_outputs[stage_idx])?,
148            None => lateral_conv_host(
149                &neck.conv_w[conv_idx],
150                &neck.conv_b[conv_idx],
151                &stage_outputs[stage_idx],
152                dim_in,
153                d,
154                h,
155                w,
156            ),
157        };
158
159        // 2) Optional top-down sum (nearest ×2 upsample of `top_down`).
160        let level_features = if neck.fpn_top_down_levels.contains(&stage_idx)
161            && let Some(td) = top_down.as_ref()
162        {
163            let (th, tw) = top_down_hw.unwrap();
164            debug_assert_eq!(th * 2, h);
165            debug_assert_eq!(tw * 2, w);
166            if let Some(ir_neck) = ir.as_deref_mut() {
167                if let Some(fuse) = ir_neck.fuses.get_mut(stage_idx).and_then(|f| f.as_mut()) {
168                    fuse.run(&lat, td)?
169                } else {
170                    top_down_add_host(&lat, td, d, h, w, th, tw)
171                }
172            } else {
173                top_down_add_host(&lat, td, d, h, w, th, tw)
174            }
175        } else {
176            lat
177        };
178
179        // 3) Sinusoidal position encoding (precomputed at IR compile time).
180        let pos = ir
181            .as_ref()
182            .map(|ir| ir.pos[stage_idx].clone())
183            .unwrap_or_else(|| sinusoidal_pos_2d(d, h, w));
184
185        levels.push(FpnLevel {
186            features: level_features.clone(),
187            pos,
188            h,
189            w,
190        });
191        top_down = Some(level_features);
192        top_down_hw = Some((h, w));
193    }
194
195    // Levels were pushed coarse → fine. Reverse to fine → coarse.
196    levels.reverse();
197    Ok(levels)
198}
199
200fn top_down_add_host(
201    lat: &[f32],
202    prev: &[f32],
203    d: usize,
204    h: usize,
205    w: usize,
206    th: usize,
207    tw: usize,
208) -> Vec<f32> {
209    let mut summed = lat.to_vec();
210    for c in 0..d {
211        for y in 0..h {
212            let sy = y / 2;
213            for x in 0..w {
214                let sx = x / 2;
215                summed[c * h * w + y * w + x] += prev[c * th * tw + sy * tw + sx];
216            }
217        }
218    }
219    summed
220}
221
222fn lateral_conv_host(
223    cw: &[f32],
224    cb: &[f32],
225    src: &[f32],
226    dim_in: usize,
227    d: usize,
228    h: usize,
229    w: usize,
230) -> Vec<f32> {
231    let mut lat = vec![0f32; d * h * w];
232    for y in 0..h {
233        for x in 0..w {
234            let in_off = (y * w + x) * dim_in;
235            for oc in 0..d {
236                let mut acc = cb[oc];
237                for ic in 0..dim_in {
238                    acc += src[in_off + ic] * cw[oc * dim_in + ic];
239                }
240                lat[oc * h * w + y * w + x] = acc;
241            }
242        }
243    }
244    lat
245}
246
247/// Sinusoidal absolute position embedding (`PositionEmbeddingSine`),
248/// matching `sam2/modeling/position_encoding.py`:
249///   - `num_pos_feats = d_model / 2`, half for x, half for y
250///   - `normalize=True`, `temperature=10000`, `scale=2π`
251///   - output `[d_model, h, w]` NCHW with channel layout
252///     `[y_sin, y_cos, …, x_sin, x_cos, …]`.
253pub(super) fn sinusoidal_pos_2d(d_model: usize, h: usize, w: usize) -> Vec<f32> {
254    let nf = d_model / 2; // num_pos_feats per axis
255    let temperature: f32 = 10000.0;
256    let scale: f32 = 2.0 * PI;
257    let eps: f32 = 1e-6;
258    let mut out = vec![0f32; d_model * h * w];
259
260    // Per-axis dim_t scaling factors.
261    // dim_t = temperature ** (2 * (i // 2) / num_pos_feats)
262    let mut dim_t = vec![0f32; nf];
263    for i in 0..nf {
264        let exp = 2.0 * ((i / 2) as f32) / (nf as f32);
265        dim_t[i] = temperature.powf(exp);
266    }
267
268    // Build normalised y and x embeddings.
269    for y in 0..h {
270        let y_emb = ((y + 1) as f32) / ((h as f32) + eps) * scale;
271        for x in 0..w {
272            let x_emb = ((x + 1) as f32) / ((w as f32) + eps) * scale;
273            // y channels go to [0..nf), x channels to [nf..d_model).
274            for i in 0..nf {
275                let py = y_emb / dim_t[i];
276                let val = if i % 2 == 0 { py.sin() } else { py.cos() };
277                out[i * h * w + y * w + x] = val;
278            }
279            for i in 0..nf {
280                let px = x_emb / dim_t[i];
281                let val = if i % 2 == 0 { px.sin() } else { px.cos() };
282                out[(nf + i) * h * w + y * w + x] = val;
283            }
284        }
285    }
286    out
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::config::Sam2HieraConfig;
293
294    #[test]
295    fn pos_2d_shape_and_finite() {
296        let pos = sinusoidal_pos_2d(256, 32, 32);
297        assert_eq!(pos.len(), 256 * 32 * 32);
298        assert!(pos.iter().all(|v| v.is_finite()));
299    }
300
301    #[test]
302    fn fpn_levels_returned_fine_to_coarse() {
303        // Tiny check: just verify the spatial ordering convention
304        // (fine → coarse) holds for B+ when we run with synthetic
305        // weights of the right shape.
306        let cfg = Sam2HieraConfig::base_plus();
307        let fpn = Sam2FpnConfig::for_hiera(&cfg);
308        // Coarse-to-fine channel list:
309        assert_eq!(fpn.backbone_channel_list, vec![896, 448, 224, 112]);
310    }
311}