use rustorch::nn::{BatchNorm1d, BatchNorm2d, Conv2d, MaxPool2d, Sequential};
use rustorch::prelude::*;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🔥 RusTorch BatchNorm Demo");
println!("========================");
let mut cnn_with_bn = Sequential::<f32>::new();
println!("🏗️ Building CNN with BatchNorm layers:");
let conv1 = Conv2d::new(3, 32, (3, 3), Some((1, 1)), Some((1, 1)), Some(true));
cnn_with_bn.add_module(conv1);
println!(" ✅ Conv2d(3→32, 3x3, padding=1)");
let bn1 = BatchNorm2d::new(32, None, None, None);
cnn_with_bn.add_module(bn1);
println!(" ✅ BatchNorm2d(32)");
let pool1 = MaxPool2d::new((2, 2), Some((2, 2)), Some((0, 0)));
cnn_with_bn.add_module(pool1);
println!(" ✅ MaxPool2d(2x2, stride=2)");
let conv2 = Conv2d::new(32, 64, (3, 3), Some((1, 1)), Some((1, 1)), Some(true));
cnn_with_bn.add_module(conv2);
println!(" ✅ Conv2d(32→64, 3x3, padding=1)");
let bn2 = BatchNorm2d::new(64, None, None, None);
cnn_with_bn.add_module(bn2);
println!(" ✅ BatchNorm2d(64)");
let pool2 = MaxPool2d::new((2, 2), Some((2, 2)), Some((0, 0)));
cnn_with_bn.add_module(pool2);
println!(" ✅ MaxPool2d(2x2, stride=2)");
let conv3 = Conv2d::new(64, 128, (3, 3), Some((1, 1)), Some((1, 1)), Some(true));
cnn_with_bn.add_module(conv3);
println!(" ✅ Conv2d(64→128, 3x3, padding=1)");
let bn3 = BatchNorm2d::new(128, None, None, None);
cnn_with_bn.add_module(bn3);
println!(" ✅ BatchNorm2d(128)");
println!("\n📊 Testing BatchNorm components individually:");
let test_bn2d = BatchNorm2d::<f32>::new(16, None, None, None);
println!("🧪 BatchNorm2d test:");
println!(" - Created with 16 channels");
println!(" - Training mode: {}", test_bn2d.is_training());
println!(" - Epsilon: {:.2e}", test_bn2d.eps());
println!(" - Momentum: {:.1}", test_bn2d.momentum());
let test_input_2d = Variable::new(
Tensor::from_vec(
(0..4 * 16 * 8 * 8).map(|i| (i as f32) * 0.01).collect(),
vec![4, 16, 8, 8],
),
true,
);
println!(
" - Input shape: {:?}",
test_input_2d.data().read().unwrap().shape()
);
let output_train = test_bn2d.forward(&test_input_2d);
println!(
" - Training output shape: {:?}",
output_train.data().read().unwrap().shape()
);
test_bn2d.eval();
let output_eval = test_bn2d.forward(&test_input_2d);
println!(
" - Evaluation output shape: {:?}",
output_eval.data().read().unwrap().shape()
);
println!(" - Evaluation mode: {}", !test_bn2d.is_training());
println!("\n🧪 BatchNorm1d test:");
let test_bn1d = BatchNorm1d::<f32>::new(128, None, None, None);
println!(" - Created with 128 features");
println!(" - Training mode: {}", test_bn1d.is_training());
let test_input_1d = Variable::new(
Tensor::from_vec(
(0..32 * 128).map(|i| (i as f32) * 0.001).collect(),
vec![32, 128],
),
true,
);
println!(
" - Input shape: {:?}",
test_input_1d.data().read().unwrap().shape()
);
let output_1d = test_bn1d.forward(&test_input_1d);
println!(
" - Output shape: {:?}",
output_1d.data().read().unwrap().shape()
);
println!("\n📈 Parameter Analysis:");
let bn2d_params = test_bn2d.parameters();
println!(" BatchNorm2d parameters: {}", bn2d_params.len());
for (i, param) in bn2d_params.iter().enumerate() {
let param_binding = param.data();
let param_data = param_binding.read().unwrap();
let param_count: usize = param_data.shape().iter().product();
println!(
" Parameter {}: shape {:?}, count: {}",
i,
param_data.shape(),
param_count
);
}
let bn1d_params = test_bn1d.parameters();
println!(" BatchNorm1d parameters: {}", bn1d_params.len());
for (i, param) in bn1d_params.iter().enumerate() {
let param_binding = param.data();
let param_data = param_binding.read().unwrap();
let param_count: usize = param_data.shape().iter().product();
println!(
" Parameter {}: shape {:?}, count: {}",
i,
param_data.shape(),
param_count
);
}
println!("\n🎯 Training Simulation:");
let conv_layer = Conv2d::<f32>::new(3, 16, (3, 3), Some((1, 1)), Some((1, 1)), Some(true));
let bn_layer = BatchNorm2d::<f32>::new(16, None, None, None);
let train_input = Variable::new(
Tensor::from_vec(
(0..8 * 3 * 16 * 16)
.map(|i| ((i as f32) / 1000.0).sin())
.collect(),
vec![8, 3, 16, 16],
),
true,
);
let target = Variable::new(Tensor::ones(&[8, 16, 16, 16]), false);
println!(
" Training batch input shape: {:?}",
train_input.data().read().unwrap().shape()
);
let mut all_params = conv_layer.parameters();
all_params.extend(bn_layer.parameters());
let mut optimizer = SGD::new(0.01);
for epoch in 0..3 {
bn_layer.train();
let conv_out = conv_layer.forward(&train_input);
let bn_out = bn_layer.forward(&conv_out);
let diff = &bn_out - ⌖
let loss = (&diff * &diff).sum().mean_autograd();
loss.backward();
for param in &all_params {
let param_data = param.data();
let param_tensor = param_data.read().unwrap();
let grad_data = param.grad();
let grad_guard = grad_data.read().unwrap();
if let Some(ref grad_tensor) = *grad_guard {
optimizer.step(¶m_tensor, grad_tensor);
}
}
println!(
" Epoch {}: Loss shape: {:?}",
epoch + 1,
loss.data().read().unwrap().shape()
);
}
println!("\n🔍 Evaluation Mode Test:");
bn_layer.eval();
let eval_out = conv_layer.forward(&train_input);
let eval_bn_out = bn_layer.forward(&eval_out);
println!(
" Evaluation output shape: {:?}",
eval_bn_out.data().read().unwrap().shape()
);
println!(" BatchNorm in eval mode: {}", !bn_layer.is_training());
println!("\n📊 Running Statistics:");
let running_mean = bn_layer.running_mean();
let running_var = bn_layer.running_var();
println!(" Running mean shape: {:?}", running_mean.shape());
println!(" Running var shape: {:?}", running_var.shape());
println!("\n🎉 BatchNorm Demo completed successfully!");
println!(" - BatchNorm1d for fully connected layers ✅");
println!(" - BatchNorm2d for convolutional layers ✅");
println!(" - Training/Evaluation mode switching ✅");
println!(" - Running statistics tracking ✅");
println!(" - Parameter management ✅");
println!(" - Integration with optimizers ✅");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batchnorm_integration() {
let conv = Conv2d::<f32>::new(1, 4, (3, 3), Some((1, 1)), Some((1, 1)), Some(true));
let bn = BatchNorm2d::<f32>::new(4, None, None, None);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 1 * 8 * 8], vec![1, 1, 8, 8]),
false,
);
let conv_out = conv.forward(&input);
let bn_out = bn.forward(&conv_out);
assert_eq!(bn_out.data().read().unwrap().shape(), &[1, 4, 8, 8]);
let bn1d = BatchNorm1d::<f32>::new(10, None, None, None);
let input_1d = Variable::new(Tensor::from_vec(vec![0.1; 5 * 10], vec![5, 10]), false);
let output_1d = bn1d.forward(&input_1d);
assert_eq!(output_1d.data().read().unwrap().shape(), &[5, 10]);
}
}