use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::Parameter;
use super::backbone::ConvBNSiLU;
use crate::ops::{InterpolateMode, interpolate_var};
pub struct ThermalFPN {
lateral_p5: ConvBNSiLU,
lateral_p4: ConvBNSiLU,
lateral_p3: ConvBNSiLU,
td_p4: ConvBNSiLU,
td_p3: ConvBNSiLU,
bu_p4: ConvBNSiLU,
bu_p5: ConvBNSiLU,
out_channels: usize,
}
impl ThermalFPN {
pub fn new(p3_ch: usize, p4_ch: usize, p5_ch: usize, out_ch: usize) -> Self {
Self {
lateral_p5: ConvBNSiLU::new(p5_ch, out_ch, 1, 1, 0),
lateral_p4: ConvBNSiLU::new(p4_ch, out_ch, 1, 1, 0),
lateral_p3: ConvBNSiLU::new(p3_ch, out_ch, 1, 1, 0),
td_p4: ConvBNSiLU::new(out_ch, out_ch, 3, 1, 1),
td_p3: ConvBNSiLU::new(out_ch, out_ch, 3, 1, 1),
bu_p4: ConvBNSiLU::new(out_ch, out_ch, 3, 2, 1),
bu_p5: ConvBNSiLU::new(out_ch, out_ch, 3, 2, 1),
out_channels: out_ch,
}
}
pub fn default_config() -> Self {
Self::new(64, 128, 256, 128)
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
pub fn forward(
&self,
p3: &Variable,
p4: &Variable,
p5: &Variable,
) -> (Variable, Variable, Variable) {
let l5 = self.lateral_p5.forward(p5);
let l4 = self.lateral_p4.forward(p4);
let l3 = self.lateral_p3.forward(p3);
let p4_shape = l4.shape();
let up5 = interpolate_var(&l5, p4_shape[2], p4_shape[3], InterpolateMode::Nearest);
let td4 = self.td_p4.forward(&up5.add_var(&l4));
let p3_shape = l3.shape();
let up4 = interpolate_var(&td4, p3_shape[2], p3_shape[3], InterpolateMode::Nearest);
let td3 = self.td_p3.forward(&up4.add_var(&l3));
let bu4 = self.bu_p4.forward(&td3).add_var(&td4);
let bu5 = self.bu_p5.forward(&bu4).add_var(&l5);
(td3, bu4, bu5)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.lateral_p5.parameters());
p.extend(self.lateral_p4.parameters());
p.extend(self.lateral_p3.parameters());
p.extend(self.td_p4.parameters());
p.extend(self.td_p3.parameters());
p.extend(self.bu_p4.parameters());
p.extend(self.bu_p5.parameters());
p
}
pub fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut p = HashMap::new();
p.extend(self.lateral_p5.named_parameters("lat_p5"));
p.extend(self.lateral_p4.named_parameters("lat_p4"));
p.extend(self.lateral_p3.named_parameters("lat_p3"));
p.extend(self.td_p4.named_parameters("td_p4"));
p.extend(self.td_p3.named_parameters("td_p3"));
p.extend(self.bu_p4.named_parameters("bu_p4"));
p.extend(self.bu_p5.named_parameters("bu_p5"));
p
}
pub fn set_training(&mut self, training: bool) {
self.lateral_p5.set_training(training);
self.lateral_p4.set_training(training);
self.lateral_p3.set_training(training);
self.td_p4.set_training(training);
self.td_p3.set_training(training);
self.bu_p4.set_training(training);
self.bu_p5.set_training(training);
}
}