use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::activation::relu;
use burn::tensor::backend::Backend;
use burn_nn::PaddingConfig2d;
use burn_nn::conv::{Conv2d, Conv2dConfig};
#[derive(Module, Debug)]
pub struct FireModule<B: Backend> {
squeeze: Conv2d<B>,
expand1x1: Conv2d<B>,
expand3x3: Conv2d<B>,
}
impl<B: Backend> FireModule<B> {
pub fn new(
in_channels: usize,
squeeze_channels: usize,
expand1x1_channels: usize,
expand3x3_channels: usize,
device: &B::Device,
) -> Self {
Self {
squeeze: Conv2dConfig::new([in_channels, squeeze_channels], [1, 1])
.with_bias(true)
.init(device),
expand1x1: Conv2dConfig::new([squeeze_channels, expand1x1_channels], [1, 1])
.with_bias(true)
.init(device),
expand3x3: Conv2dConfig::new([squeeze_channels, expand3x3_channels], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.with_bias(true)
.init(device),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let squeezed = relu(self.squeeze.forward(x));
let e1 = relu(self.expand1x1.forward(squeezed.clone()));
let e3 = relu(self.expand3x3.forward(squeezed));
Tensor::cat(vec![e1, e3], 1)
}
}
#[derive(Module, Debug)]
pub struct SqueezeFeatureExtractor<B: Backend> {
conv1: Conv2d<B>,
fire1: FireModule<B>,
fire2: FireModule<B>,
fire3: FireModule<B>,
fire4: FireModule<B>,
fire5: FireModule<B>,
fire6: FireModule<B>,
fire7: FireModule<B>,
fire8: FireModule<B>,
}
impl<B: Backend> SqueezeFeatureExtractor<B> {
pub fn new(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([3, 64], [3, 3])
.with_stride([2, 2])
.with_bias(true)
.init(device),
fire1: FireModule::new(64, 16, 64, 64, device), fire2: FireModule::new(128, 16, 64, 64, device), fire3: FireModule::new(128, 32, 128, 128, device), fire4: FireModule::new(256, 32, 128, 128, device), fire5: FireModule::new(256, 48, 192, 192, device), fire6: FireModule::new(384, 48, 192, 192, device), fire7: FireModule::new(384, 64, 256, 256, device), fire8: FireModule::new(512, 64, 256, 256, device), }
}
pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {
let mut features = Vec::with_capacity(7);
let x = relu(self.conv1.forward(x));
features.push(x.clone());
let x = max_pool2d_squeeze(x);
let x = self.fire1.forward(x);
let x = self.fire2.forward(x);
features.push(x.clone());
let x = max_pool2d_squeeze(x);
let x = self.fire3.forward(x);
let x = self.fire4.forward(x);
features.push(x.clone());
let x = max_pool2d_squeeze(x);
let x = self.fire5.forward(x);
features.push(x.clone());
let x = self.fire6.forward(x);
features.push(x.clone());
let x = self.fire7.forward(x);
features.push(x.clone());
let x = self.fire8.forward(x);
features.push(x);
features
}
}
fn max_pool2d_squeeze<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], true)
}