use crate::layers::blocks::cna::{AbstractCNA2dConfig, CNA2d, CNA2dConfig, CNA2dMeta};
use crate::layers::drop::drop_block::{DropBlock2d, DropBlock2dConfig, DropBlockOptions};
use crate::layers::drop::drop_path::{DropPath, DropPathConfig};
use crate::models::resnet::downsample::{ResNetDownsample, ResNetDownsampleConfig};
use crate::models::resnet::util::{scalar_to_array, stride_div_output_resolution};
use crate::utility::probability::expect_probability;
use burn::nn::activation::ActivationConfig;
use burn::nn::conv::Conv2dConfig;
use burn::nn::norm::NormalizationConfig;
use burn::nn::{BatchNormConfig, PaddingConfig2d};
use burn::prelude::{Backend, Config, Module, Tensor};
#[derive(Config, Debug)]
pub struct BottleneckPolicyConfig {
#[config(default = "BOTTLENECK_BLOCK_DEFAULT_PINCH_FACTOR")]
pub pinch_factor: usize,
}
impl Default for BottleneckPolicyConfig {
fn default() -> Self {
Self::new()
}
}
pub trait BottleneckBlockMeta {
fn in_planes(&self) -> usize;
fn out_planes(&self) -> usize;
fn pinch_factor(&self) -> usize;
fn planes(&self) -> usize {
self.out_planes() / self.pinch_factor()
}
fn first_planes(&self) -> usize {
self.width() / self.reduce_first()
}
fn dilation(&self) -> usize;
fn cardinality(&self) -> usize;
fn base_width(&self) -> usize;
fn reduce_first(&self) -> usize;
fn width(&self) -> usize {
((self.planes() as f64 * self.base_width() as f64 / 64.0) as usize) * self.cardinality()
}
fn stride(&self) -> usize;
fn output_resolution(
&self,
input_resolution: [usize; 2],
) -> [usize; 2] {
stride_div_output_resolution(input_resolution, self.stride())
}
}
pub const BOTTLENECK_BLOCK_DEFAULT_PINCH_FACTOR: usize = 4;
#[derive(Config, Debug)]
pub struct BottleneckBlockConfig {
pub in_planes: usize,
pub out_planes: usize,
#[config(default = "1")]
pub cardinality: usize,
#[config(default = "64")]
pub base_width: usize,
#[config(default = "BottleneckPolicyConfig::default()")]
pub policy: BottleneckPolicyConfig,
#[config(default = 1)]
pub reduce_first: usize,
#[config(default = 1)]
pub stride: usize,
#[config(default = 1)]
pub dilation: usize,
#[config(default = "None")]
pub first_dilation: Option<usize>,
#[config(default = "1")]
pub down_kernel_size: usize,
#[config(default = "0.0")]
pub drop_path_prob: f64,
#[config(default = "None")]
pub drop_block: Option<DropBlockOptions>,
#[config(default = "NormalizationConfig::Batch(BatchNormConfig::new(0))")]
pub normalization: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub activation: ActivationConfig,
}
impl BottleneckBlockMeta for BottleneckBlockConfig {
fn in_planes(&self) -> usize {
self.in_planes
}
fn out_planes(&self) -> usize {
self.out_planes
}
fn pinch_factor(&self) -> usize {
self.policy.pinch_factor
}
fn dilation(&self) -> usize {
self.dilation
}
fn cardinality(&self) -> usize {
self.cardinality
}
fn base_width(&self) -> usize {
self.base_width
}
fn reduce_first(&self) -> usize {
self.reduce_first
}
fn stride(&self) -> usize {
self.stride
}
}
impl BottleneckBlockConfig {
fn first_dilation(&self) -> Option<usize> {
self.first_dilation
}
fn effective_first_dilation(&self) -> usize {
self.first_dilation().unwrap_or(self.dilation())
}
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> BottleneckBlock<B> {
let drop_path_prob = expect_probability(self.drop_path_prob);
let in_planes = self.in_planes();
let first_planes = self.first_planes();
let width = self.width();
let out_planes = self.out_planes();
let dilation = self.dilation();
let first_dilation = self.effective_first_dilation();
let stride = self.stride();
let enable_aa = false;
let _use_aa = enable_aa && (stride != 1 || first_dilation != dilation);
let downsample = if stride != 1 || in_planes != out_planes {
ResNetDownsampleConfig::new(in_planes, out_planes, self.down_kernel_size)
.with_stride(self.stride())
.with_dilation(self.dilation())
.with_norm(self.normalization.clone())
.into()
} else {
None
};
let cna_builder = AbstractCNA2dConfig {
norm: self.normalization.clone(),
act: self.activation.clone(),
};
let cna1: CNA2dConfig = cna_builder.build_config(
Conv2dConfig::new([in_planes, first_planes], scalar_to_array(1)).with_bias(false),
);
let cna2: CNA2dConfig = cna_builder.build_config(
Conv2dConfig::new([first_planes, width], scalar_to_array(3))
.with_bias(false)
.with_stride(scalar_to_array(stride))
.with_dilation(scalar_to_array(first_dilation))
.with_padding(PaddingConfig2d::Explicit(first_dilation, first_dilation))
.with_groups(self.cardinality()),
);
let cna3: CNA2dConfig = cna_builder.build_config(
Conv2dConfig::new([width, out_planes], scalar_to_array(1)).with_bias(false),
);
assert_eq!(self.in_planes(), cna1.in_channels());
assert_eq!(cna1.out_channels(), cna2.in_channels());
assert_eq!(cna2.out_channels(), cna3.in_channels());
assert_eq!(cna3.out_channels(), self.out_planes());
BottleneckBlock {
base_width: self.base_width,
pinch_factor: self.pinch_factor(),
reduce_first: self.reduce_first,
downsample: downsample.as_ref().map(|c| c.clone().init(device)),
cna1: cna1.init(device),
cna2: cna2.init(device),
cna3: cna3.init(device),
drop_block: self
.drop_block
.as_ref()
.map(|options| DropBlock2dConfig::from(options.clone()).init()),
drop_path: if drop_path_prob != 0.0 {
DropPathConfig::new()
.with_drop_prob(drop_path_prob)
.init()
.into()
} else {
None
},
}
}
}
#[derive(Module, Debug)]
pub struct BottleneckBlock<B: Backend> {
pub base_width: usize,
pub pinch_factor: usize,
pub reduce_first: usize,
pub downsample: Option<ResNetDownsample<B>>,
pub cna1: CNA2d<B>,
pub cna2: CNA2d<B>,
pub cna3: CNA2d<B>,
pub drop_block: Option<DropBlock2d>,
pub drop_path: Option<DropPath>,
}
impl<B: Backend> BottleneckBlockMeta for BottleneckBlock<B> {
fn in_planes(&self) -> usize {
self.cna1.in_channels()
}
fn out_planes(&self) -> usize {
self.cna3.out_channels()
}
fn pinch_factor(&self) -> usize {
self.pinch_factor
}
fn dilation(&self) -> usize {
self.cna3.conv.dilation[0]
}
fn cardinality(&self) -> usize {
self.cna2.groups()
}
fn base_width(&self) -> usize {
self.base_width
}
fn reduce_first(&self) -> usize {
self.reduce_first
}
fn width(&self) -> usize {
self.cna3.in_channels()
}
fn stride(&self) -> usize {
self.cna2.stride()[0]
}
}
impl<B: Backend> BottleneckBlock<B> {
pub fn debug_print(&self) {
eprintln!("#### BottleneckBlock");
if self.downsample.is_some() {
eprintln!(" downsample");
}
eprintln!(" in_planes: {}", self.in_planes());
eprintln!(" pinch_planes: {}", self.planes());
eprintln!(" width: {}", self.width());
eprintln!(" out_planes: {}", self.out_planes());
eprintln!(" stride: {}", self.stride());
eprintln!(" reduce_first: {}", self.reduce_first());
eprintln!(" cardinality: {}", self.cardinality());
eprintln!();
eprintln!(" cna1.in_channels: {}", self.cna1.in_channels());
eprintln!(" cna1.out_channels: {}", self.cna1.out_channels());
eprintln!(" cna2.in_channels: {}", self.cna2.in_channels());
eprintln!(" cna2.out_channels: {}", self.cna2.out_channels());
eprintln!(" cna3.in_channels: {}", self.cna3.in_channels());
eprintln!(" cna3.out_channels: {}", self.cna3.out_channels());
}
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
let [batch, in_height, out_height, in_width, out_width] = bimm_contracts::unpack_shape_contract!(
[
"batch",
"in_planes",
"in_height" = "out_height" * "stride",
"in_width" = "out_width" * "stride"
],
&input.dims(),
&["batch", "in_height", "out_height", "in_width", "out_width"],
&[("in_planes", self.in_planes()), ("stride", self.stride())],
);
let identity = match &self.downsample {
Some(downsample) => downsample.forward(input.clone()),
None => input.clone(),
};
bimm_contracts::define_shape_contract!(
OUT_CONTRACT,
["batch", "out_planes", "out_height", "out_width"],
);
let out_bindings = [
("batch", batch),
("out_planes", self.out_planes()),
("out_height", out_height),
("out_width", out_width),
];
bimm_contracts::assert_shape_contract_periodically!(
OUT_CONTRACT,
&identity.dims(),
&out_bindings
);
let x = self.cna1.forward(input);
bimm_contracts::assert_shape_contract_periodically!(
["batch", "pinch_planes", "in_height", "in_width"],
&x.dims(),
&[
("batch", batch),
("pinch_planes", self.planes()),
("in_height", in_height),
("in_width", in_width),
],
);
let x = self.cna2.map_forward(x, |x| match &self.drop_block {
Some(drop_block) => drop_block.forward(x),
None => x,
});
bimm_contracts::assert_shape_contract_periodically!(
["batch", "width", "out_height", "out_width"],
&x.dims(),
&[
("batch", batch),
("width", self.width()),
("out_height", out_height),
("out_width", out_width),
],
);
self.cna3.map_forward(x, |x| {
bimm_contracts::assert_shape_contract_periodically!(
OUT_CONTRACT,
&x.dims(),
&out_bindings
);
let x = match &self.drop_path {
Some(drop_path) => drop_path.forward(x),
None => x,
};
x + identity
})
}
pub fn with_drop_path_prob(
self,
drop_path_prob: f64,
) -> Self {
let drop_path_prob = expect_probability(drop_path_prob);
Self {
drop_path: if drop_path_prob == 0.0 {
None
} else {
DropPathConfig::new()
.with_drop_prob(drop_path_prob)
.init()
.into()
},
..self
}
}
pub fn with_drop_block(
self,
drop_block: Option<DropBlockOptions>,
) -> Self {
Self {
drop_block: drop_block.map(|options| DropBlock2dConfig::from(options).init()),
..self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bimm_contracts::assert_shape_contract;
use burn::backend::NdArray;
use burn::nn::activation::ActivationConfig;
#[test]
fn test_basic_block_config() {
let in_planes = 16;
let out_planes = 32;
let config = BottleneckBlockConfig::new(in_planes, out_planes);
assert_eq!(config.in_planes(), in_planes);
assert_eq!(config.out_planes(), out_planes);
assert_eq!(config.stride(), 1);
assert_eq!(config.output_resolution([16, 16]), [16, 16]);
assert!(matches!(config.activation, ActivationConfig::Relu));
let config = config
.with_stride(2)
.with_activation(ActivationConfig::Sigmoid);
assert_eq!(config.stride(), 2);
assert_eq!(config.output_resolution([16, 16]), [8, 8]);
assert!(matches!(config.activation, ActivationConfig::Sigmoid));
}
#[test]
#[should_panic(expected = "7 !~ in_height=(out_height*stride)")]
fn test_downsample_config_panic() {
let config = BottleneckBlockConfig::new(16, 32).with_stride(2);
assert_eq!(config.stride(), 2);
config.output_resolution([7, 7]);
}
#[test]
fn test_basic_block_meta() {
type B = NdArray<f32>;
let device = Default::default();
let in_planes = 2;
let out_planes = 2;
let block: BottleneckBlock<B> =
BottleneckBlockConfig::new(in_planes, out_planes).init(&device);
assert_eq!(block.in_planes(), in_planes);
assert_eq!(block.out_planes(), out_planes);
assert_eq!(block.stride(), 1);
assert_eq!(block.output_resolution([16, 16]), [16, 16]);
}
#[test]
fn test_basic_block_forward_same_channels_no_downsample_autodiff() {
use burn::backend::{Autodiff, Wgpu};
type B = Autodiff<Wgpu>;
let device = Default::default();
let batch_size = 2;
let in_planes = 2;
let planes = 8;
let in_height = 8;
let in_width = 8;
let block: BottleneckBlock<B> = BottleneckBlockConfig::new(in_planes, planes).init(&device);
let out_planes = block.out_planes();
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_channels", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_channels", out_planes),
("out_height", in_height),
("out_width", in_width)
],
);
}
#[test]
fn test_basic_block_forward_downsample_drop_block_drop_path_autodiff() {
use burn::backend::{Autodiff, Wgpu};
type B = Autodiff<Wgpu>;
let device = Default::default();
let batch_size = 2;
let in_planes = 2;
let planes = 4;
let in_height = 8;
let in_width = 8;
let block: BottleneckBlock<B> = BottleneckBlockConfig::new(in_planes, planes)
.with_drop_path_prob(0.1)
.with_drop_block(Some(DropBlockOptions::default()))
.with_stride(2)
.init(&device);
let out_planes = block.out_planes();
let [out_height, out_width] = block.output_resolution([in_height, in_width]);
assert_eq!(out_height, 4);
assert_eq!(out_width, 4);
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_channels", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_channels", out_planes),
("out_height", out_height),
("out_width", out_width)
],
);
}
}