use std::collections::HashMap;
use axonml_autograd::Variable;
use crate::activation::ReLU;
use crate::module::Module;
use crate::parameter::Parameter;
use crate::sequential::Sequential;
pub struct ResidualBlock {
main_path: Sequential,
downsample: Option<Sequential>,
activation: Option<Box<dyn Module>>,
training: bool,
}
impl ResidualBlock {
pub fn new(main_path: Sequential) -> Self {
Self {
main_path,
downsample: None,
activation: Some(Box::new(ReLU)),
training: true,
}
}
pub fn with_downsample(mut self, downsample: Sequential) -> Self {
self.downsample = Some(downsample);
self
}
pub fn with_activation<M: Module + 'static>(mut self, activation: M) -> Self {
self.activation = Some(Box::new(activation));
self
}
pub fn without_activation(mut self) -> Self {
self.activation = None;
self
}
}
impl Module for ResidualBlock {
fn forward(&self, input: &Variable) -> Variable {
let identity = match &self.downsample {
Some(ds) => ds.forward(input),
None => input.clone(),
};
let out = self.main_path.forward(input);
let out = out.add_var(&identity);
match &self.activation {
Some(act) => act.forward(&out),
None => out,
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = self.main_path.parameters();
if let Some(ds) = &self.downsample {
params.extend(ds.parameters());
}
if let Some(act) = &self.activation {
params.extend(act.parameters());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.main_path.named_parameters() {
params.insert(format!("main_path.{name}"), param);
}
if let Some(ds) = &self.downsample {
for (name, param) in ds.named_parameters() {
params.insert(format!("downsample.{name}"), param);
}
}
if let Some(act) = &self.activation {
for (name, param) in act.named_parameters() {
params.insert(format!("activation.{name}"), param);
}
}
params
}
fn set_training(&mut self, training: bool) {
self.training = training;
self.main_path.set_training(training);
if let Some(ds) = &mut self.downsample {
ds.set_training(training);
}
if let Some(act) = &mut self.activation {
act.set_training(training);
}
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &'static str {
"ResidualBlock"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::activation::{GELU, ReLU};
use crate::layers::{BatchNorm1d, Conv1d, Linear};
use axonml_tensor::Tensor;
#[test]
fn test_residual_block_identity_skip() {
let main = Sequential::new()
.add(Linear::new(32, 32))
.add(ReLU)
.add(Linear::new(32, 32));
let block = ResidualBlock::new(main);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![2, 32]);
}
#[test]
fn test_residual_block_with_downsample() {
let main = Sequential::new()
.add(Linear::new(32, 64))
.add(ReLU)
.add(Linear::new(64, 64));
let downsample = Sequential::new().add(Linear::new(32, 64));
let block = ResidualBlock::new(main).with_downsample(downsample);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 64], &[2, 32]).expect("tensor creation failed"),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![2, 64]);
}
#[test]
fn test_residual_block_custom_activation() {
let main = Sequential::new().add(Linear::new(16, 16));
let block = ResidualBlock::new(main).with_activation(GELU);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![2, 16]);
}
#[test]
fn test_residual_block_no_activation() {
let main = Sequential::new().add(Linear::new(16, 16));
let block = ResidualBlock::new(main).without_activation();
let input = Variable::new(
Tensor::from_vec(vec![1.0; 32], &[2, 16]).expect("tensor creation failed"),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![2, 16]);
}
#[test]
fn test_residual_block_parameters() {
let main = Sequential::new()
.add(Linear::new(32, 32)) .add(Linear::new(32, 32));
let block = ResidualBlock::new(main);
let params = block.parameters();
assert_eq!(params.len(), 4); }
#[test]
fn test_residual_block_named_parameters() {
let main = Sequential::new()
.add_named("conv1", Linear::new(32, 32))
.add_named("conv2", Linear::new(32, 32));
let downsample = Sequential::new().add_named("proj", Linear::new(32, 32));
let block = ResidualBlock::new(main).with_downsample(downsample);
let params = block.named_parameters();
assert!(params.contains_key("main_path.conv1.weight"));
assert!(params.contains_key("main_path.conv2.weight"));
assert!(params.contains_key("downsample.proj.weight"));
}
#[test]
fn test_residual_block_training_mode() {
let main = Sequential::new()
.add(BatchNorm1d::new(32))
.add(Linear::new(32, 32));
let mut block = ResidualBlock::new(main);
assert!(block.is_training());
block.set_training(false);
assert!(!block.is_training());
block.set_training(true);
assert!(block.is_training());
}
#[test]
fn test_residual_block_conv1d_with_downsample() {
let main = Sequential::new()
.add(Conv1d::new(64, 64, 3))
.add(BatchNorm1d::new(64))
.add(ReLU)
.add(Conv1d::new(64, 64, 3))
.add(BatchNorm1d::new(64));
let downsample = Sequential::new()
.add(Conv1d::new(64, 64, 5))
.add(BatchNorm1d::new(64));
let block = ResidualBlock::new(main).with_downsample(downsample);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 64 * 20], &[2, 64, 20]).expect("tensor creation failed"),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape()[0], 2);
assert_eq!(output.shape()[1], 64);
assert_eq!(output.shape()[2], 16);
}
#[test]
fn test_residual_block_gradient_flow() {
let main = Sequential::new().add(Linear::new(4, 4));
let block = ResidualBlock::new(main);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4]).expect("tensor creation failed"),
true,
);
let output = block.forward(&input);
let sum = output.sum();
sum.backward();
let params = block.parameters();
assert!(!params.is_empty());
}
}