use crate::layers::blocks::cna::{CNA2d, CNA2dConfig};
use burn::module::Module;
use burn::nn::PaddingConfig2d;
use burn::nn::activation::ActivationConfig;
use burn::nn::conv::Conv2dConfig;
use burn::nn::norm::NormalizationConfig;
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
use burn::prelude::{Backend, Tensor};
#[derive(Debug, Clone, Default)]
pub enum ResNetStemContractConfig {
#[default]
Default,
Deep {
stem_width: usize,
},
DeepTiered {
stem_width: usize,
},
}
impl ResNetStemContractConfig {
pub fn to_structure(
&self,
in_channels: usize,
normalization: NormalizationConfig,
activation: ActivationConfig,
) -> ResNetStemStructureConfig {
match self {
ResNetStemContractConfig::Default => (),
_ => unimplemented!("{:?}", self),
}
let cna1 = CNA2dConfig {
conv: Conv2dConfig::new([in_channels, 64], [7, 7])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(3, 3))
.with_bias(false),
norm: normalization.clone(),
act: activation.clone(),
};
let pool = Some(
MaxPool2dConfig::new([3, 3])
.with_strides([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1)),
);
ResNetStemStructureConfig {
cna1,
cna2: None,
cna3: None,
pool,
}
}
}
#[derive(Debug, Clone)]
pub struct ResNetStemStructureConfig {
pub cna1: CNA2dConfig,
pub cna2: Option<CNA2dConfig>,
pub cna3: Option<CNA2dConfig>,
pub pool: Option<MaxPool2dConfig>,
}
#[derive(Module, Debug)]
pub struct ResNetStem<B: Backend> {
pub cna1: CNA2d<B>,
pub cna2: Option<CNA2d<B>>,
pub cna3: Option<CNA2d<B>>,
pub pool: Option<MaxPool2d>,
}
impl<B: Backend> ResNetStem<B> {
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
let mut x = input;
x = self.cna1.forward(x);
if let Some(cna2) = &self.cna2 {
x = cna2.forward(x);
}
if let Some(cna3) = &self.cna3 {
x = cna3.forward(x);
}
if let Some(pool) = &self.pool {
x = pool.forward(x);
}
x
}
}