use aprender::autograd::Tensor;
use aprender::nn::BatchNorm1d;
use aprender::nn::Module;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_training_standardization(
data in proptest::collection::vec(-10.0f32..10.0, 20..21usize)
) {
let batch = 5;
let features = 4;
let data: Vec<f32> = (0..batch * features)
.map(|i| if i < data.len() { data[i] } else { 0.0 })
.collect();
let bn = BatchNorm1d::new(features);
let x = Tensor::new(&data, &[batch, features]);
let y = bn.forward(&x);
let y_data = y.data();
for f in 0..features {
let mean: f32 = (0..batch).map(|b| y_data[b * features + f]).sum::<f32>()
/ batch as f32;
prop_assert!(
mean.abs() < 0.5,
"feature {f} mean={mean}, expected ~0"
);
}
}
#[test]
fn prop_denominator_positive(
data in proptest::collection::vec(-100.0f32..100.0, 20..21usize)
) {
let batch = 5;
let features = 4;
let data: Vec<f32> = (0..batch * features)
.map(|i| if i < data.len() { data[i] } else { 0.0 })
.collect();
let bn = BatchNorm1d::new(features);
let x = Tensor::new(&data, &[batch, features]);
let y = bn.forward(&x);
for (i, &val) in y.data().iter().enumerate() {
prop_assert!(
val.is_finite(),
"output[{i}]={val} is not finite — denominator may be non-positive"
);
}
}
#[test]
fn prop_running_variance_non_negative(
var in 0.0f32..100.0,
batch_var in 0.0f32..100.0,
momentum in 0.01f32..0.5
) {
let new_var = (1.0 - momentum) * var + momentum * batch_var;
prop_assert!(
new_var >= 0.0,
"running variance={new_var}, expected >= 0"
);
}
#[test]
fn prop_eval_uses_running_stats(
data1 in proptest::collection::vec(-10.0f32..10.0, 8..9usize),
data2 in proptest::collection::vec(-10.0f32..10.0, 8..9usize)
) {
let batch = 2;
let features = 4;
let d1: Vec<f32> = (0..batch * features)
.map(|i| if i < data1.len() { data1[i] } else { 0.0 })
.collect();
let d2: Vec<f32> = (0..batch * features)
.map(|i| if i < data2.len() { data2[i] } else { 0.0 })
.collect();
let bn = BatchNorm1d::new(features);
let x1 = Tensor::new(&d1, &[batch, features]);
let x2 = Tensor::new(&d2, &[batch, features]);
let y1 = bn.forward(&x1);
let y2 = bn.forward(&x2);
for &val in y1.data().iter().chain(y2.data().iter()) {
prop_assert!(val.is_finite(), "non-finite BN output={val}");
}
}
#[test]
#[ignore = "SIMD equivalence — trueno domain"]
fn prop_simd_equivalence(
_x in proptest::collection::vec(-100.0f32..100.0, 1..32usize)
) {
}
}