use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::activation::relu;
use burn::tensor::backend::Backend;
use burn_nn::PaddingConfig2d;
use burn_nn::conv::{Conv2d, Conv2dConfig};
use super::l2pool::{L2Pool2d, L2Pool2dConfig};
#[derive(Module, Debug)]
pub struct Vgg16L2PoolExtractor<B: Backend> {
pub(crate) conv1_1: Conv2d<B>,
pub(crate) conv1_2: Conv2d<B>,
pub(crate) pool1: L2Pool2d<B>,
pub(crate) conv2_1: Conv2d<B>,
pub(crate) conv2_2: Conv2d<B>,
pub(crate) pool2: L2Pool2d<B>,
pub(crate) conv3_1: Conv2d<B>,
pub(crate) conv3_2: Conv2d<B>,
pub(crate) conv3_3: Conv2d<B>,
pub(crate) pool3: L2Pool2d<B>,
pub(crate) conv4_1: Conv2d<B>,
pub(crate) conv4_2: Conv2d<B>,
pub(crate) conv4_3: Conv2d<B>,
pub(crate) pool4: L2Pool2d<B>,
pub(crate) conv5_1: Conv2d<B>,
pub(crate) conv5_2: Conv2d<B>,
pub(crate) conv5_3: Conv2d<B>,
}
impl<B: Backend> Vgg16L2PoolExtractor<B> {
pub fn new(device: &B::Device) -> Self {
let pool_config = L2Pool2dConfig::default();
Self {
conv1_1: Conv2dConfig::new([3, 64], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv1_2: Conv2dConfig::new([64, 64], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
pool1: pool_config.init(64, device),
conv2_1: Conv2dConfig::new([64, 128], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv2_2: Conv2dConfig::new([128, 128], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
pool2: pool_config.init(128, device),
conv3_1: Conv2dConfig::new([128, 256], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv3_2: Conv2dConfig::new([256, 256], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv3_3: Conv2dConfig::new([256, 256], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
pool3: pool_config.init(256, device),
conv4_1: Conv2dConfig::new([256, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv4_2: Conv2dConfig::new([512, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv4_3: Conv2dConfig::new([512, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
pool4: pool_config.init(512, device),
conv5_1: Conv2dConfig::new([512, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv5_2: Conv2dConfig::new([512, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
conv5_3: Conv2dConfig::new([512, 512], [3, 3])
.with_padding(PaddingConfig2d::Same)
.init(device),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {
let mut features = Vec::with_capacity(6);
features.push(x.clone());
let x = relu(self.conv1_1.forward(x));
let x = relu(self.conv1_2.forward(x));
features.push(x.clone());
let x = self.pool1.forward(x);
let x = relu(self.conv2_1.forward(x));
let x = relu(self.conv2_2.forward(x));
features.push(x.clone());
let x = self.pool2.forward(x);
let x = relu(self.conv3_1.forward(x));
let x = relu(self.conv3_2.forward(x));
let x = relu(self.conv3_3.forward(x));
features.push(x.clone());
let x = self.pool3.forward(x);
let x = relu(self.conv4_1.forward(x));
let x = relu(self.conv4_2.forward(x));
let x = relu(self.conv4_3.forward(x));
features.push(x.clone());
let x = self.pool4.forward(x);
let x = relu(self.conv5_1.forward(x));
let x = relu(self.conv5_2.forward(x));
let x = relu(self.conv5_3.forward(x));
features.push(x);
features
}
}