#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_basic() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!(mean.abs() < 1e-5, "Mean should be ~0, got {}", mean);
let var: f32 = output.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / output.len() as f32;
let std = var.sqrt();
assert!((std - 1.0).abs() < 0.01, "Std should be ~1, got {}", std);
}
#[test]
fn test_layer_norm_with_scale() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![2.0, 2.0, 2.0, 2.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
let var: f32 = output.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / output.len() as f32;
let std = var.sqrt();
assert!((std - 2.0).abs() < 0.01, "Std should be ~2, got {}", std);
}
#[test]
fn test_layer_norm_with_bias() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let bias = vec![5.0, 5.0, 5.0, 5.0];
let output = simd_layer_norm(&input, &weight, Some(&bias), 1e-5);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!((mean - 5.0).abs() < 0.01, "Mean should be ~5, got {}", mean);
}
#[test]
fn test_layer_norm_empty() {
let input: Vec<f32> = vec![];
let weight: Vec<f32> = vec![];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
assert!(output.is_empty());
}
#[test]
fn test_layer_norm_single_element() {
let input = vec![5.0];
let weight = vec![1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
assert!((output[0]).abs() < 1e-3);
}
#[test]
fn test_layer_norm_uniform_input() {
let input = vec![3.0, 3.0, 3.0, 3.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
for &x in &output {
assert!(x.abs() < 0.1);
}
}
#[test]
fn test_layer_norm_negative_values() {
let input = vec![-2.0, -1.0, 1.0, 2.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
assert!(output[0] < output[1]);
assert!(output[1] < output[2]);
assert!(output[2] < output[3]);
}
#[test]
fn test_layer_norm_large_values() {
let input = vec![1000.0, 2000.0, 3000.0, 4000.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
assert!(mean.abs() < 1e-4);
}
#[test]
fn test_rms_norm_basic() {
let input = vec![1.0, 2.0, 3.0];
let weight = vec![1.0, 1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
let rms = (14.0_f32 / 3.0).sqrt();
let expected: Vec<f32> = input.iter().map(|x| x / rms).collect();
for (out, exp) in output.iter().zip(expected.iter()) {
assert!((out - exp).abs() < 1e-5);
}
}
#[test]
fn test_rms_norm_with_scale() {
let input = vec![1.0, 2.0, 3.0];
let weight = vec![2.0, 2.0, 2.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
let rms = (14.0_f32 / 3.0).sqrt();
let expected: Vec<f32> = input.iter().map(|x| x / rms * 2.0).collect();
for (out, exp) in output.iter().zip(expected.iter()) {
assert!((out - exp).abs() < 1e-5);
}
}
#[test]
fn test_rms_norm_empty() {
let input: Vec<f32> = vec![];
let weight: Vec<f32> = vec![];
let output = simd_rms_norm(&input, &weight, 1e-5);
assert!(output.is_empty());
}
#[test]
fn test_rms_norm_single_element() {
let input = vec![5.0];
let weight = vec![1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
assert!((output[0] - 1.0).abs() < 1e-5);
}
#[test]
fn test_rms_norm_unit_vector() {
let input = vec![1.0, 0.0, 0.0];
let weight = vec![1.0, 1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
let expected = 3.0_f32.sqrt(); assert!(
(output[0] - expected).abs() < 1e-4,
"Expected {}, got {}",
expected,
output[0]
);
assert!(output[1].abs() < 1e-5);
assert!(output[2].abs() < 1e-5);
}
#[test]
fn test_rms_norm_zeros() {
let input = vec![0.0, 0.0, 0.0];
let weight = vec![1.0, 1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
for &x in &output {
assert!(x.abs() < 1e-2);
}
}
#[test]
fn test_rms_norm_negative_values() {
let input = vec![-3.0, 4.0];
let weight = vec![1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
let rms = (12.5_f32).sqrt();
assert!((output[0] - (-3.0 / rms)).abs() < 1e-5);
assert!((output[1] - (4.0 / rms)).abs() < 1e-5);
}
#[test]
fn test_rms_norm_preserves_direction() {
let input = vec![3.0, 4.0]; let weight = vec![1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
let ratio = output[1] / output[0];
assert!((ratio - 4.0 / 3.0).abs() < 1e-5);
}
#[test]
fn test_rope_position_zero() {
let mut x = vec![1.0, 2.0, 3.0, 4.0]; let original = x.clone();
apply_rope(&mut x, 4, 1, 0, 10000.0);
for (out, orig) in x.iter().zip(original.iter()) {
assert!((out - orig).abs() < 1e-5);
}
}
#[test]
fn test_rope_rotation_property() {
let mut x = vec![1.0, 0.0, 0.0, 1.0]; apply_rope(&mut x, 4, 1, 1, 10000.0);
let mag0 = (x[0] * x[0] + x[2] * x[2]).sqrt();
let mag1 = (x[1] * x[1] + x[3] * x[3]).sqrt();
assert!((mag0 - 1.0).abs() < 1e-5, "Magnitude of pair 0 should be 1");
assert!((mag1 - 1.0).abs() < 1e-5, "Magnitude of pair 1 should be 1");
}
#[test]
fn test_rope_multiple_heads() {
let mut x = vec![1.0; 8]; let original = x.clone();
apply_rope(&mut x, 8, 2, 0, 10000.0);
for (out, orig) in x.iter().zip(original.iter()) {
assert!((out - orig).abs() < 1e-5);
}
}
#[test]
fn test_rope_different_positions() {
let mut x1 = vec![1.0; 4];
let mut x2 = vec![1.0; 4];
apply_rope(&mut x1, 4, 1, 0, 10000.0);
apply_rope(&mut x2, 4, 1, 1, 10000.0);
assert!((x1[0] - x2[0]).abs() > 1e-6 || (x1[1] - x2[1]).abs() > 1e-6);
}
#[test]
fn test_rope_theta_scaling() {
let mut x1 = vec![1.0; 4];
let mut x2 = vec![1.0; 4];
apply_rope(&mut x1, 4, 1, 10, 10000.0);
apply_rope(&mut x2, 4, 1, 10, 1000.0);
assert!(
(x1[1] - x2[1]).abs() > 1e-5 || (x1[3] - x2[3]).abs() > 1e-5,
"Different theta should give different results for non-zero frequency indices"
);
}
#[test]
fn test_rope_large_position() {
let mut x = vec![1.0, 2.0, 3.0, 4.0];
apply_rope(&mut x, 4, 1, 1000, 10000.0);
for &val in &x {
assert!(val.is_finite());
}
}
#[test]
fn test_rope_eight_heads() {
let hidden_dim = 64;
let num_heads = 8;
let mut x = vec![0.5; hidden_dim];
apply_rope(&mut x, hidden_dim, num_heads, 5, 10000.0);
for &val in &x {
assert!(val.is_finite());
}
}
#[test]
fn test_rope_preserves_length() {
let mut x = vec![3.0, 4.0, 0.0, 0.0]; apply_rope(&mut x, 4, 1, 1, 10000.0);
assert_eq!(x.len(), 4);
}
#[test]
fn test_norm_then_rope() {
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let weight = vec![1.0; 8];
let normalized = simd_rms_norm(&input, &weight, 1e-5);
let mut output = normalized;
apply_rope(&mut output, 8, 2, 5, 10000.0);
for &val in &output {
assert!(val.is_finite());
}
}
#[test]
fn test_layer_norm_vs_rms_norm() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let ln_output = simd_layer_norm(&input, &weight, None, 1e-5);
let rms_output = simd_rms_norm(&input, &weight, 1e-5);
let ln_mean: f32 = ln_output.iter().sum::<f32>() / 4.0;
let rms_mean: f32 = rms_output.iter().sum::<f32>() / 4.0;
assert!(ln_mean.abs() < 1e-5, "LayerNorm should have mean ~0");
assert!(rms_mean.abs() > 0.1, "RMSNorm should not center");
}
#[test]
fn test_layer_norm_eps_impact() {
let input = vec![0.0, 0.0, 0.0, 0.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let output = simd_layer_norm(&input, &weight, None, 1e-5);
for &val in &output {
assert!(val.is_finite());
}
}
#[test]
fn test_rms_norm_eps_impact() {
let input = vec![0.0, 0.0];
let weight = vec![1.0, 1.0];
let output = simd_rms_norm(&input, &weight, 1e-5);
for &val in &output {
assert!(val.is_finite());
}
}
#[test]
fn test_rope_half_dim_calculation() {
for (hidden_dim, num_heads) in [(8, 2), (16, 4), (32, 8), (64, 16)] {
let mut x = vec![1.0; hidden_dim];
apply_rope(&mut x, hidden_dim, num_heads, 1, 10000.0);
for &val in &x {
assert!(val.is_finite());
}
}
}
}