1use super::config::{Sam2FpnConfig, Sam2HieraConfig};
22use anyhow::{Result, ensure};
23use rlx_core::weight_map::WeightMap;
24use std::f32::consts::PI;
25
26pub struct FpnNeckWeights {
30 pub conv_w: Vec<Vec<f32>>,
33 pub conv_b: Vec<Vec<f32>>,
34 pub d_model: usize,
35 pub backbone_channel_list: Vec<usize>,
36 pub fpn_top_down_levels: Vec<usize>,
37 pub nearest: bool,
38}
39
40pub(super) fn extract_fpn_weights(
41 weights: &mut WeightMap,
42 cfg: &Sam2HieraConfig,
43) -> Result<FpnNeckWeights> {
44 let fpn = Sam2FpnConfig::for_hiera(cfg);
45 let n = fpn.backbone_channel_list.len();
46 let d = fpn.d_model;
47
48 let mut conv_w = Vec::with_capacity(n);
49 let mut conv_b = Vec::with_capacity(n);
50 for i in 0..n {
51 let cin = fpn.backbone_channel_list[i];
52 let (raw_w, w_shape) =
53 weights.take(&format!("image_encoder.neck.convs.{i}.conv.weight"))?;
54 ensure!(
55 w_shape == vec![d, cin, 1, 1],
56 "neck.convs.{i}.conv.weight expected [{d}, {cin}, 1, 1], got {w_shape:?}"
57 );
58 let (raw_b, _) = weights.take(&format!("image_encoder.neck.convs.{i}.conv.bias"))?;
59 conv_w.push(raw_w);
60 conv_b.push(raw_b);
61 }
62 Ok(FpnNeckWeights {
63 conv_w,
64 conv_b,
65 d_model: d,
66 backbone_channel_list: fpn.backbone_channel_list,
67 fpn_top_down_levels: fpn.fpn_top_down_levels,
68 nearest: fpn.interpolation_nearest,
69 })
70}
71
72pub struct FpnLevel {
75 pub features: Vec<f32>,
77 pub pos: Vec<f32>,
79 pub h: usize,
80 pub w: usize,
81}
82
83pub fn apply_fpn_neck(
93 neck: &FpnNeckWeights,
94 ir: &mut super::fpn_neck_ir::Sam2FpnNeckIr,
95 stage_outputs: &[Vec<f32>],
96 stage_hw: &[(usize, usize)],
97 stage_dims: &[usize],
98) -> Result<Vec<FpnLevel>> {
99 apply_fpn_neck_impl(neck, Some(ir), stage_outputs, stage_hw, stage_dims)
100}
101
102pub fn apply_fpn_neck_host(
104 neck: &FpnNeckWeights,
105 stage_outputs: &[Vec<f32>],
106 stage_hw: &[(usize, usize)],
107 stage_dims: &[usize],
108) -> Vec<FpnLevel> {
109 apply_fpn_neck_impl(neck, None, stage_outputs, stage_hw, stage_dims).expect("host FPN neck")
110}
111
112fn apply_fpn_neck_impl(
113 neck: &FpnNeckWeights,
114 mut ir: Option<&mut super::fpn_neck_ir::Sam2FpnNeckIr>,
115 stage_outputs: &[Vec<f32>],
116 stage_hw: &[(usize, usize)],
117 stage_dims: &[usize],
118) -> Result<Vec<FpnLevel>> {
119 let n = neck.backbone_channel_list.len();
120 assert_eq!(stage_outputs.len(), n);
121 assert_eq!(stage_hw.len(), n);
122 assert_eq!(stage_dims.len(), n);
123 let d = neck.d_model;
124
125 let mut top_down: Option<Vec<f32>> = None;
133 let mut top_down_hw: Option<(usize, usize)> = None;
134 let mut levels: Vec<FpnLevel> = Vec::with_capacity(n);
135
136 for coarse_i in 0..n {
137 let stage_idx = n - 1 - coarse_i; let conv_idx = coarse_i; let (h, w) = stage_hw[stage_idx];
142 let dim_in = stage_dims[stage_idx];
143 debug_assert_eq!(dim_in, neck.backbone_channel_list[conv_idx]);
144
145 let lat = match ir.as_deref_mut() {
147 Some(ir_neck) => ir_neck.laterals[stage_idx].run(&stage_outputs[stage_idx])?,
148 None => lateral_conv_host(
149 &neck.conv_w[conv_idx],
150 &neck.conv_b[conv_idx],
151 &stage_outputs[stage_idx],
152 dim_in,
153 d,
154 h,
155 w,
156 ),
157 };
158
159 let level_features = if neck.fpn_top_down_levels.contains(&stage_idx)
161 && let Some(td) = top_down.as_ref()
162 {
163 let (th, tw) = top_down_hw.unwrap();
164 debug_assert_eq!(th * 2, h);
165 debug_assert_eq!(tw * 2, w);
166 if let Some(ir_neck) = ir.as_deref_mut() {
167 if let Some(fuse) = ir_neck.fuses.get_mut(stage_idx).and_then(|f| f.as_mut()) {
168 fuse.run(&lat, td)?
169 } else {
170 top_down_add_host(&lat, td, d, h, w, th, tw)
171 }
172 } else {
173 top_down_add_host(&lat, td, d, h, w, th, tw)
174 }
175 } else {
176 lat
177 };
178
179 let pos = ir
181 .as_ref()
182 .map(|ir| ir.pos[stage_idx].clone())
183 .unwrap_or_else(|| sinusoidal_pos_2d(d, h, w));
184
185 levels.push(FpnLevel {
186 features: level_features.clone(),
187 pos,
188 h,
189 w,
190 });
191 top_down = Some(level_features);
192 top_down_hw = Some((h, w));
193 }
194
195 levels.reverse();
197 Ok(levels)
198}
199
200fn top_down_add_host(
201 lat: &[f32],
202 prev: &[f32],
203 d: usize,
204 h: usize,
205 w: usize,
206 th: usize,
207 tw: usize,
208) -> Vec<f32> {
209 let mut summed = lat.to_vec();
210 for c in 0..d {
211 for y in 0..h {
212 let sy = y / 2;
213 for x in 0..w {
214 let sx = x / 2;
215 summed[c * h * w + y * w + x] += prev[c * th * tw + sy * tw + sx];
216 }
217 }
218 }
219 summed
220}
221
222fn lateral_conv_host(
223 cw: &[f32],
224 cb: &[f32],
225 src: &[f32],
226 dim_in: usize,
227 d: usize,
228 h: usize,
229 w: usize,
230) -> Vec<f32> {
231 let mut lat = vec![0f32; d * h * w];
232 for y in 0..h {
233 for x in 0..w {
234 let in_off = (y * w + x) * dim_in;
235 for oc in 0..d {
236 let mut acc = cb[oc];
237 for ic in 0..dim_in {
238 acc += src[in_off + ic] * cw[oc * dim_in + ic];
239 }
240 lat[oc * h * w + y * w + x] = acc;
241 }
242 }
243 }
244 lat
245}
246
247pub(super) fn sinusoidal_pos_2d(d_model: usize, h: usize, w: usize) -> Vec<f32> {
254 let nf = d_model / 2; let temperature: f32 = 10000.0;
256 let scale: f32 = 2.0 * PI;
257 let eps: f32 = 1e-6;
258 let mut out = vec![0f32; d_model * h * w];
259
260 let mut dim_t = vec![0f32; nf];
263 for i in 0..nf {
264 let exp = 2.0 * ((i / 2) as f32) / (nf as f32);
265 dim_t[i] = temperature.powf(exp);
266 }
267
268 for y in 0..h {
270 let y_emb = ((y + 1) as f32) / ((h as f32) + eps) * scale;
271 for x in 0..w {
272 let x_emb = ((x + 1) as f32) / ((w as f32) + eps) * scale;
273 for i in 0..nf {
275 let py = y_emb / dim_t[i];
276 let val = if i % 2 == 0 { py.sin() } else { py.cos() };
277 out[i * h * w + y * w + x] = val;
278 }
279 for i in 0..nf {
280 let px = x_emb / dim_t[i];
281 let val = if i % 2 == 0 { px.sin() } else { px.cos() };
282 out[(nf + i) * h * w + y * w + x] = val;
283 }
284 }
285 }
286 out
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::config::Sam2HieraConfig;
293
294 #[test]
295 fn pos_2d_shape_and_finite() {
296 let pos = sinusoidal_pos_2d(256, 32, 32);
297 assert_eq!(pos.len(), 256 * 32 * 32);
298 assert!(pos.iter().all(|v| v.is_finite()));
299 }
300
301 #[test]
302 fn fpn_levels_returned_fine_to_coarse() {
303 let cfg = Sam2HieraConfig::base_plus();
307 let fpn = Sam2FpnConfig::for_hiera(&cfg);
308 assert_eq!(fpn.backbone_channel_list, vec![896, 448, 224, 112]);
310 }
311}