use super::bottleneck_block::BottleneckPolicyConfig;
use super::layer_block::{
LayerBlock, LayerBlockContractConfig, LayerBlockMeta, LayerBlockStructureConfig,
};
use super::residual_block::{ResidualBlock, ResidualBlockStructureConfig};
use super::resnet_io::pytorch_stubs::load_resnet_stub_record;
use super::util::CONV_INTO_RELU_INITIALIZER;
use crate::layers::blocks::conv_norm::{ConvNorm2d, ConvNorm2dConfig};
use crate::layers::drop::drop_block::DropBlockOptions;
use crate::utility::probability::expect_probability;
use alloc::vec;
use alloc::vec::Vec;
use burn::module::Module;
use burn::nn::BatchNormConfig;
use burn::nn::activation::{Activation, ActivationConfig};
use burn::nn::conv::Conv2dConfig;
use burn::nn::norm::NormalizationConfig;
use burn::nn::pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, MaxPool2d, MaxPool2dConfig};
use burn::nn::{Initializer, Linear, LinearConfig, PaddingConfig2d};
use burn::prelude::{Backend, Config, Tensor};
use std::path::PathBuf;
pub const RESNET18_BLOCKS: [usize; 4] = [2, 2, 2, 2];
pub const RESNET34_BLOCKS: [usize; 4] = [3, 4, 6, 3];
pub const RESNET50_BLOCKS: [usize; 4] = [3, 4, 6, 3];
pub const RESNET101_BLOCKS: [usize; 4] = [3, 4, 23, 3];
pub const RESNET152_BLOCKS: [usize; 4] = [3, 8, 36, 3];
#[derive(Config, Debug)]
pub struct ResNetContractConfig {
pub layers: Vec<usize>,
pub num_classes: usize,
#[config(default = "64")]
pub stem_width: usize,
#[config(default = "32")]
pub output_stride: usize,
#[config(default = "None")]
pub bottleneck_policy: Option<BottleneckPolicyConfig>,
#[config(default = "NormalizationConfig::Batch(BatchNormConfig::new(0))")]
pub normalization: NormalizationConfig,
#[config(default = "ActivationConfig::Relu")]
pub activation: ActivationConfig,
}
impl ResNetContractConfig {
pub fn with_bottleneck(
self,
enable: bool,
) -> Self {
let policy = if enable {
Some(Default::default())
} else {
None
};
self.with_bottleneck_policy(policy)
}
#[allow(unused)]
pub fn to_layer_contracts(&self) -> Vec<LayerBlockContractConfig> {
let mut net_stride = 4;
let mut dilation = 1;
let mut prev_dilation = 1;
let mut layers: Vec<LayerBlockContractConfig> = Default::default();
let mut in_planes = self.stem_width;
for (stage_idx, &num_blocks) in self.layers.iter().enumerate() {
let downsample_input = {
let mut stride = if stage_idx == 0 { 1 } else { 2 };
if net_stride >= self.output_stride {
dilation *= stride;
stride = 1;
} else {
net_stride *= stride;
}
stride != 1
};
let first_dilation = prev_dilation;
let out_planes = if stage_idx == 0 {
match &self.bottleneck_policy {
Some(policy) => in_planes * policy.pinch_factor,
None => in_planes,
}
} else {
2 * in_planes
};
layers.push(
LayerBlockContractConfig::new(num_blocks, in_planes, out_planes)
.with_downsample_input(downsample_input)
.with_first_dilation(Some(first_dilation))
.with_dilation(dilation)
.with_bottleneck_policy(self.bottleneck_policy.clone())
.with_normalization(self.normalization.clone())
.with_activation(self.activation.clone()),
);
in_planes = out_planes;
prev_dilation = dilation;
}
layers
}
pub fn to_structure(self) -> ResNetStructureConfig {
ResNetStructureConfig::new(
ConvNorm2dConfig::from(
Conv2dConfig::new([3, self.stem_width], [7, 7])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(3, 3))
.with_bias(false),
)
.with_initializer(CONV_INTO_RELU_INITIALIZER.clone()),
self.to_layer_contracts()
.into_iter()
.map(|c| c.into())
.collect::<Vec<_>>(),
self.num_classes,
)
}
pub fn resnet18(num_classes: usize) -> Self {
Self::new(RESNET18_BLOCKS.to_vec(), num_classes) }
}
impl From<ResNetContractConfig> for ResNetStructureConfig {
#[allow(unused)]
fn from(config: ResNetContractConfig) -> Self {
config.to_structure()
}
}
#[derive(Config, Debug)]
pub struct ResNetStructureConfig {
pub input_conv_norm: ConvNorm2dConfig,
#[config(default = "CONV_INTO_RELU_INITIALIZER.clone().into()")]
pub input_conv_norm_initializer: Option<Initializer>,
#[config(default = "ActivationConfig::Relu")]
pub input_act: ActivationConfig,
pub layers: Vec<LayerBlockStructureConfig>,
pub num_classes: usize,
}
impl ResNetStructureConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> ResNet<B> {
let mut input_conv_norm = self.input_conv_norm.clone();
if let Some(initializer) = &self.input_conv_norm_initializer {
input_conv_norm.conv = input_conv_norm.conv.with_initializer(initializer.clone());
}
let head_planes = self.layers.last().unwrap().out_planes();
ResNet {
input_conv_norm: input_conv_norm.init(device),
input_act: self.input_act.init(device),
input_pool: MaxPool2dConfig::new([3, 3])
.with_strides([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init(),
layers: self
.layers
.into_iter()
.map(|c| c.init(device))
.collect::<Vec<_>>(),
output_pool: AdaptiveAvgPool2dConfig::new([1, 1]).init(),
output_fc: LinearConfig::new(head_planes, self.num_classes).init(device),
}
}
pub fn with_standard_drop_block_prob(
self,
drop_prob: f64,
) -> Self {
let drop_prob = expect_probability(drop_prob);
let k = self.layers.len();
let mut blocks = vec![None; k];
if drop_prob > 0.0 {
blocks[k - 2] = DropBlockOptions::default()
.with_drop_prob(drop_prob)
.with_block_size(5)
.with_gamma_scale(0.25)
.into();
blocks[k - 1] = DropBlockOptions::default()
.with_drop_prob(drop_prob)
.with_block_size(3)
.with_gamma_scale(1.0)
.into();
}
self.with_drop_block_options(blocks)
}
pub fn with_stochastic_depth_drop_path_rate(
self,
drop_path_rate: f64,
) -> Self {
let drop_path_rate = expect_probability(drop_path_rate);
let net_num_blocks = self.layers.iter().map(|b| b.len()).sum::<usize>() - self.layers.len();
let mut net_block_idx = 0;
let mut update_drop_path = |idx: usize, block: ResidualBlockStructureConfig| {
let block_dpr = drop_path_rate * (net_block_idx as f64) / ((net_num_blocks - 1) as f64);
net_block_idx += 1;
if idx != 0 && block_dpr > 0.0 {
block.with_drop_path_prob(block_dpr)
} else {
block
}
};
Self {
layers: self
.layers
.into_iter()
.map(|b| b.map_blocks(&mut update_drop_path))
.collect(),
..self
}
}
pub fn with_drop_block_options(
self,
options: Vec<Option<DropBlockOptions>>,
) -> Self {
assert_eq!(options.len(), self.layers.len());
Self {
layers: self
.layers
.into_iter()
.zip(options)
.map(|(b, o)| b.with_drop_block(o))
.collect(),
..self
}
}
}
#[derive(Module, Debug)]
pub struct ResNet<B: Backend> {
pub input_conv_norm: ConvNorm2d<B>,
pub input_act: Activation<B>,
pub input_pool: MaxPool2d,
pub layers: Vec<LayerBlock<B>>,
pub output_pool: AdaptiveAvgPool2d,
pub output_fc: Linear<B>,
}
impl<B: Backend> ResNet<B> {
pub fn debug_print(&self) {
for (idx, layer) in self.layers.iter().enumerate() {
println!(
"# Stage[{idx}]/{}:: {} :> {}",
layer.len(),
layer.in_planes(),
layer.out_planes()
);
layer.debug_print();
println!();
}
}
pub fn forward(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 2> {
let x = self.input_conv_norm.forward(input);
let x = self.input_act.forward(x);
let x = self.input_pool.forward(x);
let mut x = x;
for layer in self.layers.iter() {
x = layer.forward(x);
}
let x = self.output_pool.forward(x);
let x = x.flatten(1, 3);
self.output_fc.forward(x)
}
pub fn load_pytorch_weights(
self,
path: PathBuf,
) -> anyhow::Result<Self> {
let device = &self.devices()[0];
let stub_record = load_resnet_stub_record::<B>(path, device)?;
let adapted_target = self.with_classes(stub_record.fc.weight.dims()[0]);
Ok(stub_record.copy_stub_weights(adapted_target))
}
pub fn with_classes(
mut self,
num_classes: usize,
) -> Self {
let [d_input, _d_output] = self.output_fc.weight.dims();
self.output_fc =
LinearConfig::new(d_input, num_classes).init(&self.output_fc.weight.device());
self
}
pub fn with_stochastic_path_depth(
self,
drop_path_rate: f64,
) -> Self {
let drop_path_rate = expect_probability(drop_path_rate);
let net_num_blocks = self.layers.iter().map(|b| b.len()).sum::<usize>();
let mut net_block_idx = 0;
let mut update_drop_path = |_idx: usize, block: ResidualBlock<B>| {
let block_dpr = drop_path_rate * (net_block_idx as f64) / ((net_num_blocks - 1) as f64);
net_block_idx += 1;
if block_dpr > 0.0 {
block.with_drop_path_prob(block_dpr)
} else {
block
}
};
Self {
layers: self
.layers
.into_iter()
.map(|b| b.map_blocks(&mut update_drop_path))
.collect(),
..self
}
}
pub fn with_drop_block_options(
self,
options: Vec<Option<DropBlockOptions>>,
) -> Self {
assert_eq!(options.len(), self.layers.len());
Self {
layers: self
.layers
.into_iter()
.zip(options)
.map(|(b, o)| b.with_drop_block(o))
.collect(),
..self
}
}
pub fn with_stochastic_drop_block(
self,
drop_prob: f64,
) -> Self {
let drop_prob = expect_probability(drop_prob);
let k = self.layers.len();
let mut blocks = vec![None; k];
if drop_prob > 0.0 {
blocks[k - 2] = DropBlockOptions::default()
.with_drop_prob(drop_prob)
.with_block_size(5)
.with_gamma_scale(0.25)
.into();
blocks[k - 1] = DropBlockOptions::default()
.with_drop_prob(drop_prob)
.with_block_size(3)
.with_gamma_scale(1.0)
.into();
}
self.with_drop_block_options(blocks)
}
pub fn map_layers<F>(
self,
f: F,
) -> Self
where
F: Fn(Vec<LayerBlock<B>>) -> Vec<LayerBlock<B>>,
{
Self {
layers: f(self.layers),
..self
}
}
pub fn freeze_layers(self) -> Self {
self.map_layers(|layers| layers.into_iter().map(|layer| layer.no_grad()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::Wgpu;
#[test]
fn test_to_layers_34_basic() {
let cfg = ResNetContractConfig::new(RESNET34_BLOCKS.to_vec(), 1000);
let layers = cfg.to_layer_contracts();
println!("{:#?}", layers);
}
#[test]
fn test_to_layers_50_bottleneck() {
type B = Wgpu;
let device = Default::default();
let cfg = ResNetContractConfig::new(RESNET50_BLOCKS.to_vec(), 1000).with_bottleneck(true);
let layers = cfg.to_layer_contracts();
let first_stage = layers[0].clone();
println!("block[0] cfg:\n{:#?}", first_stage);
println!();
let blocks = first_stage
.to_block_contracts()
.into_iter()
.map(|b| b.to_structure())
.collect::<Vec<_>>();
println!("blocks ...");
println!("{:#?}", blocks);
println!();
let model: ResNet<B> = cfg.to_structure().init(&device);
model.debug_print();
}
}