use super::*;
use crate::Tensor;
#[test]
fn test_batch_norm_inference() {
let input = Tensor::<f32>::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 1.0, 3.0, 5.0, 7.0,
],
&[1, 2, 2, 4],
)
.expect("test: operation should succeed");
let gamma =
Tensor::<f32>::from_vec(vec![1.0, 1.0], &[2]).expect("test: from_vec should succeed");
let beta =
Tensor::<f32>::from_vec(vec![0.0, 0.0], &[2]).expect("test: from_vec should succeed");
let running_mean =
Tensor::<f32>::from_vec(vec![4.5, 4.5], &[2]).expect("test: from_vec should succeed");
let running_var =
Tensor::<f32>::from_vec(vec![5.25, 5.25], &[2]).expect("test: from_vec should succeed");
let output = batch_norm(
&input,
&gamma,
&beta,
&running_mean,
&running_var,
1e-5,
false,
)
.expect("test: operation should succeed");
assert_eq!(output.shape().dims(), &[1, 2, 2, 4]);
if let Some(data) = output.as_slice() {
for &val in data {
assert!(val.abs() < 3.0); }
}
}
#[test]
fn test_layer_norm() {
let input = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
.expect("test: from_vec should succeed");
let gamma =
Tensor::<f32>::from_vec(vec![1.0, 1.0, 1.0], &[3]).expect("test: from_vec should succeed");
let beta =
Tensor::<f32>::from_vec(vec![0.0, 0.0, 0.0], &[3]).expect("test: from_vec should succeed");
let output =
layer_norm(&input, &gamma, &beta, &[3], 1e-5).expect("test: layer_norm should succeed");
assert_eq!(output.shape().dims(), &[2, 3]);
if let Some(data) = output.as_slice() {
let row1_mean: f32 = data[0..3].iter().sum::<f32>() / 3.0;
assert!((row1_mean).abs() < 1e-5);
let row2_mean: f32 = data[3..6].iter().sum::<f32>() / 3.0;
assert!((row2_mean).abs() < 1e-5);
}
}
#[test]
fn test_group_norm() {
let input = Tensor::<f32>::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 1.0, 3.0, 5.0, 7.0, ],
&[1, 4, 2, 2],
)
.expect("test: operation should succeed");
let gamma = Tensor::<f32>::from_vec(vec![1.0, 1.0, 1.0, 1.0], &[4])
.expect("test: from_vec should succeed");
let beta = Tensor::<f32>::from_vec(vec![0.0, 0.0, 0.0, 0.0], &[4])
.expect("test: from_vec should succeed");
let output =
group_norm(&input, &gamma, &beta, 2, 1e-5).expect("test: group_norm should succeed");
assert_eq!(output.shape().dims(), &[1, 4, 2, 2]);
if let Some(data) = output.as_slice() {
for &val in data {
assert!(val.abs() < 3.0); }
}
}
#[test]
fn test_sync_batch_norm_inference() {
let input = Tensor::<f32>::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 1.0, 3.0, 5.0, 7.0,
],
&[1, 2, 2, 4],
)
.expect("test: operation should succeed");
let gamma =
Tensor::<f32>::from_vec(vec![1.0, 1.0], &[2]).expect("test: from_vec should succeed");
let beta =
Tensor::<f32>::from_vec(vec![0.0, 0.0], &[2]).expect("test: from_vec should succeed");
let running_mean =
Tensor::<f32>::from_vec(vec![4.5, 4.5], &[2]).expect("test: from_vec should succeed");
let running_var =
Tensor::<f32>::from_vec(vec![5.25, 5.25], &[2]).expect("test: from_vec should succeed");
let (output, updated_mean, updated_var) = sync_batch_norm(
&input,
&gamma,
&beta,
&running_mean,
&running_var,
1e-5,
false,
None,
None,
)
.expect("test: operation should succeed");
assert_eq!(output.shape().dims(), &[1, 2, 2, 4]);
assert_eq!(updated_mean.shape().dims(), &[2]);
assert_eq!(updated_var.shape().dims(), &[2]);
if let Some(data) = output.as_slice() {
for &val in data {
assert!(val.abs() < 3.0); }
}
}
#[test]
fn test_sync_batch_norm_training() {
let input = Tensor::<f32>::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 1.0, 3.0, 5.0, 7.0,
],
&[1, 2, 2, 4],
)
.expect("test: operation should succeed");
let gamma =
Tensor::<f32>::from_vec(vec![1.0, 1.0], &[2]).expect("test: from_vec should succeed");
let beta =
Tensor::<f32>::from_vec(vec![0.0, 0.0], &[2]).expect("test: from_vec should succeed");
let running_mean =
Tensor::<f32>::from_vec(vec![4.0, 4.0], &[2]).expect("test: from_vec should succeed");
let running_var =
Tensor::<f32>::from_vec(vec![1.0, 1.0], &[2]).expect("test: from_vec should succeed");
let result = sync_batch_norm(
&input,
&gamma,
&beta,
&running_mean,
&running_var,
1e-5,
true,
Some(0.1),
None,
);
match result {
Ok((output, updated_mean, updated_var)) => {
assert_eq!(output.shape().dims(), &[1, 2, 2, 4]);
assert_eq!(updated_mean.shape().dims(), &[2]);
assert_eq!(updated_var.shape().dims(), &[2]);
if let Some(data) = output.as_slice() {
for &val in data {
assert!(val.abs() < 5.0); }
}
}
Err(_) => {
println!("Sync batch norm failed as expected without proper collective setup");
}
}
}