use bimm_contracts::{
assert_shape_contract_periodically,
unpack_shape_contract,
};
use burn::{
config::Config,
module::Module,
nn::{
BatchNorm,
BatchNormConfig,
Initializer,
conv::{
Conv2d,
Conv2dConfig,
},
},
prelude::{
Backend,
Tensor,
},
};
pub trait ConvNorm2dMeta {
fn in_channels(&self) -> usize;
fn groups(&self) -> usize;
fn out_channels(&self) -> usize;
fn stride(&self) -> &[usize; 2];
}
#[derive(Config, Debug)]
pub struct ConvNorm2dConfig {
pub conv: Conv2dConfig,
}
impl ConvNorm2dMeta for ConvNorm2dConfig {
fn in_channels(&self) -> usize {
self.conv.channels[0]
}
fn groups(&self) -> usize {
self.conv.groups
}
fn out_channels(&self) -> usize {
self.conv.channels[1]
}
fn stride(&self) -> &[usize; 2] {
&self.conv.stride
}
}
impl From<Conv2dConfig> for ConvNorm2dConfig {
fn from(conv: Conv2dConfig) -> Self {
Self { conv }
}
}
impl ConvNorm2dConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> ConvNorm2d<B> {
ConvNorm2d {
conv: self.conv.init(device),
norm: BatchNormConfig::new(self.conv.channels[1]).init(device),
}
}
pub fn with_initializer(
self,
initializer: Initializer,
) -> Self {
Self {
conv: self.conv.with_initializer(initializer),
}
}
}
#[derive(Module, Debug)]
pub struct ConvNorm2d<B: Backend> {
pub conv: Conv2d<B>,
pub norm: BatchNorm<B>,
}
impl<B: Backend> ConvNorm2dMeta for ConvNorm2d<B> {
fn in_channels(&self) -> usize {
self.conv.weight.shape().dims[1] * self.groups()
}
fn groups(&self) -> usize {
self.conv.groups
}
fn out_channels(&self) -> usize {
self.conv.weight.shape().dims[0]
}
fn stride(&self) -> &[usize; 2] {
&self.conv.stride
}
}
impl<B: Backend> ConvNorm2d<B> {
pub fn zero_init_norm(&mut self) {
self.norm.gamma = self.norm.gamma.clone().map(|p| p.slice_fill([..], 0.0));
}
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
let [batch, out_height, out_width] = unpack_shape_contract!(
[
"batch",
"in_channels",
"in_height" = "out_height" * "height_stride",
"in_width" = "out_width" * "width_stride"
],
&input.dims(),
&["batch", "out_height", "out_width"],
&[
("in_channels", self.in_channels()),
("height_stride", self.stride()[0]),
("width_stride", self.stride()[1]),
]
);
let x = self.conv.forward(input);
let x = self.norm.forward(x);
assert_shape_contract_periodically!(
["batch", "out_channels", "out_height", "out_width"],
&x.dims(),
&[
("batch", batch),
("out_channels", self.out_channels()),
("out_height", out_height),
("out_width", out_width)
]
);
x
}
}
#[cfg(test)]
mod tests {
use burn::nn::PaddingConfig2d;
use super::*;
#[test]
fn test_conv_norm_config() {
let inner_config = Conv2dConfig::new([2, 4], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.with_bias(false);
let config: ConvNorm2dConfig = inner_config.clone().into();
assert_eq!(&config.conv.channels, &inner_config.channels);
assert_eq!(&config.conv.kernel_size, &inner_config.kernel_size);
assert_eq!(&config.conv.stride, &inner_config.stride);
}
}