use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter};
fn silu(x: &Variable) -> Variable {
let sig = x.sigmoid();
x.mul_var(&sig)
}
pub struct ConvBNSiLU {
conv: Conv2d,
bn: BatchNorm2d,
}
impl ConvBNSiLU {
pub fn new(in_ch: usize, out_ch: usize, kernel: usize, stride: usize, padding: usize) -> Self {
Self {
conv: Conv2d::with_options(
in_ch,
out_ch,
(kernel, kernel),
(stride, stride),
(padding, padding),
true,
),
bn: BatchNorm2d::new(out_ch),
}
}
pub fn forward(&self, x: &Variable) -> Variable {
silu(&self.bn.forward(&self.conv.forward(x)))
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.conv.parameters();
p.extend(self.bn.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = HashMap::new();
for (k, v) in self.conv.named_parameters() {
p.insert(format!("{}.conv.{}", prefix, k), v);
}
for (k, v) in self.bn.named_parameters() {
p.insert(format!("{}.bn.{}", prefix, k), v);
}
p
}
pub fn set_training(&mut self, training: bool) {
self.bn.set_training(training);
}
}
pub struct Bottleneck {
cv1: ConvBNSiLU,
cv2: ConvBNSiLU,
shortcut: bool,
}
impl Bottleneck {
pub fn new(in_ch: usize, out_ch: usize, shortcut: bool) -> Self {
let hidden = out_ch; Self {
cv1: ConvBNSiLU::new(in_ch, hidden, 1, 1, 0),
cv2: ConvBNSiLU::new(hidden, out_ch, 3, 1, 1),
shortcut: shortcut && in_ch == out_ch,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let out = self.cv2.forward(&self.cv1.forward(x));
if self.shortcut { out.add_var(x) } else { out }
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.cv1.parameters();
p.extend(self.cv2.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = self.cv1.named_parameters(&format!("{}.cv1", prefix));
p.extend(self.cv2.named_parameters(&format!("{}.cv2", prefix)));
p
}
pub fn set_training(&mut self, training: bool) {
self.cv1.set_training(training);
self.cv2.set_training(training);
}
}
pub struct CSPBlock {
downsample: ConvBNSiLU,
cv1: ConvBNSiLU,
cv2: ConvBNSiLU,
bottlenecks: Vec<Bottleneck>,
cv3: ConvBNSiLU,
out_ch: usize,
}
impl CSPBlock {
pub fn new(in_ch: usize, out_ch: usize, n_bottlenecks: usize) -> Self {
let half = out_ch / 2;
Self {
downsample: ConvBNSiLU::new(in_ch, out_ch, 3, 2, 1),
cv1: ConvBNSiLU::new(out_ch, half, 1, 1, 0),
cv2: ConvBNSiLU::new(out_ch, half, 1, 1, 0),
bottlenecks: (0..n_bottlenecks)
.map(|_| Bottleneck::new(half, half, true))
.collect(),
cv3: ConvBNSiLU::new(half * 2, out_ch, 1, 1, 0),
out_ch,
}
}
pub fn forward(&self, x: &Variable) -> Variable {
let x = self.downsample.forward(x);
let mut b1 = self.cv1.forward(&x);
for bottleneck in &self.bottlenecks {
b1 = bottleneck.forward(&b1);
}
let b2 = self.cv2.forward(&x);
let cat = Variable::cat(&[&b1, &b2], 1);
self.cv3.forward(&cat)
}
pub fn out_channels(&self) -> usize {
self.out_ch
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = self.downsample.parameters();
p.extend(self.cv1.parameters());
p.extend(self.cv2.parameters());
for b in &self.bottlenecks {
p.extend(b.parameters());
}
p.extend(self.cv3.parameters());
p
}
pub fn named_parameters(&self, prefix: &str) -> HashMap<String, Parameter> {
let mut p = self
.downsample
.named_parameters(&format!("{}.down", prefix));
p.extend(self.cv1.named_parameters(&format!("{}.cv1", prefix)));
p.extend(self.cv2.named_parameters(&format!("{}.cv2", prefix)));
for (i, b) in self.bottlenecks.iter().enumerate() {
p.extend(b.named_parameters(&format!("{}.btn{}", prefix, i)));
}
p.extend(self.cv3.named_parameters(&format!("{}.cv3", prefix)));
p
}
pub fn set_training(&mut self, training: bool) {
self.downsample.set_training(training);
self.cv1.set_training(training);
self.cv2.set_training(training);
for b in &mut self.bottlenecks {
b.set_training(training);
}
self.cv3.set_training(training);
}
}
pub struct ThermalBackbone {
stem: ConvBNSiLU,
ch_adapter: Option<ConvBNSiLU>,
stage1: CSPBlock,
stage2: CSPBlock,
stage3: CSPBlock,
}
impl ThermalBackbone {
pub fn new(in_channels: usize) -> Self {
let ch_adapter = if in_channels == 1 {
Some(ConvBNSiLU::new(1, 3, 1, 1, 0))
} else {
None
};
Self {
stem: ConvBNSiLU::new(3, 32, 3, 2, 1), ch_adapter,
stage1: CSPBlock::new(32, 64, 1),
stage2: CSPBlock::new(64, 128, 2),
stage3: CSPBlock::new(128, 256, 2),
}
}
pub fn forward(&self, x: &Variable) -> (Variable, Variable, Variable) {
let x = if let Some(ref adapter) = self.ch_adapter {
adapter.forward(x)
} else {
x.clone()
};
let x = self.stem.forward(&x); let p3 = self.stage1.forward(&x); let p4 = self.stage2.forward(&p3); let p5 = self.stage3.forward(&p4);
(p3, p4, p5)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
if let Some(ref adapter) = self.ch_adapter {
p.extend(adapter.parameters());
}
p.extend(self.stem.parameters());
p.extend(self.stage1.parameters());
p.extend(self.stage2.parameters());
p.extend(self.stage3.parameters());
p
}
pub fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut p = HashMap::new();
if let Some(ref adapter) = self.ch_adapter {
p.extend(adapter.named_parameters("ch_adapter"));
}
p.extend(self.stem.named_parameters("stem"));
p.extend(self.stage1.named_parameters("stage1"));
p.extend(self.stage2.named_parameters("stage2"));
p.extend(self.stage3.named_parameters("stage3"));
p
}
pub fn set_training(&mut self, training: bool) {
if let Some(ref mut adapter) = self.ch_adapter {
adapter.set_training(training);
}
self.stem.set_training(training);
self.stage1.set_training(training);
self.stage2.set_training(training);
self.stage3.set_training(training);
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
#[test]
fn test_conv_bn_silu_shape() {
let block = ConvBNSiLU::new(3, 32, 3, 1, 1);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 3 * 8 * 8], &[2, 3, 8, 8]).unwrap(),
false,
);
let y = block.forward(&x);
assert_eq!(y.data().shape(), &[2, 32, 8, 8]);
}
#[test]
fn test_conv_bn_silu_params() {
let block = ConvBNSiLU::new(3, 32, 3, 1, 1);
assert!(!block.parameters().is_empty());
}
#[test]
fn test_bottleneck_shortcut() {
let block = Bottleneck::new(16, 16, true);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 16 * 4 * 4], &[1, 16, 4, 4]).unwrap(),
false,
);
let y = block.forward(&x);
assert_eq!(y.data().shape(), &[1, 16, 4, 4]);
}
#[test]
fn test_bottleneck_no_shortcut() {
let block = Bottleneck::new(16, 32, false);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 16 * 4 * 4], &[1, 16, 4, 4]).unwrap(),
false,
);
let y = block.forward(&x);
assert_eq!(y.data().shape(), &[1, 32, 4, 4]);
}
#[test]
fn test_csp_block_shape() {
let block = CSPBlock::new(32, 64, 1);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 32 * 8 * 8], &[1, 32, 8, 8]).unwrap(),
false,
);
let y = block.forward(&x);
assert_eq!(y.data().shape()[0], 1);
assert_eq!(y.data().shape()[1], 64);
assert_eq!(y.data().shape()[2], 4);
assert_eq!(y.data().shape()[3], 4);
}
#[test]
fn test_thermal_backbone_3ch() {
let backbone = ThermalBackbone::new(3);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 3 * 64 * 64], &[1, 3, 64, 64]).unwrap(),
false,
);
let (p3, p4, p5) = backbone.forward(&x);
assert_eq!(p3.data().shape()[1], 64);
assert_eq!(p3.data().shape()[2], 16);
assert_eq!(p4.data().shape()[1], 128);
assert_eq!(p4.data().shape()[2], 8);
assert_eq!(p5.data().shape()[1], 256);
assert_eq!(p5.data().shape()[2], 4);
}
#[test]
fn test_thermal_backbone_1ch() {
let backbone = ThermalBackbone::new(1);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 32 * 32], &[1, 1, 32, 32]).unwrap(),
false,
);
let (p3, p4, p5) = backbone.forward(&x);
assert_eq!(p3.data().shape()[1], 64);
assert_eq!(p4.data().shape()[1], 128);
assert_eq!(p5.data().shape()[1], 256);
}
#[test]
fn test_thermal_backbone_params() {
let backbone = ThermalBackbone::new(1);
let params = backbone.parameters();
assert!(!params.is_empty());
let backbone3 = ThermalBackbone::new(3);
assert!(backbone.parameters().len() > backbone3.parameters().len());
}
}