use crate::asg::NodeType;
use crate::nn::init::Initializer;
use crate::nn::Module;
use crate::tensor::{GraphContext, Tensor};
use std::cell::RefCell;
use std::rc::Rc;
const DEFAULT_EPS: f32 = 1e-5;
const DEFAULT_MOMENTUM: f32 = 0.1;
pub struct BatchNorm {
pub gamma: Tensor,
pub beta: Tensor,
pub eps: f32,
pub channel_axis: usize,
pub momentum: f32,
pub training: bool,
pub name: String,
pub num_features: usize,
}
impl BatchNorm {
pub fn new(ctx: &Rc<RefCell<GraphContext>>, name: &str, num_features: usize) -> Self {
Self::with_axis(ctx, name, num_features, 1)
}
pub fn with_axis(
ctx: &Rc<RefCell<GraphContext>>,
name: &str,
num_features: usize,
channel_axis: usize,
) -> Self {
let gamma = Tensor::new_parameter_with_shape(
ctx,
&format!("{}.gamma", name),
vec![num_features],
Initializer::Ones,
);
let beta = Tensor::new_parameter_with_shape(
ctx,
&format!("{}.beta", name),
vec![num_features],
Initializer::Zeros,
);
Self {
gamma,
beta,
eps: DEFAULT_EPS,
channel_axis,
momentum: DEFAULT_MOMENTUM,
training: true,
name: name.to_string(),
num_features,
}
}
pub fn with_momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn with_eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
pub fn train(&mut self) {
self.training = true;
}
pub fn eval(&mut self) {
self.training = false;
}
}
impl Module for BatchNorm {
fn forward(&self, x: &Tensor) -> Tensor {
let ctx = &x.context;
let node_id = ctx.borrow_mut().main_graph_mut().add_node(
None,
NodeType::BatchNorm {
input: x.node_id,
gamma: self.gamma.node_id,
beta: self.beta.node_id,
eps: self.eps,
channel_axis: self.channel_axis,
},
);
Tensor {
node_id,
context: Rc::clone(ctx),
}
}
fn parameters(&self) -> Vec<Tensor> {
vec![self.gamma.clone(), self.beta.clone()]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn batchnorm_registers_shapes() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let _bn = BatchNorm::new(&ctx, "bn1", 32);
let borrowed = ctx.borrow();
assert_eq!(
borrowed.parameter_meta("bn1.gamma").unwrap().shape,
vec![32]
);
assert_eq!(borrowed.parameter_meta("bn1.beta").unwrap().shape, vec![32]);
}
#[test]
fn batchnorm_train_eval_toggle() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let mut bn = BatchNorm::new(&ctx, "bn1", 16);
bn.eval();
assert!(!bn.training);
bn.train();
assert!(bn.training);
}
#[test]
fn batchnorm_default_channel_axis_is_one() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let bn = BatchNorm::new(&ctx, "bn", 8);
assert_eq!(bn.channel_axis, 1);
}
#[test]
fn batchnorm_backward_matches_hand_calc() {
use crate::analysis::shape_inference::ShapeInference;
use crate::asg::{DType, Value};
use crate::runtime::backend::Backend;
use crate::runtime::cpu_backend::CpuBackend;
use ndarray::{array, ArrayD};
use std::collections::HashMap;
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let dy = Tensor::new_input(&ctx, "dy");
let x = Tensor::new_input(&ctx, "x");
let gamma = Tensor::new_input(&ctx, "gamma");
let backward_id = ctx.borrow_mut().main_graph_mut().add_node(
None,
crate::asg::NodeType::BatchNormBackward {
grad_output: dy.node_id,
input: x.node_id,
gamma: gamma.node_id,
eps: 0.0,
channel_axis: 1,
},
);
ctx.borrow_mut().main_graph_mut().set_output(backward_id);
let mut shapes = HashMap::new();
shapes.insert("dy".to_string(), (vec![3, 3], DType::F32));
shapes.insert("x".to_string(), (vec![3, 3], DType::F32));
shapes.insert("gamma".to_string(), (vec![3], DType::F32));
let mut g = ctx.borrow().main_graph().clone();
ShapeInference::run(&mut g, &shapes).unwrap();
let mut data = HashMap::new();
data.insert(
"x".to_string(),
Value::Tensor(array![[1.0_f32, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]].into_dyn()),
);
data.insert(
"gamma".to_string(),
Value::Tensor(array![1.0_f32, 1.0, 1.0].into_dyn()),
);
data.insert(
"dy".to_string(),
Value::Tensor(array![[1.0_f32, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]].into_dyn()),
);
let backend = CpuBackend::new();
let device = backend.load_data(&data).unwrap();
let mut memo = HashMap::new();
for (name, val) in device {
let nid = g
.nodes
.iter()
.find(|(_, n)| {
matches!(&n.node_type,
crate::asg::NodeType::Input { name: nn } if nn == &name)
})
.map(|(id, _)| *id)
.unwrap();
memo.insert((g.id, nid), val);
}
let (out, _) = backend.run(&g, memo).unwrap();
let result = match &backend.retrieve_data(&out).unwrap()[0] {
Value::Tensor(t) => t.clone(),
_ => panic!("expected tensor"),
};
let expected: ArrayD<f32> = array![
[0.2041_f32, 0.0, 0.0],
[-0.4082, 0.0, 0.0],
[0.2041, 0.0, 0.0],
]
.into_dyn();
for (i, (a, b)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-3,
"BatchNorm backward mismatch at idx {}: got {} expected {}",
i,
a,
b
);
}
}
#[test]
fn batchnorm_backward_via_autograd() {
use crate::analysis::shape_inference::ShapeInference;
use crate::asg::{DType, Value};
use crate::autograd::Gradients;
use crate::runtime::backend::Backend;
use crate::runtime::cpu_backend::CpuBackend;
use ndarray::{array, ArrayD};
use std::collections::HashMap;
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let x = Tensor::new_input(&ctx, "x");
let gamma = Tensor::new_input(&ctx, "gamma");
let beta = Tensor::new_input(&ctx, "beta");
let mask = Tensor::new_input(&ctx, "mask");
let bn_id = ctx.borrow_mut().main_graph_mut().add_node(
None,
crate::asg::NodeType::BatchNorm {
input: x.node_id,
gamma: gamma.node_id,
beta: beta.node_id,
eps: 0.0,
channel_axis: 1,
},
);
let bn = Tensor {
node_id: bn_id,
context: Rc::clone(&ctx),
};
let masked = &bn * &mask;
let loss = masked.sum();
let mut shapes = HashMap::new();
shapes.insert("x".to_string(), (vec![3, 3], DType::F32));
shapes.insert("gamma".to_string(), (vec![3], DType::F32));
shapes.insert("beta".to_string(), (vec![3], DType::F32));
shapes.insert("mask".to_string(), (vec![3, 3], DType::F32));
let mut fwd = ctx.borrow().main_graph().clone();
fwd.set_output(loss.node_id);
ShapeInference::run(&mut fwd, &shapes).unwrap();
let grad_graph = Gradients::new(fwd.clone())
.build(loss.node_id, &[x.node_id])
.expect("grad build");
let mut data = HashMap::new();
data.insert(
"x".to_string(),
Value::Tensor(array![[1.0_f32, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]].into_dyn()),
);
data.insert(
"gamma".to_string(),
Value::Tensor(array![1.0_f32, 1.0, 1.0].into_dyn()),
);
data.insert(
"beta".to_string(),
Value::Tensor(array![0.0_f32, 0.0, 0.0].into_dyn()),
);
data.insert(
"mask".to_string(),
Value::Tensor(array![[1.0_f32, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]].into_dyn()),
);
let backend = CpuBackend::new();
let device = backend.load_data(&data).unwrap();
let mut memo = HashMap::new();
for (name, val) in device {
let nid = fwd
.nodes
.iter()
.find(|(_, n)| {
matches!(&n.node_type,
crate::asg::NodeType::Input { name: nn } if nn == &name)
})
.map(|(id, _)| *id)
.unwrap();
memo.insert((fwd.id, nid), val);
}
let (_, fwd_memo) = backend.run(&fwd, memo).unwrap();
let (grad_out, _) = backend.run(&grad_graph, fwd_memo).unwrap();
let result = match &backend.retrieve_data(&grad_out).unwrap()[0] {
Value::Tensor(t) => t.clone(),
_ => panic!("expected tensor"),
};
let expected: ArrayD<f32> = array![
[0.2041_f32, 0.0, 0.0],
[-0.4082, 0.0, 0.0],
[0.2041, 0.0, 0.0],
]
.into_dyn();
eprintln!("autograd dx = {:?}", result.as_slice().unwrap());
eprintln!("expected dx = {:?}", expected.as_slice().unwrap());
for (i, (a, b)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-3,
"autograd dx mismatch at idx {}: got {} expected {}",
i,
a,
b
);
}
}
#[test]
fn batchnorm_forward_matches_hand_calc() {
use crate::analysis::shape_inference::ShapeInference;
use crate::asg::{DType, Value};
use crate::runtime::backend::Backend;
use crate::runtime::cpu_backend::CpuBackend;
use ndarray::{array, ArrayD};
use std::collections::HashMap;
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let x = Tensor::new_input(&ctx, "x");
let gamma = Tensor::new_input(&ctx, "gamma");
let beta = Tensor::new_input(&ctx, "beta");
let bn_id = ctx.borrow_mut().main_graph_mut().add_node(
None,
crate::asg::NodeType::BatchNorm {
input: x.node_id,
gamma: gamma.node_id,
beta: beta.node_id,
eps: 0.0,
channel_axis: 1,
},
);
ctx.borrow_mut().main_graph_mut().set_output(bn_id);
let mut shapes = HashMap::new();
shapes.insert("x".to_string(), (vec![2, 3], DType::F32));
shapes.insert("gamma".to_string(), (vec![3], DType::F32));
shapes.insert("beta".to_string(), (vec![3], DType::F32));
let mut g = ctx.borrow().main_graph().clone();
ShapeInference::run(&mut g, &shapes).unwrap();
let x_data: ArrayD<f32> = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].into_dyn();
let gamma_data: ArrayD<f32> = array![1.0, 1.0, 1.0].into_dyn();
let beta_data: ArrayD<f32> = array![0.0, 0.0, 0.0].into_dyn();
let mut data = HashMap::new();
data.insert("x".to_string(), Value::Tensor(x_data));
data.insert("gamma".to_string(), Value::Tensor(gamma_data));
data.insert("beta".to_string(), Value::Tensor(beta_data));
let backend = CpuBackend::new();
let device = backend.load_data(&data).unwrap();
let mut memo = HashMap::new();
for (name, val) in device {
let nid = g
.nodes
.iter()
.find(|(_, n)| {
matches!(&n.node_type,
crate::asg::NodeType::Input { name: nn } if nn == &name)
})
.map(|(id, _)| *id)
.unwrap();
memo.insert((g.id, nid), val);
}
let (out, _) = backend.run(&g, memo).unwrap();
let result = match &backend.retrieve_data(&out).unwrap()[0] {
Value::Tensor(t) => t.clone(),
_ => panic!("expected tensor"),
};
let expected = array![[-1.0_f32, -1.0, -1.0], [1.0, 1.0, 1.0]].into_dyn();
for (a, b) in result.iter().zip(expected.iter()) {
assert!(
(a - b).abs() < 1e-5,
"BatchNorm forward mismatch: got {} expected {}",
a,
b
);
}
}
}