use crate::layers::drop::drop_block::DropBlockOptions;
use crate::models::resnet::basic_block::{BasicBlock, BasicBlockConfig, BasicBlockMeta};
use crate::models::resnet::bottleneck_block::{
BottleneckBlock, BottleneckBlockConfig, BottleneckBlockMeta, BottleneckPolicyConfig,
};
use crate::models::resnet::util::stride_div_output_resolution;
use crate::utility::probability::expect_probability;
use burn::nn::BatchNormConfig;
use burn::nn::activation::ActivationConfig;
use burn::nn::norm::NormalizationConfig;
use burn::prelude::{Backend, Config, Module, Tensor};
#[derive(Config, Debug)]
pub struct ResidualBlockContractConfig {
pub in_planes: usize,
pub out_planes: usize,
#[config(default = 1)]
pub dilation: usize,
#[config(default = "None")]
pub first_dilation: Option<usize>,
#[config(default = "false")]
pub downsample_input: bool,
#[config(default = "None")]
pub bottleneck_policy: Option<BottleneckPolicyConfig>,
#[config(default = "NormalizationConfig::Batch(BatchNormConfig::new(0))")]
pub normalization: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub activation: ActivationConfig,
}
impl ResidualBlockContractConfig {
pub fn to_structure(self) -> ResidualBlockStructureConfig {
let stride = if self.downsample_input { 2 } else { 1 };
match self.bottleneck_policy {
None => BasicBlockConfig::new(self.in_planes, self.out_planes)
.with_stride(stride)
.with_dilation(self.dilation)
.with_first_dilation(self.first_dilation)
.with_normalization(self.normalization)
.with_activation(self.activation)
.into(),
Some(policy) => BottleneckBlockConfig::new(self.in_planes, self.out_planes)
.with_stride(stride)
.with_dilation(self.dilation)
.with_first_dilation(self.first_dilation)
.with_normalization(self.normalization)
.with_activation(self.activation)
.with_policy(policy)
.into(),
}
}
}
impl From<ResidualBlockContractConfig> for ResidualBlockStructureConfig {
fn from(config: ResidualBlockContractConfig) -> Self {
config.to_structure()
}
}
pub trait ResidualBlockMeta {
fn in_planes(&self) -> usize;
fn out_planes(&self) -> usize;
fn stride(&self) -> usize;
fn output_resolution(
&self,
input_resolution: [usize; 2],
) -> [usize; 2] {
stride_div_output_resolution(input_resolution, self.stride())
}
}
#[derive(Config, Debug)]
pub enum ResidualBlockStructureConfig {
Basic(BasicBlockConfig),
Bottleneck(BottleneckBlockConfig),
}
impl ResidualBlockMeta for ResidualBlockStructureConfig {
fn in_planes(&self) -> usize {
match self {
Self::Basic(config) => config.in_planes(),
Self::Bottleneck(config) => config.in_planes(),
}
}
fn out_planes(&self) -> usize {
match self {
Self::Basic(config) => config.out_planes(),
Self::Bottleneck(config) => config.out_planes(),
}
}
fn stride(&self) -> usize {
match self {
Self::Basic(config) => config.stride(),
Self::Bottleneck(config) => config.stride(),
}
}
fn output_resolution(
&self,
input_resolution: [usize; 2],
) -> [usize; 2] {
match self {
Self::Basic(config) => config.output_resolution(input_resolution),
Self::Bottleneck(config) => config.output_resolution(input_resolution),
}
}
}
impl From<BasicBlockConfig> for ResidualBlockStructureConfig {
fn from(config: BasicBlockConfig) -> Self {
Self::Basic(config)
}
}
impl From<BottleneckBlockConfig> for ResidualBlockStructureConfig {
fn from(config: BottleneckBlockConfig) -> Self {
Self::Bottleneck(config)
}
}
impl ResidualBlockStructureConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> ResidualBlock<B> {
match self {
Self::Basic(config) => config.init(device).into(),
Self::Bottleneck(config) => config.init(device).into(),
}
}
pub fn with_drop_block(
self,
options: Option<DropBlockOptions>,
) -> Self {
match self {
Self::Basic(config) => config.with_drop_block(options).into(),
Self::Bottleneck(config) => config.with_drop_block(options).into(),
}
}
pub fn with_drop_path_prob(
self,
drop_path_prob: f64,
) -> Self {
let drop_path_prob = expect_probability(drop_path_prob);
match self {
Self::Basic(config) => config.with_drop_path_prob(drop_path_prob).into(),
Self::Bottleneck(config) => config.with_drop_path_prob(drop_path_prob).into(),
}
}
}
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum ResidualBlock<B: Backend> {
Basic(BasicBlock<B>),
Bottleneck(BottleneckBlock<B>),
}
impl<B: Backend> From<BasicBlock<B>> for ResidualBlock<B> {
fn from(block: BasicBlock<B>) -> Self {
Self::Basic(block)
}
}
impl<B: Backend> From<BottleneckBlock<B>> for ResidualBlock<B> {
fn from(block: BottleneckBlock<B>) -> Self {
Self::Bottleneck(block)
}
}
impl<B: Backend> ResidualBlockMeta for ResidualBlock<B> {
fn in_planes(&self) -> usize {
match self {
Self::Basic(block) => block.in_planes(),
Self::Bottleneck(block) => block.in_planes(),
}
}
fn out_planes(&self) -> usize {
match self {
Self::Basic(block) => block.out_planes(),
Self::Bottleneck(block) => block.out_planes(),
}
}
fn stride(&self) -> usize {
match self {
Self::Basic(block) => block.stride(),
Self::Bottleneck(block) => block.stride(),
}
}
}
impl<B: Backend> ResidualBlock<B> {
pub fn debug_print(&self) {
match self {
Self::Basic(block) => block.debug_print(),
Self::Bottleneck(block) => block.debug_print(),
}
}
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
match self {
Self::Basic(block) => block.forward(input),
Self::Bottleneck(block) => block.forward(input),
}
}
pub fn with_drop_path_prob(
self,
drop_path_prob: f64,
) -> Self {
let drop_path_prob = expect_probability(drop_path_prob);
match self {
Self::Basic(block) => block.with_drop_path_prob(drop_path_prob).into(),
Self::Bottleneck(block) => block.with_drop_path_prob(drop_path_prob).into(),
}
}
pub fn with_drop_block(
self,
options: Option<DropBlockOptions>,
) -> Self {
match self {
Self::Basic(config) => config.with_drop_block(options).into(),
Self::Bottleneck(config) => config.with_drop_block(options).into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bimm_contracts::assert_shape_contract;
use burn::backend::NdArray;
#[test]
fn test_residual_block_config() {
let in_planes = 16;
let out_planes = 32;
{
let inner_cfg = BasicBlockConfig::new(in_planes, out_planes).with_stride(2);
let cfg: ResidualBlockStructureConfig = inner_cfg.clone().into();
assert!(matches!(cfg, ResidualBlockStructureConfig::Basic(_)));
assert_eq!(cfg.in_planes(), in_planes);
assert_eq!(cfg.out_planes(), out_planes);
assert_eq!(cfg.stride(), 2);
assert_eq!(cfg.output_resolution([20, 20]), [10, 10]);
}
{
let inner_cfg = BottleneckBlockConfig::new(in_planes, out_planes).with_stride(2);
let cfg: ResidualBlockStructureConfig = inner_cfg.clone().into();
assert!(matches!(cfg, ResidualBlockStructureConfig::Bottleneck(_)));
assert_eq!(cfg.in_planes(), in_planes);
assert_eq!(cfg.out_planes(), out_planes);
assert_eq!(cfg.stride(), 2);
assert_eq!(cfg.output_resolution([20, 20]), [10, 10]);
}
}
#[test]
fn test_residual_block_basic_block() {
type B = NdArray;
let device = Default::default();
let batch_size = 2;
let in_planes = 16;
let planes = 32;
let in_height = 8;
let in_width = 8;
let out_height = 4;
let out_width = 4;
let cfg: ResidualBlockStructureConfig = BasicBlockConfig::new(in_planes, planes)
.with_stride(2)
.into();
let block: ResidualBlock<B> = cfg.init(&device);
assert!(matches!(block, ResidualBlock::Basic(_)));
assert_eq!(block.in_planes(), in_planes);
assert_eq!(block.out_planes(), planes);
assert_eq!(block.stride(), 2);
assert_eq!(block.output_resolution([20, 20]), [10, 10]);
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_channels", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_channels", planes),
("out_height", out_height),
("out_width", out_width)
],
);
}
#[test]
fn test_residual_block_bottleneck_block() {
use burn::backend::Wgpu;
type B = Wgpu;
let device = Default::default();
let batch_size = 2;
let in_planes = 16;
let planes = 32;
let in_height = 8;
let in_width = 8;
let out_height = 4;
let out_width = 4;
let cfg: ResidualBlockStructureConfig = BottleneckBlockConfig::new(in_planes, planes)
.with_stride(2)
.into();
let block: ResidualBlock<B> = cfg.init(&device);
assert!(matches!(block, ResidualBlock::Bottleneck(_)));
assert_eq!(block.in_planes(), in_planes);
assert_eq!(block.out_planes(), planes);
assert_eq!(block.stride(), 2);
assert_eq!(block.output_resolution([20, 20]), [10, 10]);
let input = Tensor::ones([batch_size, in_planes, in_height, in_width], &device);
let output = block.forward(input);
assert_shape_contract!(
["batch", "out_planes", "out_height", "out_width"],
&output.dims(),
&[
("batch", batch_size),
("out_planes", planes),
("out_height", out_height),
("out_width", out_width)
],
);
}
}