#![cfg(feature = "neural_network")]
use approx::assert_abs_diff_eq;
use ndarray::{Array, Array4, IxDyn};
use rustyml::neural_network::layer::regularization_layer::dropout_layer::spatial_dropout_2d::SpatialDropout2D;
use rustyml::neural_network::neural_network_trait::Layer;
#[test]
fn test_spatial_dropout_2d_forward_pass_dimensions() {
let mut dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
let input = Array4::ones((2, 8, 10, 10)).into_dyn();
let output = dropout.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8, 10, 10]);
println!(
"SpatialDropout2D dimension test passed: {:?} -> {:?}",
input.shape(),
output.shape()
);
}
#[test]
fn test_spatial_dropout_2d_training_mode() {
let mut dropout = SpatialDropout2D::new(0.5, vec![10, 20, 16, 16]).unwrap();
dropout.set_training(true);
let input = Array4::ones((10, 20, 16, 16)).into_dyn();
let output = dropout.forward(&input).unwrap();
let shape = output.shape();
let batch_size = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let mut dropped_channel_count = 0;
for b in 0..batch_size {
for c in 0..channels {
let mut channel_sum = 0.0;
for h in 0..height {
for w in 0..width {
channel_sum += output[[b, c, h, w]];
}
}
if channel_sum == 0.0 {
dropped_channel_count += 1;
} else {
let first_val = output[[b, c, 0, 0]];
for h in 0..height {
for w in 0..width {
assert_abs_diff_eq!(output[[b, c, h, w]], first_val, epsilon = 0.001);
}
}
}
}
}
let total_channels = batch_size * channels;
let drop_ratio = dropped_channel_count as f32 / total_channels as f32;
assert!(
drop_ratio > 0.3 && drop_ratio < 0.7,
"Expected ~50% channels dropped, got {:.1}%",
drop_ratio * 100.0
);
println!(
"SpatialDropout2D training mode test passed: {}/{} channels dropped ({:.1}%)",
dropped_channel_count,
total_channels,
drop_ratio * 100.0
);
}
#[test]
fn test_spatial_dropout_2d_feature_map_consistency() {
let mut dropout = SpatialDropout2D::new(0.3, vec![5, 10, 8, 8]).unwrap();
dropout.set_training(true);
let input = Array4::from_shape_fn((5, 10, 8, 8), |(b, c, h, w)| {
(b * 640 + c * 64 + h * 8 + w + 1) as f32 })
.into_dyn();
let output = dropout.forward(&input).unwrap();
for b in 0..5 {
for c in 0..10 {
let mut channel_sum = 0.0;
for h in 0..8 {
for w in 0..8 {
channel_sum += output[[b, c, h, w]];
}
}
if channel_sum == 0.0 {
for h in 0..8 {
for w in 0..8 {
assert_eq!(output[[b, c, h, w]], 0.0);
}
}
} else {
let scale_factor = output[[b, c, 0, 0]] / input[[b, c, 0, 0]];
for h in 0..8 {
for w in 0..8 {
let expected = input[[b, c, h, w]] * scale_factor;
assert_abs_diff_eq!(output[[b, c, h, w]], expected, epsilon = 0.001);
}
}
}
}
}
println!("SpatialDropout2D feature map consistency test passed");
}
#[test]
fn test_spatial_dropout_2d_inference_mode() {
let mut dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
dropout.set_training(false);
let input = Array4::from_shape_fn((2, 8, 10, 10), |(i, j, k, l)| {
(i * 800 + j * 100 + k * 10 + l) as f32
})
.into_dyn();
let output = dropout.forward(&input).unwrap();
assert_eq!(output, input);
println!("SpatialDropout2D inference mode test passed: output equals input");
}
#[test]
fn test_spatial_dropout_2d_rate_zero() {
let mut dropout = SpatialDropout2D::new(0.0, vec![2, 8, 10, 10]).unwrap();
dropout.set_training(true);
let input = Array4::ones((2, 8, 10, 10)).into_dyn();
let output = dropout.forward(&input).unwrap();
assert_eq!(output, input);
println!("SpatialDropout2D rate=0 test passed: all values retained");
}
#[test]
fn test_spatial_dropout_2d_rate_one() {
let mut dropout = SpatialDropout2D::new(1.0, vec![2, 8, 10, 10]).unwrap();
dropout.set_training(true);
let input = Array4::ones((2, 8, 10, 10)).into_dyn();
let output = dropout.forward(&input).unwrap();
for val in output.iter() {
assert_eq!(*val, 0.0);
}
println!("SpatialDropout2D rate=1 test passed: all values dropped");
}
#[test]
fn test_spatial_dropout_2d_invalid_rate() {
let dropout_negative = SpatialDropout2D::new(-0.1, vec![2, 8, 10, 10]);
let dropout_over_one = SpatialDropout2D::new(1.5, vec![2, 8, 10, 10]);
assert!(dropout_negative.is_err());
assert!(dropout_over_one.is_err());
println!("SpatialDropout2D invalid rate test passed");
}
#[test]
fn test_spatial_dropout_2d_shape_validation() {
let mut dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
let wrong_input = Array4::ones((3, 8, 10, 10)).into_dyn();
let result = dropout.forward(&wrong_input);
assert!(result.is_err());
println!("SpatialDropout2D shape validation test passed");
}
#[test]
fn test_spatial_dropout_2d_dimension_validation() {
let mut dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
let input_3d = Array::ones(IxDyn(&[2, 8, 10]));
let result = dropout.forward(&input_3d);
assert!(result.is_err());
let input_5d = Array::ones(IxDyn(&[2, 8, 10, 10, 5]));
let result = dropout.forward(&input_5d);
assert!(result.is_err());
println!("SpatialDropout2D dimension validation test passed");
}
#[test]
fn test_spatial_dropout_2d_backward_pass() {
let mut dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
dropout.set_training(true);
let input = Array4::ones((2, 8, 10, 10)).into_dyn();
let output = dropout.forward(&input).unwrap();
let grad_output = Array4::ones((2, 8, 10, 10)).into_dyn();
let grad_input = dropout.backward(&grad_output).unwrap();
assert_eq!(grad_input.shape(), input.shape());
for b in 0..2 {
for c in 0..8 {
for h in 0..10 {
for w in 0..10 {
let out_val = output[[b, c, h, w]];
let grad_val = grad_input[[b, c, h, w]];
if out_val == 0.0 {
assert_eq!(
grad_val, 0.0,
"Gradient should be zero where dropout occurred"
);
}
}
}
}
}
println!("SpatialDropout2D backward pass test passed");
}
#[test]
fn test_spatial_dropout_2d_different_rates() {
let rates = vec![0.1, 0.3, 0.5, 0.7, 0.9];
for rate in rates {
let mut dropout = SpatialDropout2D::new(rate, vec![10, 50, 8, 8]).unwrap();
dropout.set_training(true);
let input = Array4::ones((10, 50, 8, 8)).into_dyn();
let output = dropout.forward(&input).unwrap();
let mut dropped_channels = 0;
let total_channels = 10 * 50;
for b in 0..10 {
for c in 0..50 {
if output[[b, c, 0, 0]] == 0.0 {
dropped_channels += 1;
}
}
}
let dropped_ratio = dropped_channels as f32 / total_channels as f32;
let expected_ratio = rate;
assert!(
(dropped_ratio - expected_ratio).abs() < 0.15,
"Rate {:.1}: expected ~{:.1}% dropped, got {:.1}%",
rate,
expected_ratio * 100.0,
dropped_ratio * 100.0
);
println!(
"SpatialDropout2D rate={:.1} test passed: {:.1}% channels dropped (expected {:.1}%)",
rate,
dropped_ratio * 100.0,
expected_ratio * 100.0
);
}
}
#[test]
fn test_spatial_dropout_2d_maintains_expected_value() {
let mut dropout = SpatialDropout2D::new(0.5, vec![20, 30, 12, 12]).unwrap();
dropout.set_training(true);
let input = Array4::from_elem((20, 30, 12, 12), 2.0).into_dyn();
let output = dropout.forward(&input).unwrap();
let sum: f32 = output.iter().sum();
let total_elements = 20 * 30 * 12 * 12;
let mean = sum / total_elements as f32;
assert_abs_diff_eq!(mean, 2.0, epsilon = 0.3);
println!(
"SpatialDropout2D expected value test passed: mean = {:.2} (expected 2.0)",
mean
);
}
#[test]
fn test_spatial_dropout_2d_layer_type() {
let dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
assert_eq!(dropout.layer_type(), "SpatialDropout2D");
println!("SpatialDropout2D layer type test passed");
}
#[test]
fn test_spatial_dropout_2d_output_shape() {
let dropout = SpatialDropout2D::new(0.5, vec![2, 8, 10, 10]).unwrap();
let shape_str = dropout.output_shape();
assert!(shape_str.contains("2") && shape_str.contains("8") && shape_str.contains("10"));
println!("SpatialDropout2D output shape test passed: {}", shape_str);
}
#[test]
fn test_spatial_dropout_2d_consistency_across_calls() {
let mut dropout = SpatialDropout2D::new(0.5, vec![5, 10, 8, 8]).unwrap();
dropout.set_training(true);
let input = Array4::ones((5, 10, 8, 8)).into_dyn();
let output1 = dropout.forward(&input).unwrap();
let output2 = dropout.forward(&input).unwrap();
assert_ne!(output1, output2);
println!("SpatialDropout2D consistency test passed: different masks on each forward pass");
}
#[test]
fn test_spatial_dropout_2d_various_shapes() {
let shapes = vec![
(1, 4, 8, 8),
(4, 16, 16, 16),
(8, 32, 32, 32),
(16, 8, 64, 64),
];
for (batch_size, channels, height, width) in shapes.iter() {
let mut dropout =
SpatialDropout2D::new(0.5, vec![*batch_size, *channels, *height, *width]).unwrap();
dropout.set_training(true);
let input = Array4::ones((*batch_size, *channels, *height, *width)).into_dyn();
let output = dropout.forward(&input).unwrap();
assert_eq!(output.shape(), &[*batch_size, *channels, *height, *width]);
println!(
"SpatialDropout2D shape test passed for ({}, {}, {}, {})",
batch_size, channels, height, width
);
}
}
#[test]
fn test_spatial_dropout_2d_scaling() {
let rate = 0.5;
let expected_scale = 1.0 / (1.0 - rate);
let mut dropout = SpatialDropout2D::new(rate, vec![2, 10, 8, 8]).unwrap();
dropout.set_training(true);
let input = Array4::from_elem((2, 10, 8, 8), 1.0).into_dyn();
let output = dropout.forward(&input).unwrap();
for &val in output.iter() {
if val != 0.0 {
assert_abs_diff_eq!(val, expected_scale, epsilon = 0.001);
}
}
println!(
"SpatialDropout2D scaling test passed: kept values scaled by {:.2}",
expected_scale
);
}
#[test]
fn test_spatial_dropout_2d_spatial_structure() {
let mut dropout = SpatialDropout2D::new(0.5, vec![3, 4, 16, 16]).unwrap();
dropout.set_training(true);
let input = Array4::from_shape_fn((3, 4, 16, 16), |(b, c, h, w)| {
(b as f32 * 100.0) + (c as f32 * 10.0) + (h as f32) + (w as f32 * 0.01) + 1.0
})
.into_dyn();
let output = dropout.forward(&input).unwrap();
for b in 0..3 {
for c in 0..4 {
let mask_indicator = if output[[b, c, 0, 0]] == 0.0 {
0.0
} else {
1.0
};
for h in 0..16 {
for w in 0..16 {
let current_indicator = if output[[b, c, h, w]] == 0.0 {
0.0
} else {
1.0
};
assert_eq!(
current_indicator, mask_indicator,
"Spatial position ({},{}) in batch {} channel {} has inconsistent mask",
h, w, b, c
);
}
}
}
}
println!(
"SpatialDropout2D spatial structure test passed: all positions in each feature map share the same mask"
);
}
#[test]
fn test_spatial_dropout_2d_non_square_feature_maps() {
let shapes = vec![
(2, 8, 16, 32), (2, 8, 32, 16), (4, 16, 7, 13), ];
for (batch_size, channels, height, width) in shapes.iter() {
let mut dropout =
SpatialDropout2D::new(0.5, vec![*batch_size, *channels, *height, *width]).unwrap();
dropout.set_training(true);
let input = Array4::ones((*batch_size, *channels, *height, *width)).into_dyn();
let output = dropout.forward(&input).unwrap();
assert_eq!(output.shape(), &[*batch_size, *channels, *height, *width]);
for b in 0..*batch_size {
for c in 0..*channels {
let first_val = output[[b, c, 0, 0]];
for h in 0..*height {
for w in 0..*width {
assert_abs_diff_eq!(output[[b, c, h, w]], first_val, epsilon = 0.001);
}
}
}
}
println!(
"SpatialDropout2D non-square test passed for ({}, {}, {}, {})",
batch_size, channels, height, width
);
}
}