use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter};
fn silu(x: &Variable) -> Variable {
let sig = x.sigmoid();
x.mul_var(&sig)
}
pub struct ConvBNSiLU {
conv: Conv2d,
bn: BatchNorm2d,
}
impl ConvBNSiLU {
pub fn new(in_ch: usize, out_ch: usize, kernel: usize, stride: usize, padding: usize) -> Self {
Self {
conv: Conv2d::with_options(
in_ch,
out_ch,
(kernel, kernel),
(stride, stride),
(padding, padding),
true,
),
bn: BatchNorm2d::new(out_ch),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
silu(&self.bn.forward(&self.conv.forward(x)))
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.conv.parameters();
p.extend(self.bn.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = HashMap::new();
for (k, v) in self.conv.named_parameters() {
p.insert(format!("{}.conv.{}", prefix, k), v);
}
for (k, v) in self.bn.named_parameters() {
p.insert(format!("{}.bn.{}", prefix, k), v);
}
p
}
pub fn set_training(&mut self, training: bool) {
self.bn.set_training(training);
}
}
pub struct Bottleneck {
cv1: ConvBNSiLU,
cv2: ConvBNSiLU,
shortcut: bool,
}
impl Bottleneck {
pub fn new(in_ch: usize, out_ch: usize, shortcut: bool) -> Self {
let hidden = out_ch; Self {
cv1: ConvBNSiLU::new(in_ch, hidden, 1, 1, 0),
cv2: ConvBNSiLU::new(hidden, out_ch, 3, 1, 1),
shortcut: shortcut && in_ch == out_ch,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let out = self.cv2.forward(&self.cv1.forward(x));
if self.shortcut { out.add_var(x) } else { out }
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.cv1.parameters();
p.extend(self.cv2.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = self.cv1.named_parameters(&format!("{}.cv1", prefix));
p.extend(self.cv2.named_parameters(&format!("{}.cv2", prefix)));
p
}
pub fn set_training(&mut self, training: bool) {
self.cv1.set_training(training);
self.cv2.set_training(training);
}
}
pub struct CSPBlock {
downsample: ConvBNSiLU,
cv1: ConvBNSiLU,
cv2: ConvBNSiLU,
bottlenecks: Vec<Bottleneck>,
cv3: ConvBNSiLU,
out_ch: usize,
}
impl CSPBlock {
pub fn new(in_ch: usize, out_ch: usize, n_bottlenecks: usize) -> Self {
let half = out_ch / 2;
Self {
downsample: ConvBNSiLU::new(in_ch, out_ch, 3, 2, 1),
cv1: ConvBNSiLU::new(out_ch, half, 1, 1, 0),
cv2: ConvBNSiLU::new(out_ch, half, 1, 1, 0),
bottlenecks: (0..n_bottlenecks)
.map(|_| Bottleneck::new(half, half, true))
.collect(),
cv3: ConvBNSiLU::new(half * 2, out_ch, 1, 1, 0),
out_ch,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let x = self.downsample.forward(x);
let mut b1 = self.cv1.forward(&x);
for bottleneck in &self.bottlenecks {
b1 = bottleneck.forward(&b1);
}
let b2 = self.cv2.forward(&x);
let cat = Variable::cat(&[&b1, &b2], 1);
self.cv3.forward(&cat)
}
pub fn out_channels(&self) -> usize {
self.out_ch
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.downsample.parameters();
p.extend(self.cv1.parameters());
p.extend(self.cv2.parameters());
for b in &self.bottlenecks {
p.extend(b.parameters());
}
p.extend(self.cv3.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = self
.downsample
.named_parameters(&format!("{}.down", prefix));
p.extend(self.cv1.named_parameters(&format!("{}.cv1", prefix)));
p.extend(self.cv2.named_parameters(&format!("{}.cv2", prefix)));
for (i, b) in self.bottlenecks.iter().enumerate() {
p.extend(b.named_parameters(&format!("{}.btn{}", prefix, i)));
}
p.extend(self.cv3.named_parameters(&format!("{}.cv3", prefix)));
p
}
pub fn set_training(&mut self, training: bool) {
self.downsample.set_training(training);
self.cv1.set_training(training);
self.cv2.set_training(training);
for b in &mut self.bottlenecks {
b.set_training(training);
}
self.cv3.set_training(training);
}
}
pub struct ThermalBackbone {
stem: ConvBNSiLU,
ch_adapter: Option<ConvBNSiLU>,
stage1: CSPBlock,
stage2: CSPBlock,
stage3: CSPBlock,
}
impl ThermalBackbone {
pub fn new(in_channels: usize) -> Self {
let ch_adapter = if in_channels == 1 {
Some(ConvBNSiLU::new(1, 3, 1, 1, 0))
} else {
None
};
Self {
stem: ConvBNSiLU::new(3, 32, 3, 2, 1), ch_adapter,
stage1: CSPBlock::new(32, 64, 1),
stage2: CSPBlock::new(64, 128, 2),
stage3: CSPBlock::new(128, 256, 2),
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable) {
let x = if let Some(ref adapter) = self.ch_adapter {
adapter.forward(x)
} else {
x.clone()
};
let x = self.stem.forward(&x); let p3 = self.stage1.forward(&x); let p4 = self.stage2.forward(&p3); let p5 = self.stage3.forward(&p4);
(p3, p4, p5)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
if let Some(ref adapter) = self.ch_adapter {
p.extend(adapter.parameters());
}
p.extend(self.stem.parameters());
p.extend(self.stage1.parameters());
p.extend(self.stage2.parameters());
p.extend(self.stage3.parameters());
p
}
pub fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut p = HashMap::new();
if let Some(ref adapter) = self.ch_adapter {
p.extend(adapter.named_parameters("ch_adapter"));
}
p.extend(self.stem.named_parameters("stem"));
p.extend(self.stage1.named_parameters("stage1"));
p.extend(self.stage2.named_parameters("stage2"));
p.extend(self.stage3.named_parameters("stage3"));
p
}
pub fn set_training(&mut self, training: bool) {
if let Some(ref mut adapter) = self.ch_adapter {
adapter.set_training(training);
}
self.stem.set_training(training);
self.stage1.set_training(training);
self.stage2.set_training(training);
self.stage3.set_training(training);
}
}