use burn_core as burn;
use crate::PaddingConfig2d;
use crate::conv::{Conv2d, Conv2dConfig};
use burn::module::Module;
use burn::tensor::{
Tensor,
activation::relu,
backend::Backend,
module::{avg_pool2d, max_pool2d},
};
#[derive(Module, Debug)]
pub struct Vgg19<B: Backend> {
use_avg_pool: bool,
pub conv1_1: Conv2d<B>,
conv1_2: Conv2d<B>,
conv2_1: Conv2d<B>,
conv2_2: Conv2d<B>,
conv3_1: Conv2d<B>,
conv3_2: Conv2d<B>,
conv3_3: Conv2d<B>,
conv3_4: Conv2d<B>,
conv4_1: Conv2d<B>,
conv4_2: Conv2d<B>,
conv4_3: Conv2d<B>,
conv4_4: Conv2d<B>,
conv5_1: Conv2d<B>,
}
impl<B: Backend> Vgg19<B> {
pub fn new(use_avg_pool: bool, device: &B::Device) -> Self {
let conv_config = |in_ch, out_ch| {
Conv2dConfig::new([in_ch, out_ch], [3, 3])
.with_stride([1, 1])
.with_padding(PaddingConfig2d::Same)
.init(device)
};
Self {
use_avg_pool,
conv1_1: conv_config(3, 64),
conv1_2: conv_config(64, 64),
conv2_1: conv_config(64, 128),
conv2_2: conv_config(128, 128),
conv3_1: conv_config(128, 256),
conv3_2: conv_config(256, 256),
conv3_3: conv_config(256, 256),
conv3_4: conv_config(256, 256),
conv4_1: conv_config(256, 512),
conv4_2: conv_config(512, 512),
conv4_3: conv_config(512, 512),
conv4_4: conv_config(512, 512),
conv5_1: conv_config(512, 512),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 3>> {
let pool_2d = |x| {
if self.use_avg_pool {
avg_pool2d(x, [2, 2], [2, 2], [0, 0], false, false)
} else {
max_pool2d(x, [2, 2], [2, 2], [0, 0], [1, 1], false)
}
};
let mut features = Vec::with_capacity(5);
let x1_1 = relu(self.conv1_1.forward(x));
let flattened_x1_1 = x1_1.clone().flatten(2, 3);
features.push(flattened_x1_1);
let x1_2 = relu(self.conv1_2.forward(x1_1));
let x1 = pool_2d(x1_2);
let x2_1 = relu(self.conv2_1.forward(x1));
let flattened_x2_1 = x2_1.clone().flatten(2, 3);
features.push(flattened_x2_1);
let x2_2 = relu(self.conv2_2.forward(x2_1));
let x2 = pool_2d(x2_2);
let x3_1 = relu(self.conv3_1.forward(x2));
let flattened_x3_1 = x3_1.clone().flatten(2, 3);
features.push(flattened_x3_1);
let x3_2 = relu(self.conv3_2.forward(x3_1));
let x3_3 = relu(self.conv3_3.forward(x3_2));
let x3_4 = relu(self.conv3_4.forward(x3_3));
let x3 = pool_2d(x3_4);
let x4_1 = relu(self.conv4_1.forward(x3));
let flattened_x4_1 = x4_1.clone().flatten(2, 3);
features.push(flattened_x4_1);
let x4_2 = relu(self.conv4_2.forward(x4_1));
let x4_3 = relu(self.conv4_3.forward(x4_2));
let x4_4 = relu(self.conv4_4.forward(x4_3));
let x4 = pool_2d(x4_4);
let x5_1 = relu(self.conv5_1.forward(x4));
let flattened_x5_1 = x5_1.flatten(2, 3);
features.push(flattened_x5_1);
features
}
}