#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_bf16_to_f32_various_values() {
let values = [0.0f32, 1.0, -1.0, 0.5, 2.0, -0.5, 100.0, -100.0];
let mut bf16_bytes = Vec::new();
for &v in &values {
bf16_bytes.extend_from_slice(&half::bf16::from_f32(v).to_le_bytes());
}
let result = simd_bf16_to_f32(&bf16_bytes);
assert_eq!(result.len(), values.len());
for (i, (&expected, &actual)) in values.iter().zip(result.iter()).enumerate() {
let tol = expected.abs().max(1.0) * 0.01;
assert!(
(actual - expected).abs() < tol,
"Value {} mismatch: expected {}, got {}",
i,
expected,
actual
);
}
}
#[test]
fn test_simd_bf16_to_f32_large_batch() {
let count = 17; let mut bf16_bytes = Vec::with_capacity(count * 2);
for i in 0..count {
let v = i as f32 * 0.1;
bf16_bytes.extend_from_slice(&half::bf16::from_f32(v).to_le_bytes());
}
let result = simd_bf16_to_f32(&bf16_bytes);
assert_eq!(result.len(), count);
for i in 0..count {
let expected = i as f32 * 0.1;
let tol = 0.01;
assert!(
(result[i] - expected).abs() < tol,
"Index {} mismatch: expected {}, got {}",
i,
expected,
result[i]
);
}
}
#[test]
fn test_simd_f16_to_f32_single() {
let f16_bytes = half::f16::from_f32(1.0).to_le_bytes();
let result = simd_f16_to_f32(&f16_bytes);
assert_eq!(result.len(), 1);
assert!((result[0] - 1.0).abs() < 1e-3);
}
#[test]
fn test_simd_f16_to_f32_various_values() {
let values = [0.0f32, 1.0, -1.0, 0.5, 2.0, -0.5];
let mut f16_bytes = Vec::new();
for &v in &values {
f16_bytes.extend_from_slice(&half::f16::from_f32(v).to_le_bytes());
}
let result = simd_f16_to_f32(&f16_bytes);
assert_eq!(result.len(), values.len());
for (expected, actual) in values.iter().zip(result.iter()) {
assert!(
(actual - expected).abs() < 0.01,
"Mismatch: expected {}, got {}",
expected,
actual
);
}
}
#[test]
fn test_simd_bf16_dot_basic() {
let a_vals = [1.0f32, 2.0, 3.0, 4.0];
let b_vals = [1.0f32, 1.0, 1.0, 1.0];
let mut a_bytes = Vec::new();
let mut b_bytes = Vec::new();
for (&a, &b) in a_vals.iter().zip(b_vals.iter()) {
a_bytes.extend_from_slice(&half::bf16::from_f32(a).to_le_bytes());
b_bytes.extend_from_slice(&half::bf16::from_f32(b).to_le_bytes());
}
let result = simd_bf16_dot(&a_bytes, &b_bytes);
assert!((result - 10.0).abs() < 0.1);
}
#[test]
fn test_simd_bf16_dot_large() {
let n = 256; let mut a_bytes = Vec::with_capacity(n * 2);
let mut b_bytes = Vec::with_capacity(n * 2);
for i in 0..n {
let v = ((i % 10) as f32) * 0.1;
a_bytes.extend_from_slice(&half::bf16::from_f32(v).to_le_bytes());
b_bytes.extend_from_slice(&half::bf16::from_f32(1.0).to_le_bytes());
}
let result = simd_bf16_dot(&a_bytes, &b_bytes);
let expected = 25.0 * 4.5 + 1.5;
assert!(
(result - expected).abs() < 1.0,
"Large dot: expected ~{}, got {}",
expected,
result
);
}
#[test]
fn test_simd_bf16_matmul_identity() {
let input = vec![1.0f32, 2.0, 3.0];
let mut weight_bytes = Vec::new();
let identity = [[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
for row in &identity {
for &v in row {
weight_bytes.extend_from_slice(&half::bf16::from_f32(v).to_le_bytes());
}
}
let output = simd_bf16_matmul(&input, &weight_bytes, 3, 3);
assert_eq!(output.len(), 3);
assert!((output[0] - 1.0).abs() < 0.01);
assert!((output[1] - 2.0).abs() < 0.01);
assert!((output[2] - 3.0).abs() < 0.01);
}
#[test]
fn test_simd_bf16_matmul_projection() {
let input = vec![1.0f32, 2.0, 3.0, 4.0];
let weight = [
[1.0f32, 1.0, 1.0, 1.0], [1.0, 0.0, 0.0, -1.0], ];
let mut weight_bytes = Vec::new();
for row in &weight {
for &v in row {
weight_bytes.extend_from_slice(&half::bf16::from_f32(v).to_le_bytes());
}
}
let output = simd_bf16_matmul(&input, &weight_bytes, 4, 2);
assert_eq!(output.len(), 2);
assert!(
(output[0] - 10.0).abs() < 0.1,
"Sum: expected 10, got {}",
output[0]
);
assert!(
(output[1] - (-3.0)).abs() < 0.1,
"Diff: expected -3, got {}",
output[1]
);
}
include!("simd_matmul.rs");
}