use bimm_contracts::{
assert_shape_contract_periodically,
unpack_shape_contract,
};
use burn::{
config::Config,
module::Module,
nn::{
activation::{
Activation,
ActivationConfig,
},
conv::{
Conv2d,
Conv2dConfig,
},
norm::{
Normalization,
NormalizationConfig,
},
},
prelude::{
Backend,
Tensor,
},
};
#[derive(Config, Debug)]
pub struct AbstractCNA2dConfig {
pub norm: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub act: ActivationConfig,
}
impl AbstractCNA2dConfig {
pub fn build_config(
&self,
conv: Conv2dConfig,
) -> CNA2dConfig {
CNA2dConfig {
conv,
norm: self.norm.clone(),
act: self.act.clone(),
}
.match_norm_features()
}
}
pub trait CNA2dMeta {
fn in_channels(&self) -> usize;
fn out_channels(&self) -> usize;
fn groups(&self) -> usize;
fn stride(&self) -> [usize; 2];
}
#[derive(Config, Debug)]
pub struct CNA2dConfig {
pub conv: Conv2dConfig,
pub norm: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub act: ActivationConfig,
}
impl CNA2dMeta for CNA2dConfig {
fn in_channels(&self) -> usize {
self.conv.channels[0]
}
fn out_channels(&self) -> usize {
self.conv.channels[1]
}
fn groups(&self) -> usize {
self.conv.groups
}
fn stride(&self) -> [usize; 2] {
self.conv.stride
}
}
impl CNA2dConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> CNA2d<B> {
let out_channels = self.out_channels();
CNA2d {
conv: self.conv.init(device),
norm: self.norm.with_num_features(out_channels).init(device),
act: self.act.init(device),
}
}
pub fn match_norm_features(self) -> Self {
let features = self.out_channels();
let norm = self.norm.with_num_features(features);
Self { norm, ..self }
}
}
#[derive(Module, Debug)]
pub struct CNA2d<B: Backend> {
pub conv: Conv2d<B>,
pub norm: Normalization<B>,
pub act: Activation<B>,
}
impl<B: Backend> CNA2dMeta for CNA2d<B> {
fn in_channels(&self) -> usize {
self.conv.weight.dims()[1] * self.groups()
}
fn out_channels(&self) -> usize {
self.conv.weight.dims()[0]
}
fn groups(&self) -> usize {
self.conv.groups
}
fn stride(&self) -> [usize; 2] {
self.conv.stride
}
}
impl<B: Backend> CNA2d<B> {
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
self.map_forward(input, |x| x)
}
pub fn map_forward<F>(
&self,
input: Tensor<B, 4>,
f: F,
) -> Tensor<B, 4>
where
F: FnOnce(Tensor<B, 4>) -> Tensor<B, 4>,
{
let [batch, out_height, out_width] = unpack_shape_contract!(
[
"batch",
"in_channels",
"in_height" = "out_height" * "height_stride",
"in_width" = "out_width" * "width_stride"
],
&input.dims(),
&["batch", "out_height", "out_width"],
&[
("in_channels", self.in_channels()),
("height_stride", self.stride()[0]),
("width_stride", self.stride()[1]),
]
);
let x = self.conv.forward(input);
assert_shape_contract_periodically!(
["batch", "out_channels", "out_height", "out_width"],
&x.dims(),
&[
("batch", batch),
("out_channels", self.out_channels()),
("out_height", out_height),
("out_width", out_width)
]
);
let x = self.norm.forward(x);
let x = f(x);
let x = self.act.forward(x);
assert_shape_contract_periodically!(
["batch", "out_channels", "out_height", "out_width"],
&x.dims(),
&[
("batch", batch),
("out_channels", self.out_channels()),
("out_height", out_height),
("out_width", out_width)
]
);
x
}
}
#[cfg(test)]
mod tests {
use burn::{
backend::{
Autodiff,
NdArray,
},
nn::{
BatchNormConfig,
PaddingConfig2d,
activation::ActivationConfig,
norm::NormalizationConfig,
},
tensor::Distribution,
};
use super::*;
#[test]
fn test_conv_norm_config() {
let abstract_config =
AbstractCNA2dConfig::new(NormalizationConfig::Batch(BatchNormConfig::new(0)));
let conv_config = Conv2dConfig::new([2, 4], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.with_bias(false);
let config: CNA2dConfig = abstract_config.build_config(conv_config.clone());
assert_eq!(config.in_channels(), 2);
assert_eq!(config.out_channels(), 4);
assert_eq!(config.groups(), 1);
assert_eq!(config.stride(), [2, 2]);
}
#[test]
fn test_cna() {
type B = Autodiff<NdArray<f32>>;
let device = Default::default();
let config = CNA2dConfig::new(
Conv2dConfig::new([2, 4], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.with_bias(false),
NormalizationConfig::Batch(BatchNormConfig::new(0)),
)
.with_act(ActivationConfig::Relu);
let layer: CNA2d<B> = config.init(&device);
assert_eq!(layer.in_channels(), 2);
assert_eq!(layer.out_channels(), 4);
assert_eq!(layer.groups(), 1);
assert_eq!(layer.stride(), [2, 2]);
let batch_size = 2;
let height = 10;
let width = 10;
let channels = 2;
let input = Tensor::random(
[batch_size, channels, height, width],
Distribution::Default,
&device,
);
{
let output = layer.forward(input.clone());
let expected = {
let x = layer.conv.forward(input.clone());
let x = layer.norm.forward(x);
let x = layer.act.forward(x);
x
};
output.to_data().assert_eq(&expected.to_data(), true);
}
{
let hook = |x| x * 2.0;
let output = layer.map_forward(input.clone(), hook);
let expected = {
let x = layer.conv.forward(input.clone());
let x = layer.norm.forward(x);
let x = hook(x);
let x = layer.act.forward(x);
x
};
output.to_data().assert_eq(&expected.to_data(), true);
}
}
}