1use super::lateral::LateralConv1x1;
12use crate::{
13 error::{VisionError, VisionResult},
14 handle::LcgRng,
15};
16
17#[derive(Debug, Clone)]
23pub struct FeatureMap {
24 pub data: Vec<f32>,
26 pub channels: usize,
28 pub height: usize,
30 pub width: usize,
32}
33
34impl FeatureMap {
35 pub fn new(data: Vec<f32>, channels: usize, height: usize, width: usize) -> VisionResult<Self> {
40 let expected = channels * height * width;
41 if data.len() != expected {
42 return Err(VisionError::DimensionMismatch {
43 expected,
44 got: data.len(),
45 });
46 }
47 Ok(Self {
48 data,
49 channels,
50 height,
51 width,
52 })
53 }
54
55 #[inline]
60 pub fn at(&self, c: usize, h_idx: usize, w_idx: usize) -> f32 {
61 self.data[c * self.height * self.width + h_idx * self.width + w_idx]
62 }
63
64 #[inline]
66 fn len(&self) -> usize {
67 self.channels * self.height * self.width
68 }
69}
70
71pub struct FpnConfig {
75 pub in_channels: Vec<usize>,
78 pub out_channels: usize,
80}
81
82impl FpnConfig {
83 pub fn new(in_channels: Vec<usize>, out_channels: usize) -> VisionResult<Self> {
89 if in_channels.is_empty() {
90 return Err(VisionError::EmptyInput("FpnConfig::in_channels"));
91 }
92 if out_channels == 0 {
93 return Err(VisionError::InvalidImageSize {
94 height: 0,
95 width: 0,
96 channels: out_channels,
97 });
98 }
99 Ok(Self {
100 in_channels,
101 out_channels,
102 })
103 }
104
105 #[inline]
107 pub fn n_levels(&self) -> usize {
108 self.in_channels.len()
109 }
110}
111
112pub struct Fpn {
119 pub config: FpnConfig,
121 pub lateral_convs: Vec<LateralConv1x1>,
123 pub smooth_weights: Vec<Vec<f32>>,
125 pub smooth_biases: Vec<Vec<f32>>,
127}
128
129impl Fpn {
130 pub fn new(cfg: FpnConfig, rng: &mut LcgRng) -> VisionResult<Self> {
135 let n = cfg.n_levels();
136 let oc = cfg.out_channels;
137
138 let mut lateral_convs = Vec::with_capacity(n);
140 for &ic in &cfg.in_channels {
141 lateral_convs.push(LateralConv1x1::new(ic, oc, rng)?);
142 }
143
144 let smooth_scale = 1.0_f32 / ((oc * 9) as f32).sqrt();
146 let mut smooth_weights = Vec::with_capacity(n);
147 let mut smooth_biases = Vec::with_capacity(n);
148 for _ in 0..n {
149 let kernel_size = oc * oc * 9; let mut w = vec![0.0f32; kernel_size];
151 rng.fill_normal(&mut w);
152 for v in &mut w {
153 *v *= smooth_scale;
154 }
155 smooth_weights.push(w);
156 smooth_biases.push(vec![0.0f32; oc]);
157 }
158
159 Ok(Self {
160 config: cfg,
161 lateral_convs,
162 smooth_weights,
163 smooth_biases,
164 })
165 }
166
167 pub fn forward(&self, features: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
179 let n = self.config.n_levels();
180 if features.is_empty() {
181 return Err(VisionError::EmptyInput("Fpn::forward features"));
182 }
183 if features.len() != n {
184 return Err(VisionError::DimensionMismatch {
185 expected: n,
186 got: features.len(),
187 });
188 }
189
190 let oc = self.config.out_channels;
191
192 let mut lateral_maps: Vec<FeatureMap> = Vec::with_capacity(n);
194 for (l, feat) in features.iter().enumerate() {
195 let h = feat.height;
196 let w = feat.width;
197 let lateral_data = self.lateral_convs[l].forward(&feat.data, h, w)?;
198 lateral_maps.push(FeatureMap::new(lateral_data, oc, h, w)?);
199 }
200
201 let mut merged: Vec<FeatureMap> = Vec::with_capacity(n);
205 merged.push(lateral_maps[n - 1].clone());
207
208 for l in (0..n - 1).rev() {
211 let target_h = lateral_maps[l].height;
212 let target_w = lateral_maps[l].width;
213 let coarser = merged.last().expect("at least one element");
215 let upsampled = upsample_nearest(coarser, target_h, target_w);
216 let lat = &lateral_maps[l];
218 let mut merged_data = vec![0.0f32; lat.len()];
219 for (i, v) in merged_data.iter_mut().enumerate() {
220 *v = lat.data[i] + upsampled.data[i];
221 }
222 merged.push(FeatureMap::new(merged_data, oc, target_h, target_w)?);
223 }
224
225 merged.reverse();
227
228 let mut output: Vec<FeatureMap> = Vec::with_capacity(n);
230 for (l, fm) in merged.into_iter().enumerate() {
231 let smooth_data = conv3x3_same(
232 &fm.data,
233 oc,
234 fm.height,
235 fm.width,
236 &self.smooth_weights[l],
237 &self.smooth_biases[l],
238 oc,
239 );
240 output.push(FeatureMap::new(smooth_data, oc, fm.height, fm.width)?);
241 }
242
243 Ok(output)
244 }
245}
246
247fn upsample_nearest(feat: &FeatureMap, target_h: usize, target_w: usize) -> FeatureMap {
254 let src_h = feat.height;
255 let src_w = feat.width;
256 let c = feat.channels;
257 let mut out = vec![0.0f32; c * target_h * target_w];
258
259 for ch in 0..c {
260 for i in 0..target_h {
261 let src_i = (i * src_h / target_h).min(src_h.saturating_sub(1));
263 for j in 0..target_w {
264 let src_j = (j * src_w / target_w).min(src_w.saturating_sub(1));
265 out[ch * target_h * target_w + i * target_w + j] = feat.at(ch, src_i, src_j);
266 }
267 }
268 }
269
270 FeatureMap {
272 data: out,
273 channels: c,
274 height: target_h,
275 width: target_w,
276 }
277}
278
279fn conv3x3_same(
289 feat: &[f32],
290 channels: usize,
291 h: usize,
292 w: usize,
293 weight: &[f32],
294 bias: &[f32],
295 out_channels: usize,
296) -> Vec<f32> {
297 let mut out = vec![0.0f32; out_channels * h * w];
298
299 for oc in 0..out_channels {
300 for i in 0..h {
301 for j in 0..w {
302 let mut acc = bias[oc];
303 for ic in 0..channels {
304 for ki in 0..3usize {
306 let src_i = i as isize + ki as isize - 1;
307 if src_i < 0 || src_i >= h as isize {
308 continue; }
310 for kj in 0..3usize {
311 let src_j = j as isize + kj as isize - 1;
312 if src_j < 0 || src_j >= w as isize {
313 continue; }
315 let w_idx = oc * channels * 9 + ic * 9 + ki * 3 + kj;
317 let f_idx = ic * h * w + src_i as usize * w + src_j as usize;
318 acc += weight[w_idx] * feat[f_idx];
319 }
320 }
321 }
322 out[oc * h * w + i * w + j] = acc;
323 }
324 }
325 }
326
327 out
328}
329
330#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn make_rng() -> LcgRng {
337 LcgRng::new(123)
338 }
339
340 fn random_feature_map(rng: &mut LcgRng, channels: usize, h: usize, w: usize) -> FeatureMap {
342 let n = channels * h * w;
343 let mut data = vec![0.0f32; n];
344 rng.fill_normal(&mut data);
345 FeatureMap::new(data, channels, h, w).expect("valid feature map")
346 }
347
348 #[test]
351 fn feature_map_valid_construction() {
352 let data = vec![1.0f32; 3 * 4 * 4];
353 let fm = FeatureMap::new(data, 3, 4, 4).expect("valid feature map");
354 assert_eq!(fm.channels, 3);
355 assert_eq!(fm.height, 4);
356 assert_eq!(fm.width, 4);
357 }
358
359 #[test]
360 fn feature_map_wrong_size_errors() {
361 let data = vec![0.0f32; 3 * 4 * 4 - 1];
362 let r = FeatureMap::new(data, 3, 4, 4);
363 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
364 }
365
366 #[test]
367 fn feature_map_at_correct_value() {
368 let mut data = vec![0.0f32; 2 * 3 * 3];
370 for c in 0..2 {
371 for pos in 0..9 {
372 data[c * 9 + pos] = c as f32;
373 }
374 }
375 let fm = FeatureMap::new(data, 2, 3, 3).expect("valid feature map");
376 assert_eq!(fm.at(0, 1, 1), 0.0);
377 assert_eq!(fm.at(1, 0, 0), 1.0);
378 }
379
380 #[test]
383 fn fpn_config_valid() {
384 let cfg = FpnConfig::new(vec![2048, 1024, 512, 256], 256).expect("valid config");
385 assert_eq!(cfg.n_levels(), 4);
386 }
387
388 #[test]
389 fn fpn_config_empty_in_channels_errors() {
390 let r = FpnConfig::new(vec![], 256);
391 assert!(r.is_err());
392 }
393
394 #[test]
395 fn fpn_config_zero_out_channels_errors() {
396 let r = FpnConfig::new(vec![512, 256], 0);
397 assert!(r.is_err());
398 }
399
400 #[test]
403 fn upsample_nearest_doubles_size() {
404 let data = vec![1.0, 2.0, 3.0, 4.0]; let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
406 let up = upsample_nearest(&fm, 4, 4);
407 assert_eq!(up.height, 4);
408 assert_eq!(up.width, 4);
409 assert_eq!(up.channels, 1);
410 assert_eq!(up.data.len(), 4 * 4);
411 }
412
413 #[test]
414 fn upsample_nearest_values_replicated() {
415 let data = vec![1.0f32, 2.0, 3.0, 4.0];
417 let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
418 let up = upsample_nearest(&fm, 4, 4);
419 assert_eq!(up.at(0, 0, 0), 1.0);
421 assert_eq!(up.at(0, 0, 1), 1.0);
422 assert_eq!(up.at(0, 1, 0), 1.0);
423 assert_eq!(up.at(0, 1, 1), 1.0);
424 assert_eq!(up.at(0, 2, 2), 4.0);
426 assert_eq!(up.at(0, 3, 3), 4.0);
427 }
428
429 #[test]
430 fn upsample_nearest_identity_when_same_size() {
431 let mut rng = make_rng();
432 let fm = random_feature_map(&mut rng, 4, 5, 7);
433 let up = upsample_nearest(&fm, 5, 7);
434 for (a, b) in fm.data.iter().zip(up.data.iter()) {
435 assert_eq!(*a, *b, "identity upsample should be exact copy");
436 }
437 }
438
439 #[test]
442 fn fpn_forward_output_channel_count_uniform() {
443 let mut rng = make_rng();
444 let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
445 let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
446 let features = vec![
447 random_feature_map(&mut rng, 64, 4, 4),
448 random_feature_map(&mut rng, 32, 8, 8),
449 ];
450 let output = fpn.forward(features).expect("FPN forward ok");
451 assert_eq!(output.len(), 2, "two output levels");
452 for fm in &output {
453 assert_eq!(fm.channels, 16, "all output levels should have 16 channels");
454 }
455 }
456
457 #[test]
458 fn fpn_forward_preserves_spatial_dims() {
459 let mut rng = make_rng();
460 let cfg = FpnConfig::new(vec![32, 16], 8).expect("valid config");
461 let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
462 let features = vec![
463 random_feature_map(&mut rng, 32, 3, 3),
464 random_feature_map(&mut rng, 16, 6, 6),
465 ];
466 let output = fpn.forward(features).expect("FPN forward ok");
467 assert_eq!(output[0].height, 3);
468 assert_eq!(output[0].width, 3);
469 assert_eq!(output[1].height, 6);
470 assert_eq!(output[1].width, 6);
471 }
472
473 #[test]
474 fn fpn_forward_three_levels() {
475 let mut rng = make_rng();
476 let cfg = FpnConfig::new(vec![64, 32, 16], 8).expect("valid config");
477 let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
478 let features = vec![
479 random_feature_map(&mut rng, 64, 2, 2),
480 random_feature_map(&mut rng, 32, 4, 4),
481 random_feature_map(&mut rng, 16, 8, 8),
482 ];
483 let output = fpn.forward(features).expect("FPN forward 3 levels ok");
484 assert_eq!(output.len(), 3);
485 for fm in &output {
486 assert_eq!(fm.channels, 8);
487 assert!(fm.data.iter().all(|v| v.is_finite()), "non-finite output");
488 }
489 }
490
491 #[test]
492 fn fpn_forward_wrong_level_count_errors() {
493 let mut rng = make_rng();
494 let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
495 let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
496 let features = vec![random_feature_map(&mut rng, 64, 4, 4)];
498 let r = fpn.forward(features);
499 assert!(
500 matches!(
501 r,
502 Err(VisionError::DimensionMismatch {
503 expected: 2,
504 got: 1
505 })
506 ),
507 "expected DimensionMismatch error"
508 );
509 }
510
511 #[test]
512 fn fpn_forward_empty_features_errors() {
513 let mut rng = make_rng();
514 let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
515 let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
516 let r = fpn.forward(vec![]);
517 assert!(r.is_err(), "expected error for empty features");
518 }
519
520 #[test]
523 fn conv3x3_same_output_shape() {
524 let feat = vec![0.5f32; 4 * 6 * 6];
525 let weight = vec![0.0f32; 4 * 4 * 9];
526 let bias = vec![1.0f32; 4]; let out = conv3x3_same(&feat, 4, 6, 6, &weight, &bias, 4);
528 assert_eq!(
529 out.len(),
530 4 * 6 * 6,
531 "output size matches input spatial dims"
532 );
533 for v in &out {
535 assert!((*v - 1.0).abs() < 1e-6, "expected 1.0, got {v}");
536 }
537 }
538}