Skip to main content

oxicuda_vision/fpn/
top_down.rs

1//! Feature Pyramid Network top-down pathway.
2//!
3//! Implements the FPN as described in "Feature Pyramid Networks for Object Detection"
4//! (Lin et al., 2017).  The forward pass:
5//!
6//! 1. Apply a 1×1 lateral convolution per backbone level to unify channel counts.
7//! 2. Top-down merge: start from the coarsest level, upsample (nearest-neighbour)
8//!    and add to the next finer level.
9//! 3. Apply a 3×3 "anti-aliasing" (smoothing) convolution to each merged level.
10
11use super::lateral::LateralConv1x1;
12use crate::{
13    error::{VisionError, VisionResult},
14    handle::LcgRng,
15};
16
17// ─── FeatureMap ───────────────────────────────────────────────────────────────
18
19/// A single-level feature map in CHW layout.
20///
21/// `data` is stored as `[channels, height, width]` in row-major order.
22#[derive(Debug, Clone)]
23pub struct FeatureMap {
24    /// Flat data buffer: `[channels × height × width]`.
25    pub data: Vec<f32>,
26    /// Number of feature channels.
27    pub channels: usize,
28    /// Spatial height (rows).
29    pub height: usize,
30    /// Spatial width (columns).
31    pub width: usize,
32}
33
34impl FeatureMap {
35    /// Construct a `FeatureMap`, validating that the data length matches the shape.
36    ///
37    /// # Errors
38    /// `DimensionMismatch` if `data.len() != channels * height * width`.
39    pub fn new(data: Vec<f32>, channels: usize, height: usize, width: usize) -> VisionResult<Self> {
40        let expected = channels * height * width;
41        if data.len() != expected {
42            return Err(VisionError::DimensionMismatch {
43                expected,
44                got: data.len(),
45            });
46        }
47        Ok(Self {
48            data,
49            channels,
50            height,
51            width,
52        })
53    }
54
55    /// Sample the value at channel `c`, row `h_idx`, column `w_idx`.
56    ///
57    /// Panics in debug builds if any index is out of range (same contract as
58    /// direct slice indexing — callers are responsible for bounds checking).
59    #[inline]
60    pub fn at(&self, c: usize, h_idx: usize, w_idx: usize) -> f32 {
61        self.data[c * self.height * self.width + h_idx * self.width + w_idx]
62    }
63
64    /// Total number of elements.
65    #[inline]
66    fn len(&self) -> usize {
67        self.channels * self.height * self.width
68    }
69}
70
71// ─── FpnConfig ────────────────────────────────────────────────────────────────
72
73/// Configuration for the Feature Pyramid Network.
74pub struct FpnConfig {
75    /// Input channel counts per backbone level, ordered **coarse→fine**
76    /// (e.g. `[2048, 1024, 512, 256]`).
77    pub in_channels: Vec<usize>,
78    /// Output channel count (uniform across all FPN levels after lateral convs).
79    pub out_channels: usize,
80}
81
82impl FpnConfig {
83    /// Construct and validate an `FpnConfig`.
84    ///
85    /// # Errors
86    /// - `EmptyInput` if `in_channels` is empty.
87    /// - `InvalidImageSize` if `out_channels == 0`.
88    pub fn new(in_channels: Vec<usize>, out_channels: usize) -> VisionResult<Self> {
89        if in_channels.is_empty() {
90            return Err(VisionError::EmptyInput("FpnConfig::in_channels"));
91        }
92        if out_channels == 0 {
93            return Err(VisionError::InvalidImageSize {
94                height: 0,
95                width: 0,
96                channels: out_channels,
97            });
98        }
99        Ok(Self {
100            in_channels,
101            out_channels,
102        })
103    }
104
105    /// Number of FPN levels.
106    #[inline]
107    pub fn n_levels(&self) -> usize {
108        self.in_channels.len()
109    }
110}
111
112// ─── Fpn ──────────────────────────────────────────────────────────────────────
113
114/// Feature Pyramid Network.
115///
116/// Accepts a list of feature maps at multiple scales (coarsest first) and
117/// returns a pyramid with uniform `out_channels` after the top-down path.
118pub struct Fpn {
119    /// FPN configuration (includes level input channels and output channels).
120    pub config: FpnConfig,
121    /// 1×1 lateral convolutions, one per backbone level.
122    pub lateral_convs: Vec<LateralConv1x1>,
123    /// 3×3 smoothing convolution weights per level: each `[out_c × out_c × 9]`.
124    pub smooth_weights: Vec<Vec<f32>>,
125    /// Biases for the smoothing convolutions: each `[out_c]`.
126    pub smooth_biases: Vec<Vec<f32>>,
127}
128
129impl Fpn {
130    /// Build an FPN with Xavier-initialised weights.
131    ///
132    /// # Errors
133    /// Propagates errors from `FpnConfig` or lateral conv construction.
134    pub fn new(cfg: FpnConfig, rng: &mut LcgRng) -> VisionResult<Self> {
135        let n = cfg.n_levels();
136        let oc = cfg.out_channels;
137
138        // One lateral conv per level: in_channels[l] → out_channels
139        let mut lateral_convs = Vec::with_capacity(n);
140        for &ic in &cfg.in_channels {
141            lateral_convs.push(LateralConv1x1::new(ic, oc, rng)?);
142        }
143
144        // Xavier scale for 3×3 conv: fan_in = out_channels * 9
145        let smooth_scale = 1.0_f32 / ((oc * 9) as f32).sqrt();
146        let mut smooth_weights = Vec::with_capacity(n);
147        let mut smooth_biases = Vec::with_capacity(n);
148        for _ in 0..n {
149            let kernel_size = oc * oc * 9; // [out_c × out_c × 3 × 3]
150            let mut w = vec![0.0f32; kernel_size];
151            rng.fill_normal(&mut w);
152            for v in &mut w {
153                *v *= smooth_scale;
154            }
155            smooth_weights.push(w);
156            smooth_biases.push(vec![0.0f32; oc]);
157        }
158
159        Ok(Self {
160            config: cfg,
161            lateral_convs,
162            smooth_weights,
163            smooth_biases,
164        })
165    }
166
167    /// Run the FPN forward pass.
168    ///
169    /// `features` must be ordered **coarse→fine**, one `FeatureMap` per level.
170    ///
171    /// Returns the FPN pyramid in the same coarse→fine order, with all maps
172    /// having `out_channels` channels.
173    ///
174    /// # Errors
175    /// - `DimensionMismatch` if the number of features ≠ number of configured levels.
176    /// - `EmptyInput` if `features` is empty.
177    /// - Propagates errors from lateral convolutions.
178    pub fn forward(&self, features: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
179        let n = self.config.n_levels();
180        if features.is_empty() {
181            return Err(VisionError::EmptyInput("Fpn::forward features"));
182        }
183        if features.len() != n {
184            return Err(VisionError::DimensionMismatch {
185                expected: n,
186                got: features.len(),
187            });
188        }
189
190        let oc = self.config.out_channels;
191
192        // ── Step 1: lateral convolutions ──────────────────────────────────────
193        let mut lateral_maps: Vec<FeatureMap> = Vec::with_capacity(n);
194        for (l, feat) in features.iter().enumerate() {
195            let h = feat.height;
196            let w = feat.width;
197            let lateral_data = self.lateral_convs[l].forward(&feat.data, h, w)?;
198            lateral_maps.push(FeatureMap::new(lateral_data, oc, h, w)?);
199        }
200
201        // ── Step 2: top-down merge ────────────────────────────────────────────
202        // merged[L-1] = lateral[L-1]  (coarsest level)
203        // merged[l]   = lateral[l] + upsample(merged[l+1]) for l in (L-2)..=0
204        let mut merged: Vec<FeatureMap> = Vec::with_capacity(n);
205        // Clone the coarsest level to start the top-down path.
206        merged.push(lateral_maps[n - 1].clone());
207
208        // Build merged levels from (n-2) down to 0, pushing in reverse order
209        // so we'll need to reverse at the end.
210        for l in (0..n - 1).rev() {
211            let target_h = lateral_maps[l].height;
212            let target_w = lateral_maps[l].width;
213            // The coarser merged level was added last to our temporary vec.
214            let coarser = merged.last().expect("at least one element");
215            let upsampled = upsample_nearest(coarser, target_h, target_w);
216            // Element-wise addition: lateral[l] + upsampled
217            let lat = &lateral_maps[l];
218            let mut merged_data = vec![0.0f32; lat.len()];
219            for (i, v) in merged_data.iter_mut().enumerate() {
220                *v = lat.data[i] + upsampled.data[i];
221            }
222            merged.push(FeatureMap::new(merged_data, oc, target_h, target_w)?);
223        }
224
225        // Reverse so that index 0 corresponds to the coarsest level again.
226        merged.reverse();
227
228        // ── Step 3: 3×3 smoothing conv ────────────────────────────────────────
229        let mut output: Vec<FeatureMap> = Vec::with_capacity(n);
230        for (l, fm) in merged.into_iter().enumerate() {
231            let smooth_data = conv3x3_same(
232                &fm.data,
233                oc,
234                fm.height,
235                fm.width,
236                &self.smooth_weights[l],
237                &self.smooth_biases[l],
238                oc,
239            );
240            output.push(FeatureMap::new(smooth_data, oc, fm.height, fm.width)?);
241        }
242
243        Ok(output)
244    }
245}
246
247// ─── Internal helpers ─────────────────────────────────────────────────────────
248
249/// Nearest-neighbour 2D upsampling to `(target_h, target_w)`.
250///
251/// Each output pixel `(c, i, j)` is assigned the value from the nearest
252/// input pixel using floor-based index mapping.
253fn upsample_nearest(feat: &FeatureMap, target_h: usize, target_w: usize) -> FeatureMap {
254    let src_h = feat.height;
255    let src_w = feat.width;
256    let c = feat.channels;
257    let mut out = vec![0.0f32; c * target_h * target_w];
258
259    for ch in 0..c {
260        for i in 0..target_h {
261            // Map output row → source row (nearest neighbour, floor)
262            let src_i = (i * src_h / target_h).min(src_h.saturating_sub(1));
263            for j in 0..target_w {
264                let src_j = (j * src_w / target_w).min(src_w.saturating_sub(1));
265                out[ch * target_h * target_w + i * target_w + j] = feat.at(ch, src_i, src_j);
266            }
267        }
268    }
269
270    // Cannot fail: data.len() == c * target_h * target_w by construction.
271    FeatureMap {
272        data: out,
273        channels: c,
274        height: target_h,
275        width: target_w,
276    }
277}
278
279/// 3×3 convolution with same (zero) padding.
280///
281/// - `feat`:       input `[channels, h, w]`.
282/// - `channels`:   `in_channels` (must equal `out_channels` here — smoothing).
283/// - `weight`:     `[out_c × in_c × 9]` (kernel laid out as `[oc][ic][ki][kj]`).
284/// - `bias`:       `[out_c]`.
285/// - `out_channels`: output channel count.
286///
287/// Returns `[out_channels, h, w]`.
288fn conv3x3_same(
289    feat: &[f32],
290    channels: usize,
291    h: usize,
292    w: usize,
293    weight: &[f32],
294    bias: &[f32],
295    out_channels: usize,
296) -> Vec<f32> {
297    let mut out = vec![0.0f32; out_channels * h * w];
298
299    for oc in 0..out_channels {
300        for i in 0..h {
301            for j in 0..w {
302                let mut acc = bias[oc];
303                for ic in 0..channels {
304                    // 3×3 kernel offsets: ki ∈ {-1, 0, 1}, kj ∈ {-1, 0, 1}
305                    for ki in 0..3usize {
306                        let src_i = i as isize + ki as isize - 1;
307                        if src_i < 0 || src_i >= h as isize {
308                            continue; // zero-padding
309                        }
310                        for kj in 0..3usize {
311                            let src_j = j as isize + kj as isize - 1;
312                            if src_j < 0 || src_j >= w as isize {
313                                continue; // zero-padding
314                            }
315                            // weight layout: [oc, ic, ki, kj]  → [oc*(ic*9) + ic*9 + ki*3 + kj]
316                            let w_idx = oc * channels * 9 + ic * 9 + ki * 3 + kj;
317                            let f_idx = ic * h * w + src_i as usize * w + src_j as usize;
318                            acc += weight[w_idx] * feat[f_idx];
319                        }
320                    }
321                }
322                out[oc * h * w + i * w + j] = acc;
323            }
324        }
325    }
326
327    out
328}
329
330// ─── Tests ───────────────────────────────────────────────────────────────────
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn make_rng() -> LcgRng {
337        LcgRng::new(123)
338    }
339
340    /// Build a random feature map for testing.
341    fn random_feature_map(rng: &mut LcgRng, channels: usize, h: usize, w: usize) -> FeatureMap {
342        let n = channels * h * w;
343        let mut data = vec![0.0f32; n];
344        rng.fill_normal(&mut data);
345        FeatureMap::new(data, channels, h, w).expect("valid feature map")
346    }
347
348    // ── FeatureMap ─────────────────────────────────────────────────────────────
349
350    #[test]
351    fn feature_map_valid_construction() {
352        let data = vec![1.0f32; 3 * 4 * 4];
353        let fm = FeatureMap::new(data, 3, 4, 4).expect("valid feature map");
354        assert_eq!(fm.channels, 3);
355        assert_eq!(fm.height, 4);
356        assert_eq!(fm.width, 4);
357    }
358
359    #[test]
360    fn feature_map_wrong_size_errors() {
361        let data = vec![0.0f32; 3 * 4 * 4 - 1];
362        let r = FeatureMap::new(data, 3, 4, 4);
363        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
364    }
365
366    #[test]
367    fn feature_map_at_correct_value() {
368        // channel 0 = 0.0, channel 1 = 1.0, etc.
369        let mut data = vec![0.0f32; 2 * 3 * 3];
370        for c in 0..2 {
371            for pos in 0..9 {
372                data[c * 9 + pos] = c as f32;
373            }
374        }
375        let fm = FeatureMap::new(data, 2, 3, 3).expect("valid feature map");
376        assert_eq!(fm.at(0, 1, 1), 0.0);
377        assert_eq!(fm.at(1, 0, 0), 1.0);
378    }
379
380    // ── FpnConfig ──────────────────────────────────────────────────────────────
381
382    #[test]
383    fn fpn_config_valid() {
384        let cfg = FpnConfig::new(vec![2048, 1024, 512, 256], 256).expect("valid config");
385        assert_eq!(cfg.n_levels(), 4);
386    }
387
388    #[test]
389    fn fpn_config_empty_in_channels_errors() {
390        let r = FpnConfig::new(vec![], 256);
391        assert!(r.is_err());
392    }
393
394    #[test]
395    fn fpn_config_zero_out_channels_errors() {
396        let r = FpnConfig::new(vec![512, 256], 0);
397        assert!(r.is_err());
398    }
399
400    // ── upsample_nearest ──────────────────────────────────────────────────────
401
402    #[test]
403    fn upsample_nearest_doubles_size() {
404        let data = vec![1.0, 2.0, 3.0, 4.0]; // 1 channel, 2×2
405        let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
406        let up = upsample_nearest(&fm, 4, 4);
407        assert_eq!(up.height, 4);
408        assert_eq!(up.width, 4);
409        assert_eq!(up.channels, 1);
410        assert_eq!(up.data.len(), 4 * 4);
411    }
412
413    #[test]
414    fn upsample_nearest_values_replicated() {
415        // 1-channel 2×2 map upsampled to 4×4: each pixel maps to a 2×2 block.
416        let data = vec![1.0f32, 2.0, 3.0, 4.0];
417        let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
418        let up = upsample_nearest(&fm, 4, 4);
419        // top-left 2×2 should all be 1.0 (from src pixel [0,0])
420        assert_eq!(up.at(0, 0, 0), 1.0);
421        assert_eq!(up.at(0, 0, 1), 1.0);
422        assert_eq!(up.at(0, 1, 0), 1.0);
423        assert_eq!(up.at(0, 1, 1), 1.0);
424        // bottom-right should be 4.0
425        assert_eq!(up.at(0, 2, 2), 4.0);
426        assert_eq!(up.at(0, 3, 3), 4.0);
427    }
428
429    #[test]
430    fn upsample_nearest_identity_when_same_size() {
431        let mut rng = make_rng();
432        let fm = random_feature_map(&mut rng, 4, 5, 7);
433        let up = upsample_nearest(&fm, 5, 7);
434        for (a, b) in fm.data.iter().zip(up.data.iter()) {
435            assert_eq!(*a, *b, "identity upsample should be exact copy");
436        }
437    }
438
439    // ── FPN forward ───────────────────────────────────────────────────────────
440
441    #[test]
442    fn fpn_forward_output_channel_count_uniform() {
443        let mut rng = make_rng();
444        let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
445        let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
446        let features = vec![
447            random_feature_map(&mut rng, 64, 4, 4),
448            random_feature_map(&mut rng, 32, 8, 8),
449        ];
450        let output = fpn.forward(features).expect("FPN forward ok");
451        assert_eq!(output.len(), 2, "two output levels");
452        for fm in &output {
453            assert_eq!(fm.channels, 16, "all output levels should have 16 channels");
454        }
455    }
456
457    #[test]
458    fn fpn_forward_preserves_spatial_dims() {
459        let mut rng = make_rng();
460        let cfg = FpnConfig::new(vec![32, 16], 8).expect("valid config");
461        let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
462        let features = vec![
463            random_feature_map(&mut rng, 32, 3, 3),
464            random_feature_map(&mut rng, 16, 6, 6),
465        ];
466        let output = fpn.forward(features).expect("FPN forward ok");
467        assert_eq!(output[0].height, 3);
468        assert_eq!(output[0].width, 3);
469        assert_eq!(output[1].height, 6);
470        assert_eq!(output[1].width, 6);
471    }
472
473    #[test]
474    fn fpn_forward_three_levels() {
475        let mut rng = make_rng();
476        let cfg = FpnConfig::new(vec![64, 32, 16], 8).expect("valid config");
477        let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
478        let features = vec![
479            random_feature_map(&mut rng, 64, 2, 2),
480            random_feature_map(&mut rng, 32, 4, 4),
481            random_feature_map(&mut rng, 16, 8, 8),
482        ];
483        let output = fpn.forward(features).expect("FPN forward 3 levels ok");
484        assert_eq!(output.len(), 3);
485        for fm in &output {
486            assert_eq!(fm.channels, 8);
487            assert!(fm.data.iter().all(|v| v.is_finite()), "non-finite output");
488        }
489    }
490
491    #[test]
492    fn fpn_forward_wrong_level_count_errors() {
493        let mut rng = make_rng();
494        let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
495        let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
496        // Only one feature map provided instead of two
497        let features = vec![random_feature_map(&mut rng, 64, 4, 4)];
498        let r = fpn.forward(features);
499        assert!(
500            matches!(
501                r,
502                Err(VisionError::DimensionMismatch {
503                    expected: 2,
504                    got: 1
505                })
506            ),
507            "expected DimensionMismatch error"
508        );
509    }
510
511    #[test]
512    fn fpn_forward_empty_features_errors() {
513        let mut rng = make_rng();
514        let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
515        let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
516        let r = fpn.forward(vec![]);
517        assert!(r.is_err(), "expected error for empty features");
518    }
519
520    // ── conv3x3_same ──────────────────────────────────────────────────────────
521
522    #[test]
523    fn conv3x3_same_output_shape() {
524        let feat = vec![0.5f32; 4 * 6 * 6];
525        let weight = vec![0.0f32; 4 * 4 * 9];
526        let bias = vec![1.0f32; 4]; // constant bias → all outputs = 1
527        let out = conv3x3_same(&feat, 4, 6, 6, &weight, &bias, 4);
528        assert_eq!(
529            out.len(),
530            4 * 6 * 6,
531            "output size matches input spatial dims"
532        );
533        // With all-zero weights and bias=1, every output should be exactly 1.
534        for v in &out {
535            assert!((*v - 1.0).abs() < 1e-6, "expected 1.0, got {v}");
536        }
537    }
538}