use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::activation::relu;
use burn::tensor::backend::Backend;
use burn_nn::PaddingConfig2d;
use burn_nn::conv::{Conv2d, Conv2dConfig};
#[derive(Module, Debug)]
pub struct AlexFeatureExtractor<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
conv3: Conv2d<B>,
conv4: Conv2d<B>,
conv5: Conv2d<B>,
}
impl<B: Backend> AlexFeatureExtractor<B> {
pub fn new(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([3, 64], [11, 11])
.with_stride([4, 4])
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
.with_bias(true)
.init(device),
conv2: Conv2dConfig::new([64, 192], [5, 5])
.with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
.with_bias(true)
.init(device),
conv3: Conv2dConfig::new([192, 384], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.with_bias(true)
.init(device),
conv4: Conv2dConfig::new([384, 256], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.with_bias(true)
.init(device),
conv5: Conv2dConfig::new([256, 256], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.with_bias(true)
.init(device),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {
let mut features = Vec::with_capacity(5);
let x = relu(self.conv1.forward(x));
features.push(x.clone());
let x = max_pool2d_alex(x);
let x = relu(self.conv2.forward(x));
features.push(x.clone());
let x = max_pool2d_alex(x);
let x = relu(self.conv3.forward(x));
features.push(x.clone());
let x = relu(self.conv4.forward(x));
features.push(x.clone());
let x = relu(self.conv5.forward(x));
features.push(x);
features
}
}
fn max_pool2d_alex<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], false)
}