use axonml_autograd::Variable;
use axonml_nn::{Conv2d, Module, Parameter};
use crate::ops::{InterpolateMode, interpolate_var};
pub struct FPN {
lateral_convs: Vec<Conv2d>,
smooth_convs: Vec<Conv2d>,
num_levels: usize,
out_channels: usize,
}
impl FPN {
pub fn new(in_channels: &[usize], out_channels: usize) -> Self {
let num_levels = in_channels.len();
let mut lateral_convs = Vec::with_capacity(num_levels);
let mut smooth_convs = Vec::with_capacity(num_levels);
for &in_c in in_channels {
lateral_convs.push(Conv2d::with_options(
in_c,
out_channels,
(1, 1),
(1, 1),
(0, 0),
true,
));
smooth_convs.push(Conv2d::with_options(
out_channels,
out_channels,
(3, 3),
(1, 1),
(1, 1),
true,
));
}
Self {
lateral_convs,
smooth_convs,
num_levels,
out_channels,
}
}
pub fn forward(&self, features: &[Variable]) -> Vec<Variable> {
assert_eq!(features.len(), self.num_levels);
let mut laterals: Vec<Variable> = features
.iter()
.enumerate()
.map(|(i, feat)| self.lateral_convs[i].forward(feat))
.collect();
for i in (0..self.num_levels - 1).rev() {
let upper = &laterals[i + 1];
let target_h = laterals[i].shape()[2];
let target_w = laterals[i].shape()[3];
let upsampled = interpolate_var(upper, target_h, target_w, InterpolateMode::Nearest);
laterals[i] = laterals[i].add_var(&upsampled);
}
laterals
.iter()
.enumerate()
.map(|(i, lat)| self.smooth_convs[i].forward(lat))
.collect()
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
}
impl Module for FPN {
fn forward(&self, _x: &Variable) -> Variable {
panic!("FPN requires multi-scale input. Use FPN::forward(&[Variable]) instead.");
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
for conv in &self.lateral_convs {
params.extend(conv.parameters());
}
for conv in &self.smooth_convs {
params.extend(conv.parameters());
}
params
}
fn train(&mut self) {}
fn eval(&mut self) {}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
#[test]
fn test_fpn_creation() {
let fpn = FPN::new(&[64, 128, 256, 512], 256);
assert_eq!(fpn.num_levels, 4);
assert_eq!(fpn.out_channels, 256);
assert!(!fpn.parameters().is_empty());
}
#[test]
fn test_fpn_forward() {
let fpn = FPN::new(&[64, 128, 256, 512], 64);
let c2 = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 64 * 16 * 16], &[1, 64, 16, 16]).unwrap(),
false,
);
let c3 = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 128 * 8 * 8], &[1, 128, 8, 8]).unwrap(),
false,
);
let c4 = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 256 * 4 * 4], &[1, 256, 4, 4]).unwrap(),
false,
);
let c5 = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 512 * 2 * 2], &[1, 512, 2, 2]).unwrap(),
false,
);
let pyramid = fpn.forward(&[c2, c3, c4, c5]);
assert_eq!(pyramid.len(), 4);
assert_eq!(pyramid[0].shape(), vec![1, 64, 16, 16]);
assert_eq!(pyramid[1].shape(), vec![1, 64, 8, 8]);
assert_eq!(pyramid[2].shape(), vec![1, 64, 4, 4]);
assert_eq!(pyramid[3].shape(), vec![1, 64, 2, 2]);
}
#[test]
fn test_fpn_parameter_count() {
let fpn = FPN::new(&[256, 512, 1024, 2048], 256);
let params = fpn.parameters();
assert_eq!(params.len(), 16);
}
}