#![allow(clippy::float_cmp)]
use innr::{cosine, dot, fast_cosine_dispatch, l2_distance, l2_distance_squared};
fn ref_dot(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let mut sum = 0.0f32;
for i in 0..a.len() {
sum += a[i] * b[i];
}
sum
}
fn ref_l2_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
fn ref_l2(a: &[f32], b: &[f32]) -> f32 {
ref_l2_squared(a, b).sqrt()
}
fn ref_cosine(a: &[f32], b: &[f32]) -> f32 {
let dot = ref_dot(a, b);
let norm_a = ref_l2_squared(a, &vec![0.0; a.len()]).sqrt();
let norm_b = ref_l2_squared(b, &vec![0.0; b.len()]).sqrt();
if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn test_vec(dim: usize, seed: u64) -> Vec<f32> {
(0..dim)
.map(|i| {
let x = (seed.wrapping_mul(31).wrapping_add(i as u64 * 17)) as f32;
(x * 0.001).sin()
})
.collect()
}
fn approx_eq(a: f32, b: f32, rel_eps: f32) -> bool {
let diff = (a - b).abs();
let max_val = a.abs().max(b.abs()).max(1.0);
diff < rel_eps * max_val
}
#[test]
fn simd_correctness_dot_small_dims() {
for dim in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17, 31, 32, 33] {
for seed in 0..10 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let simd_result = dot(&a, &b);
let ref_result = ref_dot(&a, &b);
assert!(
approx_eq(simd_result, ref_result, 1e-5),
"dot mismatch at dim={}, seed={}: simd={}, ref={}",
dim,
seed,
simd_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_dot_large_dims() {
for dim in [64, 128, 256, 384, 512, 768, 1024, 1536] {
for seed in 0..5 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let simd_result = dot(&a, &b);
let ref_result = ref_dot(&a, &b);
assert!(
approx_eq(simd_result, ref_result, 1e-4),
"dot mismatch at dim={}: simd={}, ref={}",
dim,
simd_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_l2_squared() {
for dim in [1, 7, 8, 15, 16, 31, 32, 64, 128, 384, 768] {
for seed in 0..5 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let simd_result = l2_distance_squared(&a, &b);
let ref_result = ref_l2_squared(&a, &b);
assert!(
approx_eq(simd_result, ref_result, 1e-4),
"l2_squared mismatch at dim={}: simd={}, ref={}",
dim,
simd_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_l2() {
for dim in [1, 8, 16, 32, 64, 128, 384, 768] {
for seed in 0..5 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let simd_result = l2_distance(&a, &b);
let ref_result = ref_l2(&a, &b);
assert!(
approx_eq(simd_result, ref_result, 1e-4),
"l2 mismatch at dim={}: simd={}, ref={}",
dim,
simd_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_cosine() {
for dim in [1, 8, 16, 32, 64, 128, 384, 768] {
for seed in 0..5 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let simd_result = cosine(&a, &b);
let ref_result = ref_cosine(&a, &b);
assert!(
approx_eq(simd_result, ref_result, 1e-4),
"cosine mismatch at dim={}: simd={}, ref={}",
dim,
simd_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_fast_cosine_dispatch() {
for dim in [1, 8, 16, 32, 64, 128, 384, 768] {
for seed in 0..5 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 1000);
let dispatch_result = fast_cosine_dispatch(&a, &b);
let ref_result = ref_cosine(&a, &b);
assert!(
approx_eq(dispatch_result, ref_result, 1e-2),
"fast_cosine_dispatch mismatch at dim={}: dispatch={}, ref={}",
dim,
dispatch_result,
ref_result
);
}
}
}
#[test]
fn simd_correctness_edge_cases() {
assert_eq!(dot(&[], &[]), 0.0);
assert_eq!(dot(&[2.0], &[3.0]), 6.0);
let zeros = vec![0.0f32; 100];
let ones = vec![1.0f32; 100];
assert_eq!(dot(&zeros, &ones), 0.0);
assert_eq!(l2_distance_squared(&zeros, &zeros), 0.0);
let v = test_vec(64, 42);
assert!(approx_eq(l2_distance(&v, &v), 0.0, 1e-6));
assert!(approx_eq(cosine(&v, &v), 1.0, 1e-5));
}
#[test]
fn simd_correctness_special_values() {
let small: Vec<f32> = (0..64).map(|i| 1e-20 * (i as f32 + 1.0)).collect();
let ref_result = ref_dot(&small, &small);
let simd_result = dot(&small, &small);
assert!(
(simd_result - ref_result).abs() < 1e-30,
"small value mismatch"
);
let large: Vec<f32> = (0..64).map(|i| 1e10 * (i as f32 + 1.0)).collect();
let ref_result = ref_dot(&large, &large);
let simd_result = dot(&large, &large);
assert!(
approx_eq(simd_result, ref_result, 1e-3),
"large value mismatch: simd={}, ref={}",
simd_result,
ref_result
);
let mixed: Vec<f32> = (0..64)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let ref_result = ref_dot(&mixed, &mixed);
let simd_result = dot(&mixed, &mixed);
assert_eq!(simd_result, ref_result);
}
#[test]
fn simd_invariant_dot_commutative() {
for dim in [32, 64, 128, 384] {
let a = test_vec(dim, 1);
let b = test_vec(dim, 2);
assert!(
approx_eq(dot(&a, &b), dot(&b, &a), 1e-6),
"dot should be commutative"
);
}
}
#[test]
fn simd_invariant_l2_symmetric() {
for dim in [32, 64, 128, 384] {
let a = test_vec(dim, 1);
let b = test_vec(dim, 2);
assert!(
approx_eq(l2_distance(&a, &b), l2_distance(&b, &a), 1e-6),
"l2 should be symmetric"
);
}
}
#[test]
fn simd_invariant_l2_nonnegative() {
for dim in [32, 64, 128, 384] {
let a = test_vec(dim, 1);
let b = test_vec(dim, 2);
assert!(l2_distance(&a, &b) >= 0.0, "l2 should be non-negative");
assert!(
l2_distance_squared(&a, &b) >= 0.0,
"l2_squared should be non-negative"
);
}
}
#[test]
fn simd_invariant_cosine_range() {
for dim in [32, 64, 128, 384] {
for seed in 0..10 {
let a = test_vec(dim, seed);
let b = test_vec(dim, seed + 100);
let sim = cosine(&a, &b);
assert!(
(-1.0 - 1e-5..=1.0 + 1e-5).contains(&sim),
"cosine should be in [-1, 1], got {}",
sim
);
}
}
}