#![cfg(feature = "neural_network")]
use approx::assert_abs_diff_eq;
use ndarray::Array;
use rustyml::neural_network::layer::activation_layer::linear::Linear;
use rustyml::neural_network::layer::dense::Dense;
use rustyml::neural_network::layer::layer_weight::LayerWeight;
use rustyml::neural_network::layer::regularization_layer::normalization_layer::batch_normalization::BatchNormalization;
use rustyml::neural_network::layer::regularization_layer::normalization_layer::layer_normalization::{LayerNormalization, LayerNormalizationAxis};
use rustyml::neural_network::layer::serialize_weight::SerializableLayerWeight;
use rustyml::neural_network::layer::TrainingParameters;
use rustyml::neural_network::loss_function::mean_squared_error::MeanSquaredError;
use rustyml::neural_network::neural_network_trait::Layer;
use rustyml::neural_network::optimizer::sgd::SGD;
use rustyml::neural_network::sequential::Sequential;
#[test]
fn test_layer_normalization_forward_pass_dimensions() {
let mut ln =
LayerNormalization::new(vec![4, 8], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::ones((4, 8)).into_dyn();
let output = ln.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 8]);
println!(
"Forward pass dimension test passed: {:?} -> {:?}",
input.shape(),
output.shape()
);
}
#[test]
fn test_layer_normalization_default_axis() {
let mut ln =
LayerNormalization::new(vec![3, 4], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap()
.into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
let output_2d = output
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
for sample_idx in 0..3 {
let sample_row = output_2d.row(sample_idx);
let mean: f32 = sample_row.mean().unwrap();
let variance: f32 = sample_row.mapv(|x| (x - mean).powi(2)).mean().unwrap();
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(variance, 1.0, epsilon = 1e-4);
}
println!("Default axis normalization test passed");
}
#[test]
fn test_layer_normalization_custom_axis() {
let mut ln =
LayerNormalization::new(vec![3, 4], LayerNormalizationAxis::Custom(0), 1e-5).unwrap();
let input = Array::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap()
.into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
let output_2d = output
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
for col_idx in 0..4 {
let col = output_2d.column(col_idx);
let mean: f32 = col.mean().unwrap();
let variance: f32 = col.mapv(|x| (x - mean).powi(2)).mean().unwrap();
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(variance, 1.0, epsilon = 1e-4);
}
println!("Custom axis normalization test passed");
}
#[test]
fn test_layer_normalization_invalid_axis() {
let mut ln =
LayerNormalization::new(vec![3, 4], LayerNormalizationAxis::Custom(5), 1e-5).unwrap();
let input = Array::ones((3, 4)).into_dyn();
let result = ln.forward(&input);
assert!(result.is_err(), "Should return error for invalid axis");
println!("Invalid axis error handling test passed");
}
#[test]
fn test_layer_normalization_training_mode() {
let mut ln =
LayerNormalization::new(vec![2, 6], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec(
(2, 6),
vec![
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 15.0, 25.0, 35.0, 45.0, 55.0, 65.0,
],
)
.unwrap()
.into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 6]);
let output_2d = output
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
for sample_idx in 0..2 {
let sample = output_2d.row(sample_idx);
let mean: f32 = sample.mean().unwrap();
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-5);
}
println!("Training mode test passed");
}
#[test]
fn test_layer_normalization_inference_mode() {
let mut ln =
LayerNormalization::new(vec![2, 4], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.unwrap()
.into_dyn();
ln.set_training(true);
ln.forward(&input).unwrap();
ln.set_training(false);
let test_input = Array::from_shape_vec((2, 4), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.unwrap()
.into_dyn();
let output = ln.forward(&test_input).unwrap();
assert_eq!(output.shape(), &[2, 4]);
println!("Inference mode test passed");
}
#[test]
fn test_layer_normalization_backward_pass() {
let mut ln =
LayerNormalization::new(vec![3, 4], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec(
(3, 4),
vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0],
)
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::ones((3, 4)).into_dyn();
let grad_input = ln.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), input.shape());
let grad_2d = grad_input
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
for sample_idx in 0..3 {
let sample_grad_sum: f32 = grad_2d.row(sample_idx).sum();
assert_abs_diff_eq!(sample_grad_sum, 0.0, epsilon = 1e-4);
}
println!("Backward pass test passed");
}
#[test]
fn test_layer_normalization_parameter_update_sgd() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::ones((2, 3)).into_dyn();
let _grad_input = ln.backward(&grad_output).unwrap();
if let LayerWeight::LayerNormalizationLayer(weights) = ln.get_weights() {
let initial_gamma = weights.gamma.clone();
let initial_beta = weights.beta.clone();
ln.update_parameters_sgd(0.01);
if let LayerWeight::LayerNormalizationLayer(updated_weights) = ln.get_weights() {
let gamma_changed = updated_weights
.gamma
.as_slice()
.unwrap()
.iter()
.zip(initial_gamma.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
let beta_changed = updated_weights
.beta
.as_slice()
.unwrap()
.iter()
.zip(initial_beta.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
assert!(
gamma_changed || beta_changed,
"Parameters should be updated"
);
}
}
println!("SGD parameter update test passed");
}
#[test]
fn test_layer_normalization_parameter_update_adam() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::ones((2, 3)).into_dyn();
let _grad_input = ln.backward(&grad_output).unwrap();
if let LayerWeight::LayerNormalizationLayer(weights) = ln.get_weights() {
let initial_gamma = weights.gamma.clone();
let initial_beta = weights.beta.clone();
ln.update_parameters_adam(0.001, 0.9, 0.999, 1e-8, 1);
if let LayerWeight::LayerNormalizationLayer(updated_weights) = ln.get_weights() {
let gamma_changed = updated_weights
.gamma
.as_slice()
.unwrap()
.iter()
.zip(initial_gamma.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
let beta_changed = updated_weights
.beta
.as_slice()
.unwrap()
.iter()
.zip(initial_beta.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
assert!(
gamma_changed || beta_changed,
"Parameters should be updated"
);
}
}
println!("Adam parameter update test passed");
}
#[test]
fn test_layer_normalization_parameter_update_rmsprop() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::ones((2, 3)).into_dyn();
let _grad_input = ln.backward(&grad_output).unwrap();
if let LayerWeight::LayerNormalizationLayer(weights) = ln.get_weights() {
let initial_gamma = weights.gamma.clone();
let initial_beta = weights.beta.clone();
ln.update_parameters_rmsprop(0.001, 0.9, 1e-8);
if let LayerWeight::LayerNormalizationLayer(updated_weights) = ln.get_weights() {
let gamma_changed = updated_weights
.gamma
.as_slice()
.unwrap()
.iter()
.zip(initial_gamma.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
let beta_changed = updated_weights
.beta
.as_slice()
.unwrap()
.iter()
.zip(initial_beta.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
assert!(
gamma_changed || beta_changed,
"Parameters should be updated"
);
}
}
println!("RMSprop parameter update test passed");
}
#[test]
fn test_layer_normalization_parameter_update_adagrad() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::ones((2, 3)).into_dyn();
let _grad_input = ln.backward(&grad_output).unwrap();
if let LayerWeight::LayerNormalizationLayer(weights) = ln.get_weights() {
let initial_gamma = weights.gamma.clone();
let initial_beta = weights.beta.clone();
ln.update_parameters_ada_grad(0.01, 1e-8);
if let LayerWeight::LayerNormalizationLayer(updated_weights) = ln.get_weights() {
let gamma_changed = updated_weights
.gamma
.as_slice()
.unwrap()
.iter()
.zip(initial_gamma.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
let beta_changed = updated_weights
.beta
.as_slice()
.unwrap()
.iter()
.zip(initial_beta.as_slice().unwrap().iter())
.any(|(a, b)| (a - b).abs() > 1e-6);
assert!(
gamma_changed || beta_changed,
"Parameters should be updated"
);
}
}
println!("AdaGrad parameter update test passed");
}
#[test]
fn test_layer_normalization_different_batch_sizes() {
let batch_sizes = vec![2, 4, 8, 16];
let features = 5;
for batch_size in batch_sizes {
let mut ln = LayerNormalization::new(
vec![batch_size, features],
LayerNormalizationAxis::Default,
1e-5,
)
.unwrap();
let input =
Array::from_shape_fn((batch_size, features), |(i, j)| (i * features + j) as f32)
.into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
assert_eq!(
output.shape(),
&[batch_size, features],
"Output shape should match input shape"
);
println!(
"Batch size {} test passed: {:?} -> {:?}",
batch_size,
input.shape(),
output.shape()
);
}
}
#[test]
fn test_layer_normalization_parameter_count() {
let ln = LayerNormalization::new(vec![4, 10], LayerNormalizationAxis::Default, 1e-5).unwrap();
let expected_params = 10 + 10; assert_eq!(
ln.param_count(),
TrainingParameters::Trainable(expected_params)
);
println!(
"Parameter count test passed: {} parameters",
expected_params
);
}
#[test]
fn test_layer_normalization_layer_type() {
let ln = LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
assert_eq!(ln.layer_type(), "LayerNormalization");
println!("Layer type test passed");
}
#[test]
fn test_layer_normalization_output_shape() {
let ln = LayerNormalization::new(vec![4, 8], LayerNormalizationAxis::Default, 1e-5).unwrap();
let output_shape = ln.output_shape();
assert_eq!(output_shape, "(4, 8)");
println!("Output shape test passed: {}", output_shape);
}
#[test]
fn test_layer_normalization_with_sequential_model() {
let mut model = Sequential::new();
model.add(Dense::new(4, 8, Linear::new()).unwrap());
model.add(LayerNormalization::new(vec![2, 8], LayerNormalizationAxis::Default, 1e-5).unwrap());
model.add(Dense::new(8, 1, Linear::new()).unwrap());
model.compile(SGD::new(0.01).unwrap(), MeanSquaredError::new());
let input = Array::ones((2, 4)).into_dyn();
let target = Array::ones((2, 1)).into_dyn();
let output = model.predict(&input).unwrap();
assert_eq!(output.shape(), &[2, 1]);
let result = model.fit(&input, &target, 10);
assert!(result.is_ok(), "Training should succeed");
println!("Sequential model integration test passed");
}
#[test]
fn test_layer_normalization_set_weights() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let new_gamma = Array::from_vec(vec![2.0, 2.0, 2.0]).into_dyn();
let new_beta = Array::from_vec(vec![1.0, 1.0, 1.0]).into_dyn();
ln.set_weights(new_gamma.clone(), new_beta.clone());
if let LayerWeight::LayerNormalizationLayer(weights) = ln.get_weights() {
assert_eq!(
weights.gamma.as_slice().unwrap(),
new_gamma.as_slice().unwrap()
);
assert_eq!(
weights.beta.as_slice().unwrap(),
new_beta.as_slice().unwrap()
);
}
println!("Set weights test passed");
}
#[test]
fn test_layer_normalization_3d_input() {
let batch_size = 2;
let sequence_len = 4;
let features = 3;
let mut ln = LayerNormalization::new(
vec![batch_size, sequence_len, features],
LayerNormalizationAxis::Default,
1e-5,
)
.unwrap();
let input = Array::from_shape_fn((batch_size, sequence_len, features), |(i, j, k)| {
(i * 100 + j * 10 + k) as f32
})
.into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
assert_eq!(output.shape(), &[batch_size, sequence_len, features]);
println!("3D input test passed");
}
#[test]
fn test_layer_normalization_epsilon_effect() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 3), vec![5.0, 5.0, 5.0, 3.0, 3.0, 3.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let result = ln.forward(&input);
assert!(
result.is_ok(),
"Forward pass should succeed with zero variance"
);
let output = result.unwrap();
assert!(
output.iter().all(|&x| x.is_finite()),
"Output should not contain NaN or Inf"
);
println!("Epsilon effect test passed");
}
#[test]
fn test_layer_normalization_vs_batch_normalization_difference() {
let input = Array::from_shape_vec(
(3, 4),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap()
.into_dyn();
let mut ln =
LayerNormalization::new(vec![3, 4], LayerNormalizationAxis::Default, 1e-5).unwrap();
ln.set_training(true);
let ln_output = ln.forward(&input).unwrap();
let mut bn = BatchNormalization::new(vec![3, 4], 0.9, 1e-5).unwrap();
bn.set_training(true);
let bn_output = bn.forward(&input).unwrap();
let ln_2d = ln_output
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
let bn_2d = bn_output
.as_standard_layout()
.into_dimensionality::<ndarray::Ix2>()
.unwrap();
let outputs_different = ln_2d
.iter()
.zip(bn_2d.iter())
.any(|(a, b)| (a - b).abs() > 1e-5);
assert!(
outputs_different,
"Layer norm and batch norm should produce different outputs"
);
println!("Layer norm vs batch norm difference test passed");
}
#[test]
fn test_layer_normalization_gradient_flow() {
let mut ln =
LayerNormalization::new(vec![2, 4], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.unwrap()
.into_dyn();
ln.set_training(true);
let _output = ln.forward(&input).unwrap();
let grad_output = Array::from_shape_vec((2, 4), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
.unwrap()
.into_dyn();
let grad_input = ln.backward(&grad_output).unwrap();
let has_nonzero_grad = grad_input.iter().any(|&x| x.abs() > 1e-6);
assert!(has_nonzero_grad, "Gradients should flow through the layer");
assert!(
grad_input.iter().all(|&x| x.is_finite()),
"Gradients should be finite"
);
println!("Gradient flow test passed");
}
#[test]
fn test_layer_normalization_multiple_forward_backward() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
for i in 0..5 {
let input = Array::from_shape_fn((2, 3), |(b, f)| (i + b + f) as f32).into_dyn();
ln.set_training(true);
let output = ln.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 3]);
let grad_output = Array::ones((2, 3)).into_dyn();
let grad_input = ln.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), &[2, 3]);
ln.update_parameters_sgd(0.01);
}
println!("Multiple forward-backward passes test passed");
}
#[test]
fn test_layer_normalization_serialization() {
let mut ln =
LayerNormalization::new(vec![2, 3], LayerNormalizationAxis::Default, 1e-5).unwrap();
let input = Array::ones((2, 3)).into_dyn();
ln.set_training(true);
ln.forward(&input).unwrap();
let weights = ln.get_weights();
let serializable_weights = SerializableLayerWeight::from_layer_weight(&weights);
match serializable_weights {
SerializableLayerWeight::LayerNormalization(w) => {
assert_eq!(w.gamma.len(), 3);
assert_eq!(w.beta.len(), 3);
assert_eq!(w.shape, vec![3]);
}
_ => panic!("Expected LayerNormalization weights"),
}
println!("Serialization test passed");
}