use crate::layers::blocks::cna::{AbstractCNA2dConfig, CNA2d, CNA2dConfig, CNA2dMeta};
use crate::layers::drop::drop_block::{DropBlock2d, DropBlock2dConfig, DropBlockOptions};
use crate::layers::drop::drop_path::{DropPath, DropPathConfig};
use crate::models::resnet::downsample::{ResNetDownsample, ResNetDownsampleConfig};
use crate::models::resnet::util::{scalar_to_array, stride_div_output_resolution};
use crate::utility::probability::expect_probability;
use burn::nn::BatchNormConfig;
use burn::nn::PaddingConfig2d;
use burn::nn::activation::ActivationConfig;
use burn::nn::conv::Conv2dConfig;
use burn::nn::norm::NormalizationConfig;
use burn::prelude::{Backend, Config, Module, Tensor};
pub trait BasicBlockMeta {
fn in_planes(&self) -> usize;
fn out_planes(&self) -> usize;
fn dilation(&self) -> usize;
fn first_dilation(&self) -> Option<usize>;
fn effective_first_dilation(&self) -> usize {
self.first_dilation().unwrap_or(self.dilation())
}
fn reduce_first(&self) -> usize;
fn first_planes(&self) -> usize {
self.out_planes() / self.reduce_first()
}
fn stride(&self) -> usize;
fn output_resolution(
&self,
input_resolution: [usize; 2],
) -> [usize; 2] {
stride_div_output_resolution(input_resolution, self.stride())
}
}
#[derive(Config, Debug)]
pub struct BasicBlockConfig {
pub in_planes: usize,
pub out_planes: usize,
#[config(default = 1)]
pub reduce_first: usize,
#[config(default = 1)]
pub stride: usize,
#[config(default = 1)]
pub dilation: usize,
#[config(default = "None")]
pub first_dilation: Option<usize>,
#[config(default = "1")]
pub down_kernel_size: usize,
#[config(default = "0.0")]
pub drop_path_prob: f64,
#[config(default = "None")]
pub drop_block: Option<DropBlockOptions>,
#[config(default = "NormalizationConfig::Batch(BatchNormConfig::new(0))")]
pub normalization: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub activation: ActivationConfig,
}
impl BasicBlockMeta for BasicBlockConfig {
fn in_planes(&self) -> usize {
self.in_planes
}
fn out_planes(&self) -> usize {
self.out_planes
}
fn dilation(&self) -> usize {
self.dilation
}
fn first_dilation(&self) -> Option<usize> {
self.first_dilation
}
fn reduce_first(&self) -> usize {
self.reduce_first
}
fn stride(&self) -> usize {
self.stride
}
}
impl BasicBlockConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> BasicBlock<B> {
let drop_path_prob = expect_probability(self.drop_path_prob);
let in_planes = self.in_planes();
let first_planes = self.first_planes();
let out_planes = self.out_planes();
let first_dilation = self.effective_first_dilation();
let stride = self.stride();
let downsample = if stride != 1 || in_planes != out_planes {
ResNetDownsampleConfig::new(self.in_planes(), self.out_planes(), self.down_kernel_size)
.with_stride(self.stride())
.with_dilation(first_dilation)
.with_norm(self.normalization.clone())
.into()
} else {
None
};
let cna_builder = AbstractCNA2dConfig {
norm: self.normalization.clone(),
act: self.activation.clone(),
};
let cna1: CNA2dConfig = cna_builder.build_config(
Conv2dConfig::new([in_planes, first_planes], scalar_to_array(3))
.with_stride(scalar_to_array(stride))
.with_dilation(scalar_to_array(first_dilation))
.with_padding(PaddingConfig2d::Explicit(first_dilation, first_dilation))
.with_bias(false),
);
let cna2: CNA2dConfig = cna_builder.build_config(
Conv2dConfig::new([first_planes, out_planes], scalar_to_array(3))
.with_dilation(scalar_to_array(self.dilation))
.with_padding(PaddingConfig2d::Explicit(self.dilation, self.dilation))
.with_bias(false),
);
BasicBlock {
reduce_first: self.reduce_first,
downsample: downsample.as_ref().map(|cfg| cfg.clone().init(device)),
cna1: cna1.init(device),
cna2: cna2.init(device),
drop_block: self
.drop_block
.as_ref()
.map(|options| DropBlock2dConfig::from(options.clone()).init()),
drop_path: if drop_path_prob != 0.0 {
DropPathConfig::new()
.with_drop_prob(drop_path_prob)
.init()
.into()
} else {
None
},
}
}
}
#[derive(Module, Debug)]
pub struct BasicBlock<B: Backend> {
pub reduce_first: usize,
pub downsample: Option<ResNetDownsample<B>>,
pub cna1: CNA2d<B>,
pub cna2: CNA2d<B>,
pub drop_block: Option<DropBlock2d>,
pub drop_path: Option<DropPath>,
}
impl<B: Backend> BasicBlockMeta for BasicBlock<B> {
fn in_planes(&self) -> usize {
self.cna1.in_channels()
}
fn out_planes(&self) -> usize {
self.cna2.out_channels()
}
fn dilation(&self) -> usize {
self.cna1.conv.dilation[0]
}
fn first_dilation(&self) -> Option<usize> {
let d1 = self.cna1.conv.dilation[0];
let d2 = self.cna2.conv.dilation[0];
if d1 == d2 { None } else { Some(d1) }
}
fn reduce_first(&self) -> usize {
self.reduce_first
}
fn first_planes(&self) -> usize {
self.cna1.out_channels()
}
fn stride(&self) -> usize {
self.cna1.stride()[0]
}
}
impl<B: Backend> BasicBlock<B> {
pub fn debug_print(&self) {
println!("#### BasicBlock");
if self.downsample.is_some() {
println!(" downsample");
}
println!(" in_planes: {}", self.in_planes());
println!(" out_planes: {}", self.out_planes());
}
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
#[cfg(debug_assertions)]
let [batch, out_height, out_width] = bimm_contracts::unpack_shape_contract!(
[
"batch",
"in_planes",
"in_height" = "out_height" * "stride",
"in_width" = "out_width" * "stride"
],
&input.dims(),
&["batch", "out_height", "out_width"],
&[("in_planes", self.in_planes()), ("stride", self.stride())],
);
let identity = match &self.downsample {
Some(downsample) => downsample.forward(input.clone()),
None => input.clone(),
};
#[cfg(debug_assertions)]
bimm_contracts::define_shape_contract!(
OUT_CONTRACT,
["batch", "out_planes", "out_height", "out_width"],
);
#[cfg(debug_assertions)]
let out_bindings = [
("batch", batch),
("out_planes", self.out_planes()),
("out_height", out_height),
("out_width", out_width),
];
let x = self.cna1.map_forward(input, |x| match &self.drop_block {
Some(drop_block) => drop_block.forward(x),
None => x,
});
#[cfg(debug_assertions)]
bimm_contracts::assert_shape_contract_periodically!(
["batch", "first_planes", "out_height", "out_width"],
&x.dims(),
&[
("batch", batch),
("first_planes", self.first_planes()),
("out_height", out_height),
("out_width", out_width),
]
);
let x = self.cna2.map_forward(x, |x| {
let x = match &self.drop_path {
Some(drop_path) => drop_path.forward(x),
None => x,
};
x + identity
});
#[cfg(debug_assertions)]
bimm_contracts::assert_shape_contract_periodically!(OUT_CONTRACT, &x.dims(), &out_bindings);
x
}
pub fn with_drop_path_prob(
self,
drop_path_prob: f64,
) -> Self {
let drop_path_prob = expect_probability(drop_path_prob);
Self {
drop_path: if drop_path_prob == 0.0 {
None
} else {
DropPathConfig::new()
.with_drop_prob(drop_path_prob)
.init()
.into()
},
..self
}
}
pub fn with_drop_block(
self,
drop_block: Option<DropBlockOptions>,
) -> Self {
Self {
drop_block: drop_block.map(|options| DropBlock2dConfig::from(options).init()),
..self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bimm_contracts::assert_shape_contract;
use burn::backend::{Autodiff, NdArray};
use burn::nn::activation::ActivationConfig;
#[test]
fn test_basic_block_config() {
let in_planes = 16;
let out_planes = 32;
let config = BasicBlockConfig::new(in_planes, out_planes);
assert_eq!(config.in_planes(), in_planes);
assert_eq!(config.out_planes(), out_planes);
assert_eq!(config.stride(), 1);
assert_eq!(config.output_resolution([16, 16]), [16, 16]);
assert!(matches!(config.activation, ActivationConfig::Relu));
let config = config
.with_stride(2)
.with_activation(ActivationConfig::Sigmoid);
assert_eq!(config.stride(), 2);
assert_eq!(config.output_resolution([16, 16]), [8, 8]);
assert!(matches!(config.activation, ActivationConfig::Sigmoid));
}
#[test]
#[should_panic(expected = "7 !~ in_height=(out_height*stride)")]
fn test_downsample_config_panic() {
let config = BasicBlockConfig::new(16, 32).with_stride(2);
assert_eq!(config.stride(), 2);
config.output_resolution([7, 7]);
}
#[test]
fn test_basic_block_meta() {
type B = NdArray<f32>;
let device = Default::default();
let in_planes = 2;
let out_planes = in_planes;
let block: BasicBlock<B> = BasicBlockConfig::new(in_planes, out_planes).init(&device);
assert_eq!(block.in_planes(), in_planes);
assert_eq!(block.out_planes(), out_planes);
assert_eq!(block.stride(), 1);
assert_eq!(block.output_resolution([16, 16]), [16, 16]);
}
#[test]
fn test_basic_block_forward_same_channels_no_downsample_autodiff() {
type B = Autodiff<NdArray<f32>>;
let device = Default::default();
let batch_size = 2;
let in_planes = 2;
let out_planes = 8;
let in_height = 8;
let in_width = 8;
let block: BasicBlock<B> = BasicBlockConfig::new(in_planes, out_planes).init(&device);
let out_planes = block.out_planes();
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_channels", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_channels", out_planes),
("out_height", in_height),
("out_width", in_width)
],
);
}
#[test]
fn test_basic_block_forward_downsample_drop_block_drop_path_autodiff() {
type B = Autodiff<NdArray<f32>>;
let device = Default::default();
let batch_size = 2;
let in_planes = 2;
let planes = 4;
let in_height = 8;
let in_width = 8;
let block: BasicBlock<B> = BasicBlockConfig::new(in_planes, planes)
.with_drop_path_prob(0.1)
.with_drop_block(Some(DropBlockOptions::default()))
.with_stride(2)
.init(&device);
let out_planes = block.out_planes();
let [out_height, out_width] = block.output_resolution([in_height, in_width]);
assert_eq!(out_height, 4);
assert_eq!(out_width, 4);
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_channels", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_channels", out_planes),
("out_height", out_height),
("out_width", out_width)
],
);
}
}