use super::lateral::LateralConv1x1;
use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
};
#[derive(Debug, Clone)]
pub struct FeatureMap {
pub data: Vec<f32>,
pub channels: usize,
pub height: usize,
pub width: usize,
}
impl FeatureMap {
pub fn new(data: Vec<f32>, channels: usize, height: usize, width: usize) -> VisionResult<Self> {
let expected = channels * height * width;
if data.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: data.len(),
});
}
Ok(Self {
data,
channels,
height,
width,
})
}
#[inline]
pub fn at(&self, c: usize, h_idx: usize, w_idx: usize) -> f32 {
self.data[c * self.height * self.width + h_idx * self.width + w_idx]
}
#[inline]
fn len(&self) -> usize {
self.channels * self.height * self.width
}
}
pub struct FpnConfig {
pub in_channels: Vec<usize>,
pub out_channels: usize,
}
impl FpnConfig {
pub fn new(in_channels: Vec<usize>, out_channels: usize) -> VisionResult<Self> {
if in_channels.is_empty() {
return Err(VisionError::EmptyInput("FpnConfig::in_channels"));
}
if out_channels == 0 {
return Err(VisionError::InvalidImageSize {
height: 0,
width: 0,
channels: out_channels,
});
}
Ok(Self {
in_channels,
out_channels,
})
}
#[inline]
pub fn n_levels(&self) -> usize {
self.in_channels.len()
}
}
pub struct Fpn {
pub config: FpnConfig,
pub lateral_convs: Vec<LateralConv1x1>,
pub smooth_weights: Vec<Vec<f32>>,
pub smooth_biases: Vec<Vec<f32>>,
}
impl Fpn {
pub fn new(cfg: FpnConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let n = cfg.n_levels();
let oc = cfg.out_channels;
let mut lateral_convs = Vec::with_capacity(n);
for &ic in &cfg.in_channels {
lateral_convs.push(LateralConv1x1::new(ic, oc, rng)?);
}
let smooth_scale = 1.0_f32 / ((oc * 9) as f32).sqrt();
let mut smooth_weights = Vec::with_capacity(n);
let mut smooth_biases = Vec::with_capacity(n);
for _ in 0..n {
let kernel_size = oc * oc * 9; let mut w = vec![0.0f32; kernel_size];
rng.fill_normal(&mut w);
for v in &mut w {
*v *= smooth_scale;
}
smooth_weights.push(w);
smooth_biases.push(vec![0.0f32; oc]);
}
Ok(Self {
config: cfg,
lateral_convs,
smooth_weights,
smooth_biases,
})
}
pub fn forward(&self, features: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
let n = self.config.n_levels();
if features.is_empty() {
return Err(VisionError::EmptyInput("Fpn::forward features"));
}
if features.len() != n {
return Err(VisionError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let oc = self.config.out_channels;
let mut lateral_maps: Vec<FeatureMap> = Vec::with_capacity(n);
for (l, feat) in features.iter().enumerate() {
let h = feat.height;
let w = feat.width;
let lateral_data = self.lateral_convs[l].forward(&feat.data, h, w)?;
lateral_maps.push(FeatureMap::new(lateral_data, oc, h, w)?);
}
let mut merged: Vec<FeatureMap> = Vec::with_capacity(n);
merged.push(lateral_maps[n - 1].clone());
for l in (0..n - 1).rev() {
let target_h = lateral_maps[l].height;
let target_w = lateral_maps[l].width;
let coarser = merged.last().expect("at least one element");
let upsampled = upsample_nearest(coarser, target_h, target_w);
let lat = &lateral_maps[l];
let mut merged_data = vec![0.0f32; lat.len()];
for (i, v) in merged_data.iter_mut().enumerate() {
*v = lat.data[i] + upsampled.data[i];
}
merged.push(FeatureMap::new(merged_data, oc, target_h, target_w)?);
}
merged.reverse();
let mut output: Vec<FeatureMap> = Vec::with_capacity(n);
for (l, fm) in merged.into_iter().enumerate() {
let smooth_data = conv3x3_same(
&fm.data,
oc,
fm.height,
fm.width,
&self.smooth_weights[l],
&self.smooth_biases[l],
oc,
);
output.push(FeatureMap::new(smooth_data, oc, fm.height, fm.width)?);
}
Ok(output)
}
}
fn upsample_nearest(feat: &FeatureMap, target_h: usize, target_w: usize) -> FeatureMap {
let src_h = feat.height;
let src_w = feat.width;
let c = feat.channels;
let mut out = vec![0.0f32; c * target_h * target_w];
for ch in 0..c {
for i in 0..target_h {
let src_i = (i * src_h / target_h).min(src_h.saturating_sub(1));
for j in 0..target_w {
let src_j = (j * src_w / target_w).min(src_w.saturating_sub(1));
out[ch * target_h * target_w + i * target_w + j] = feat.at(ch, src_i, src_j);
}
}
}
FeatureMap {
data: out,
channels: c,
height: target_h,
width: target_w,
}
}
fn conv3x3_same(
feat: &[f32],
channels: usize,
h: usize,
w: usize,
weight: &[f32],
bias: &[f32],
out_channels: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; out_channels * h * w];
for oc in 0..out_channels {
for i in 0..h {
for j in 0..w {
let mut acc = bias[oc];
for ic in 0..channels {
for ki in 0..3usize {
let src_i = i as isize + ki as isize - 1;
if src_i < 0 || src_i >= h as isize {
continue; }
for kj in 0..3usize {
let src_j = j as isize + kj as isize - 1;
if src_j < 0 || src_j >= w as isize {
continue; }
let w_idx = oc * channels * 9 + ic * 9 + ki * 3 + kj;
let f_idx = ic * h * w + src_i as usize * w + src_j as usize;
acc += weight[w_idx] * feat[f_idx];
}
}
}
out[oc * h * w + i * w + j] = acc;
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn make_rng() -> LcgRng {
LcgRng::new(123)
}
fn random_feature_map(rng: &mut LcgRng, channels: usize, h: usize, w: usize) -> FeatureMap {
let n = channels * h * w;
let mut data = vec![0.0f32; n];
rng.fill_normal(&mut data);
FeatureMap::new(data, channels, h, w).expect("valid feature map")
}
#[test]
fn feature_map_valid_construction() {
let data = vec![1.0f32; 3 * 4 * 4];
let fm = FeatureMap::new(data, 3, 4, 4).expect("valid feature map");
assert_eq!(fm.channels, 3);
assert_eq!(fm.height, 4);
assert_eq!(fm.width, 4);
}
#[test]
fn feature_map_wrong_size_errors() {
let data = vec![0.0f32; 3 * 4 * 4 - 1];
let r = FeatureMap::new(data, 3, 4, 4);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn feature_map_at_correct_value() {
let mut data = vec![0.0f32; 2 * 3 * 3];
for c in 0..2 {
for pos in 0..9 {
data[c * 9 + pos] = c as f32;
}
}
let fm = FeatureMap::new(data, 2, 3, 3).expect("valid feature map");
assert_eq!(fm.at(0, 1, 1), 0.0);
assert_eq!(fm.at(1, 0, 0), 1.0);
}
#[test]
fn fpn_config_valid() {
let cfg = FpnConfig::new(vec![2048, 1024, 512, 256], 256).expect("valid config");
assert_eq!(cfg.n_levels(), 4);
}
#[test]
fn fpn_config_empty_in_channels_errors() {
let r = FpnConfig::new(vec![], 256);
assert!(r.is_err());
}
#[test]
fn fpn_config_zero_out_channels_errors() {
let r = FpnConfig::new(vec![512, 256], 0);
assert!(r.is_err());
}
#[test]
fn upsample_nearest_doubles_size() {
let data = vec![1.0, 2.0, 3.0, 4.0]; let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
let up = upsample_nearest(&fm, 4, 4);
assert_eq!(up.height, 4);
assert_eq!(up.width, 4);
assert_eq!(up.channels, 1);
assert_eq!(up.data.len(), 4 * 4);
}
#[test]
fn upsample_nearest_values_replicated() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let fm = FeatureMap::new(data, 1, 2, 2).expect("valid");
let up = upsample_nearest(&fm, 4, 4);
assert_eq!(up.at(0, 0, 0), 1.0);
assert_eq!(up.at(0, 0, 1), 1.0);
assert_eq!(up.at(0, 1, 0), 1.0);
assert_eq!(up.at(0, 1, 1), 1.0);
assert_eq!(up.at(0, 2, 2), 4.0);
assert_eq!(up.at(0, 3, 3), 4.0);
}
#[test]
fn upsample_nearest_identity_when_same_size() {
let mut rng = make_rng();
let fm = random_feature_map(&mut rng, 4, 5, 7);
let up = upsample_nearest(&fm, 5, 7);
for (a, b) in fm.data.iter().zip(up.data.iter()) {
assert_eq!(*a, *b, "identity upsample should be exact copy");
}
}
#[test]
fn fpn_forward_output_channel_count_uniform() {
let mut rng = make_rng();
let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
let features = vec![
random_feature_map(&mut rng, 64, 4, 4),
random_feature_map(&mut rng, 32, 8, 8),
];
let output = fpn.forward(features).expect("FPN forward ok");
assert_eq!(output.len(), 2, "two output levels");
for fm in &output {
assert_eq!(fm.channels, 16, "all output levels should have 16 channels");
}
}
#[test]
fn fpn_forward_preserves_spatial_dims() {
let mut rng = make_rng();
let cfg = FpnConfig::new(vec![32, 16], 8).expect("valid config");
let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
let features = vec![
random_feature_map(&mut rng, 32, 3, 3),
random_feature_map(&mut rng, 16, 6, 6),
];
let output = fpn.forward(features).expect("FPN forward ok");
assert_eq!(output[0].height, 3);
assert_eq!(output[0].width, 3);
assert_eq!(output[1].height, 6);
assert_eq!(output[1].width, 6);
}
#[test]
fn fpn_forward_three_levels() {
let mut rng = make_rng();
let cfg = FpnConfig::new(vec![64, 32, 16], 8).expect("valid config");
let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
let features = vec![
random_feature_map(&mut rng, 64, 2, 2),
random_feature_map(&mut rng, 32, 4, 4),
random_feature_map(&mut rng, 16, 8, 8),
];
let output = fpn.forward(features).expect("FPN forward 3 levels ok");
assert_eq!(output.len(), 3);
for fm in &output {
assert_eq!(fm.channels, 8);
assert!(fm.data.iter().all(|v| v.is_finite()), "non-finite output");
}
}
#[test]
fn fpn_forward_wrong_level_count_errors() {
let mut rng = make_rng();
let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
let features = vec![random_feature_map(&mut rng, 64, 4, 4)];
let r = fpn.forward(features);
assert!(
matches!(
r,
Err(VisionError::DimensionMismatch {
expected: 2,
got: 1
})
),
"expected DimensionMismatch error"
);
}
#[test]
fn fpn_forward_empty_features_errors() {
let mut rng = make_rng();
let cfg = FpnConfig::new(vec![64, 32], 16).expect("valid config");
let fpn = Fpn::new(cfg, &mut rng).expect("valid FPN");
let r = fpn.forward(vec![]);
assert!(r.is_err(), "expected error for empty features");
}
#[test]
fn conv3x3_same_output_shape() {
let feat = vec![0.5f32; 4 * 6 * 6];
let weight = vec![0.0f32; 4 * 4 * 9];
let bias = vec![1.0f32; 4]; let out = conv3x3_same(&feat, 4, 6, 6, &weight, &bias, 4);
assert_eq!(
out.len(),
4 * 6 * 6,
"output size matches input spatial dims"
);
for v in &out {
assert!((*v - 1.0).abs() < 1e-6, "expected 1.0, got {v}");
}
}
}