1use super::fpn_neck::FpnNeckWeights;
19use anyhow::Result;
20use rlx_core::vision_ops_ir::{bhwc_to_nchw, conv2d_bias};
21use rlx_flow::CompileProfile;
22use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
23use rlx_ir::{DType, Graph, HirGraphExt, Shape};
24use rlx_runtime::{CompiledGraph, Device};
25use std::collections::HashMap;
26
27pub struct Sam2FpnLateralCompiled {
29 graph: CompiledGraph,
30 pub in_c: usize,
31 pub out_c: usize,
32 pub h: usize,
33 pub w: usize,
34}
35
36impl Sam2FpnLateralCompiled {
37 pub fn compile(
38 in_c: usize,
39 out_c: usize,
40 h: usize,
41 w: usize,
42 weight: &[f32],
43 bias: &[f32],
44 device: Device,
45 ) -> Result<Self> {
46 Self::compile_with_profile(
47 in_c,
48 out_c,
49 h,
50 w,
51 weight,
52 bias,
53 device,
54 &CompileProfile::sam_encoder(),
55 )
56 }
57
58 pub fn compile_with_profile(
59 in_c: usize,
60 out_c: usize,
61 h: usize,
62 w: usize,
63 weight: &[f32],
64 bias: &[f32],
65 device: Device,
66 profile: &CompileProfile,
67 ) -> Result<Self> {
68 let (graph, params) = build_lateral_graph(in_c, out_c, h, w, weight, bias)?;
69 let mut compiled =
70 rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
71 for (name, data) in ¶ms {
72 compiled.set_param(name, data);
73 }
74 Ok(Self {
75 graph: compiled,
76 in_c,
77 out_c,
78 h,
79 w,
80 })
81 }
82
83 pub fn run(&mut self, stage_bhwc: &[f32]) -> Result<Vec<f32>> {
85 let expected = self.in_c * self.h * self.w;
86 anyhow::ensure!(
87 stage_bhwc.len() == expected,
88 "FPN lateral input len {} ≠ {expected}",
89 stage_bhwc.len()
90 );
91 let outs = self.graph.run(&[("stage", stage_bhwc)]);
92 Ok(outs.into_iter().next().expect("fpn lateral output"))
93 }
94}
95
96pub struct Sam2FpnTopDownCompiled {
98 graph: CompiledGraph,
99 pub channels: usize,
100 pub prev_h: usize,
101 pub prev_w: usize,
102 pub out_h: usize,
103 pub out_w: usize,
104}
105
106impl Sam2FpnTopDownCompiled {
107 pub fn compile(channels: usize, prev_h: usize, prev_w: usize, device: Device) -> Result<Self> {
108 Self::compile_with_profile(
109 channels,
110 prev_h,
111 prev_w,
112 device,
113 &CompileProfile::sam_encoder(),
114 )
115 }
116
117 pub fn compile_with_profile(
118 channels: usize,
119 prev_h: usize,
120 prev_w: usize,
121 device: Device,
122 profile: &CompileProfile,
123 ) -> Result<Self> {
124 let out_h = prev_h * 2;
125 let out_w = prev_w * 2;
126 let (graph, _params) = build_top_down_graph(channels, prev_h, prev_w)?;
127 let compiled = rlx_core::flow_bridge::compile_graph_with_profile(device, graph, profile)?;
128 Ok(Self {
129 graph: compiled,
130 channels,
131 prev_h,
132 prev_w,
133 out_h,
134 out_w,
135 })
136 }
137
138 pub fn run(&mut self, lat: &[f32], prev: &[f32]) -> Result<Vec<f32>> {
140 let lat_n = self.channels * self.out_h * self.out_w;
141 let prev_n = self.channels * self.prev_h * self.prev_w;
142 anyhow::ensure!(
143 lat.len() == lat_n,
144 "FPN fuse lat len {} ≠ {lat_n}",
145 lat.len()
146 );
147 anyhow::ensure!(
148 prev.len() == prev_n,
149 "FPN fuse prev len {} ≠ {prev_n}",
150 prev.len()
151 );
152 let outs = self.graph.run(&[("lat", lat), ("prev", prev)]);
153 Ok(outs.into_iter().next().expect("fpn top_down output"))
154 }
155}
156
157pub struct Sam2FpnNeckIr {
159 pub laterals: Vec<Sam2FpnLateralCompiled>,
160 pub fuses: Vec<Option<Sam2FpnTopDownCompiled>>,
162 pub pos: Vec<Vec<f32>>,
164}
165
166pub fn compile_fpn_neck_ir(
167 neck: &FpnNeckWeights,
168 stage_hw: &[(usize, usize)],
169 stage_dims: &[usize],
170 device: Device,
171 profile: &CompileProfile,
172) -> Result<Sam2FpnNeckIr> {
173 let n = stage_hw.len();
174 anyhow::ensure!(
175 stage_dims.len() == n && neck.conv_w.len() == n,
176 "FPN compile: stage count mismatch"
177 );
178 let mut laterals = Vec::with_capacity(n);
179 let mut pos = Vec::with_capacity(n);
180 for stage_idx in 0..n {
181 let (h, w) = stage_hw[stage_idx];
182 pos.push(super::fpn_neck::sinusoidal_pos_2d(neck.d_model, h, w));
183 let conv_idx = n - 1 - stage_idx;
184 let (h, w) = stage_hw[stage_idx];
185 let in_c = stage_dims[stage_idx];
186 laterals.push(Sam2FpnLateralCompiled::compile_with_profile(
187 in_c,
188 neck.d_model,
189 h,
190 w,
191 &neck.conv_w[conv_idx],
192 &neck.conv_b[conv_idx],
193 device,
194 profile,
195 )?);
196 }
197 let mut fuses: Vec<Option<Sam2FpnTopDownCompiled>> = (0..n).map(|_| None).collect();
198 for &stage_idx in &neck.fpn_top_down_levels {
199 anyhow::ensure!(
200 stage_idx < n,
201 "fpn_top_down_levels index {stage_idx} out of range"
202 );
203 let (h, w) = stage_hw[stage_idx];
204 anyhow::ensure!(
205 h % 2 == 0 && w % 2 == 0,
206 "FPN top-down at stage {stage_idx} needs even h,w, got {h}×{w}"
207 );
208 fuses[stage_idx] = Some(Sam2FpnTopDownCompiled::compile_with_profile(
209 neck.d_model,
210 h / 2,
211 w / 2,
212 device,
213 profile,
214 )?);
215 }
216 Ok(Sam2FpnNeckIr {
217 laterals,
218 fuses,
219 pos,
220 })
221}
222
223fn build_top_down_graph(
224 channels: usize,
225 prev_h: usize,
226 prev_w: usize,
227) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
228 let f = DType::F32;
229 let out_h = prev_h * 2;
230 let out_w = prev_w * 2;
231 let mut hir = HirModule::new("sam2_fpn_top_down");
232 let mut g = HirMut::new(&mut hir);
233
234 let lat = g.input("lat", Shape::new(&[1, channels, out_h, out_w], f));
235 let prev = g.input("prev", Shape::new(&[1, channels, prev_h, prev_w], f));
236 let up = g.resize_nearest_2x(prev);
237 let out = g.add(lat, up);
238
239 hir.set_outputs(vec![out]);
240 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
241 Ok((graph, HashMap::new()))
242}
243
244fn build_lateral_graph(
245 in_c: usize,
246 out_c: usize,
247 h: usize,
248 w: usize,
249 weight: &[f32],
250 bias: &[f32],
251) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
252 let f = DType::F32;
253 let mut hir = HirModule::new("sam2_fpn_lateral");
254 let mut params = HashMap::new();
255 let mut g = HirMut::new(&mut hir);
256
257 let stage = g.input("stage", Shape::new(&[1, h, w, in_c], f));
258 let x = bhwc_to_nchw(&mut g, stage, 1, h, w, in_c);
259 let wt = param(&mut g, &mut params, "w", weight, &[out_c, in_c, 1, 1]);
260 let bt = param(&mut g, &mut params, "b", bias, &[out_c]);
261 let y = conv2d_bias(&mut g, x, wt, bt, 1, out_c, 1, 1, [1, 1], [0, 0], h, w);
262
263 hir.set_outputs(vec![y]);
264 let graph = Graph::from_hir(hir).map_err(|e| anyhow::anyhow!("{e}"))?;
265 Ok((graph, params))
266}
267
268fn param(
269 g: &mut HirMut<'_>,
270 params: &mut HashMap<String, Vec<f32>>,
271 name: &str,
272 data: &[f32],
273 shape: &[usize],
274) -> HirNodeId {
275 let id = g.param(name, Shape::new(shape, DType::F32));
276 params.insert(name.to_string(), data.to_vec());
277 id
278}