Skip to main content

rlx_sam2/
fpn_neck_ir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM2 FPN neck IR: lateral 1×1 convs + top-down nearest ×2 fusion.
17
18use super::fpn_neck::FpnNeckWeights;
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{bhwc_to_nchw, conv2d_bias};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27/// Per-backbone-stage lateral 1×1 (`dim_in` → `d_model`).
28pub struct Sam2FpnLateralCompiled {
29    graph: CompiledGraph,
30    pub in_c: usize,
31    pub out_c: usize,
32    pub h: usize,
33    pub w: usize,
34}
35
36impl Sam2FpnLateralCompiled {
37    pub fn compile(
38        in_c: usize,
39        out_c: usize,
40        h: usize,
41        w: usize,
42        weight: &[f32],
43        bias: &[f32],
44        device: Device,
45    ) -> Result<Self> {
46        Self::compile_with_profile(
47            in_c,
48            out_c,
49            h,
50            w,
51            weight,
52            bias,
53            device,
54            &CompileProfile::sam_encoder(),
55        )
56    }
57
58    pub fn compile_with_profile(
59        in_c: usize,
60        out_c: usize,
61        h: usize,
62        w: usize,
63        weight: &[f32],
64        bias: &[f32],
65        device: Device,
66        profile: &CompileProfile,
67    ) -> Result<Self> {
68        let (graph, params) = build_lateral_graph(in_c, out_c, h, w, weight, bias)?;
69        let mut compiled =
70            rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
71        for (name, data) in &params {
72            compiled.set_param(name, data);
73        }
74        Ok(Self {
75            graph: compiled,
76            in_c,
77            out_c,
78            h,
79            w,
80        })
81    }
82
83    /// Encoder stage output, BHWC flat `[1, H, W, in_c]`.
84    pub fn run(&mut self, stage_bhwc: &[f32]) -> Result<Vec<f32>> {
85        let expected = self.in_c * self.h * self.w;
86        anyhow::ensure!(
87            stage_bhwc.len() == expected,
88            "FPN lateral input len {} ≠ {expected}",
89            stage_bhwc.len()
90        );
91        let outs = self.graph.run(&[("stage", stage_bhwc)]);
92        Ok(outs.into_iter().next().expect("fpn lateral output"))
93    }
94}
95
96/// Top-down fusion: `lat + ResizeNearest2x(prev)` at 2× resolution.
97pub struct Sam2FpnTopDownCompiled {
98    graph: CompiledGraph,
99    pub channels: usize,
100    pub prev_h: usize,
101    pub prev_w: usize,
102    pub out_h: usize,
103    pub out_w: usize,
104}
105
106impl Sam2FpnTopDownCompiled {
107    pub fn compile(channels: usize, prev_h: usize, prev_w: usize, device: Device) -> Result<Self> {
108        Self::compile_with_profile(
109            channels,
110            prev_h,
111            prev_w,
112            device,
113            &CompileProfile::sam_encoder(),
114        )
115    }
116
117    pub fn compile_with_profile(
118        channels: usize,
119        prev_h: usize,
120        prev_w: usize,
121        device: Device,
122        profile: &CompileProfile,
123    ) -> Result<Self> {
124        let out_h = prev_h * 2;
125        let out_w = prev_w * 2;
126        let (graph, _params) = build_top_down_graph(channels, prev_h, prev_w)?;
127        let compiled = rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
128        Ok(Self {
129            graph: compiled,
130            channels,
131            prev_h,
132            prev_w,
133            out_h,
134            out_w,
135        })
136    }
137
138    /// `lat` / `prev`: NCHW flat `[C, H, W]` at output / previous resolution.
139    pub fn run(&mut self, lat: &[f32], prev: &[f32]) -> Result<Vec<f32>> {
140        let lat_n = self.channels * self.out_h * self.out_w;
141        let prev_n = self.channels * self.prev_h * self.prev_w;
142        anyhow::ensure!(
143            lat.len() == lat_n,
144            "FPN fuse lat len {} ≠ {lat_n}",
145            lat.len()
146        );
147        anyhow::ensure!(
148            prev.len() == prev_n,
149            "FPN fuse prev len {} ≠ {prev_n}",
150            prev.len()
151        );
152        let outs = self.graph.run(&[("lat", lat), ("prev", prev)]);
153        Ok(outs.into_iter().next().expect("fpn top_down output"))
154    }
155}
156
157/// One compiled lateral per encoder stage (index = stage 0 finest … n-1 coarsest).
158pub struct Sam2FpnNeckIr {
159    pub laterals: Vec<Sam2FpnLateralCompiled>,
160    /// `fuses[stage_idx]` when that stage receives top-down fusion.
161    pub fuses: Vec<Option<Sam2FpnTopDownCompiled>>,
162    /// Per-stage sinusoidal PE `[d_model, h, w]` NCHW flat (index = stage).
163    pub pos: Vec<Vec<f32>>,
164}
165
166pub fn compile_fpn_neck_ir(
167    neck: &FpnNeckWeights,
168    stage_hw: &[(usize, usize)],
169    stage_dims: &[usize],
170    device: Device,
171    profile: &CompileProfile,
172) -> Result<Sam2FpnNeckIr> {
173    let n = stage_hw.len();
174    anyhow::ensure!(
175        stage_dims.len() == n && neck.conv_w.len() == n,
176        "FPN compile: stage count mismatch"
177    );
178    let mut laterals = Vec::with_capacity(n);
179    let mut pos = Vec::with_capacity(n);
180    for stage_idx in 0..n {
181        let (h, w) = stage_hw[stage_idx];
182        pos.push(super::fpn_neck::sinusoidal_pos_2d(neck.d_model, h, w));
183        let conv_idx = n - 1 - stage_idx;
184        let (h, w) = stage_hw[stage_idx];
185        let in_c = stage_dims[stage_idx];
186        laterals.push(Sam2FpnLateralCompiled::compile_with_profile(
187            in_c,
188            neck.d_model,
189            h,
190            w,
191            &neck.conv_w[conv_idx],
192            &neck.conv_b[conv_idx],
193            device,
194            profile,
195        )?);
196    }
197    let mut fuses: Vec<Option<Sam2FpnTopDownCompiled>> = (0..n).map(|_| None).collect();
198    for &stage_idx in &neck.fpn_top_down_levels {
199        anyhow::ensure!(
200            stage_idx < n,
201            "fpn_top_down_levels index {stage_idx} out of range"
202        );
203        let (h, w) = stage_hw[stage_idx];
204        anyhow::ensure!(
205            h % 2 == 0 && w % 2 == 0,
206            "FPN top-down at stage {stage_idx} needs even h,w, got {h}×{w}"
207        );
208        fuses[stage_idx] = Some(Sam2FpnTopDownCompiled::compile_with_profile(
209            neck.d_model,
210            h / 2,
211            w / 2,
212            device,
213            profile,
214        )?);
215    }
216    Ok(Sam2FpnNeckIr {
217        laterals,
218        fuses,
219        pos,
220    })
221}
222
223fn build_top_down_graph(
224    channels: usize,
225    prev_h: usize,
226    prev_w: usize,
227) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
228    let f = DType::F32;
229    let out_h = prev_h * 2;
230    let out_w = prev_w * 2;
231    let mut hir = HirModule::new("sam2_fpn_top_down");
232    let mut g = HirMut::new(&mut hir);
233
234    let lat = g.input("lat", Shape::new(&[1, channels, out_h, out_w], f));
235    let prev = g.input("prev", Shape::new(&[1, channels, prev_h, prev_w], f));
236    let up = g.resize_nearest_2x(prev);
237    let out = g.add(lat, up);
238
239    hir.set_outputs(vec![out]);
240    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
241    Ok((graph, HashMap::new()))
242}
243
244fn build_lateral_graph(
245    in_c: usize,
246    out_c: usize,
247    h: usize,
248    w: usize,
249    weight: &[f32],
250    bias: &[f32],
251) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
252    let f = DType::F32;
253    let mut hir = HirModule::new("sam2_fpn_lateral");
254    let mut params = HashMap::new();
255    let mut g = HirMut::new(&mut hir);
256
257    let stage = g.input("stage", Shape::new(&[1, h, w, in_c], f));
258    let x = bhwc_to_nchw(&mut g, stage, 1, h, w, in_c);
259    let wt = param(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
260    let bt = param(&mut g, &mut params, "b", bias, &[out_c]);
261    let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
262
263    hir.set_outputs(vec![y]);
264    let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
265    Ok((graph, params))
266}
267
268fn param(
269    g: &mut HirMut<'_>,
270    params: &mut HashMap<String, Vec<f32>>,
271    name: &str,
272    data: &[f32],
273    shape: &[usize],
274) -> HirNodeId {
275    let id = g.param(name, Shape::new(shape, DType::F32));
276    params.insert(name.to_string(), data.to_vec());
277    id
278}