use super::ops::*;
use super::simd::*;
fn with_unaligned_f32_slice<F>(len: usize, f: F)
where
F: FnOnce(&[f32]),
{
let mut buffer = vec![0u8; 64 + len * 4];
let ptr = buffer.as_ptr() as usize;
let mut offset = 0;
while (ptr + offset) & 3 != 0 || (ptr + offset) & 31 == 0 {
offset += 1;
}
assert!(offset + len * 4 <= buffer.len());
let slice_ptr = unsafe { buffer.as_ptr().add(offset) as *const f32 };
let slice = unsafe { std::slice::from_raw_parts(slice_ptr, len) };
let slice_addr = slice.as_ptr() as usize;
assert_eq!(slice_addr % 4, 0, "Slice must be 4-byte aligned for f32");
assert_ne!(
slice_addr % 32,
0,
"Slice must NOT be 32-byte aligned for testing unaligned loads"
);
for i in 0..len {
let val = (i as f32) * 1.0;
let bytes = val.to_ne_bytes();
for j in 0..4 {
buffer[offset + i * 4 + j] = bytes[j];
}
}
f(slice);
}
fn with_unaligned_f32_slice_mut<F>(len: usize, f: F)
where
F: FnOnce(&mut [f32]),
{
let mut buffer = vec![0u8; 64 + len * 4];
let ptr = buffer.as_ptr() as usize;
let mut offset = 0;
while (ptr + offset) & 3 != 0 || (ptr + offset) & 31 == 0 {
offset += 1;
}
assert!(offset + len * 4 <= buffer.len());
let slice_ptr = unsafe { buffer.as_mut_ptr().add(offset) as *mut f32 };
let slice = unsafe { std::slice::from_raw_parts_mut(slice_ptr, len) };
for (i, val) in slice.iter_mut().enumerate() {
*val = (i as f32) * 1.0;
}
f(slice);
}
#[test]
fn test_simd_unaligned_load_dot_product() {
let len = 100;
with_unaligned_f32_slice(len, |a| {
with_unaligned_f32_slice(len, |b| {
let result = dot_product(a, b).unwrap();
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
assert!(
(result - expected).abs() < 1e-4,
"Unaligned dot product failed: {} vs {}",
result,
expected
);
});
});
}
#[test]
fn test_simd_unaligned_load_cosine_similarity() {
let len = 128; with_unaligned_f32_slice(len, |a| {
with_unaligned_f32_slice(len, |b| {
let result = cosine_similarity(a, b).unwrap();
assert!(result > 0.9, "Cosine similarity should be valid"); });
});
}
#[test]
fn test_simd_unaligned_load_euclidean_distance() {
let len = 33; with_unaligned_f32_slice(len, |a| {
with_unaligned_f32_slice(len, |b| {
let result = euclidean_distance(a, b).unwrap();
assert!(result >= 0.0);
});
});
}
#[test]
fn test_dot_product_nan_propagation_exact() {
let a = vec![1.0, 2.0, 3.0, f32::NAN, 5.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = dot_product(&a, &b).unwrap();
assert!(result.is_nan(), "Dot product should propagate NaN");
}
#[test]
fn test_dot_product_inf_propagation_exact() {
let a = vec![1.0, 2.0, f32::INFINITY, 4.0];
let b = vec![1.0, 2.0, 1.0, 4.0];
let result = dot_product(&a, &b).unwrap();
assert_eq!(
result,
f32::INFINITY,
"Dot product should propagate Infinity"
);
}
#[test]
fn test_cosine_similarity_subnormal_handling() {
let val = f32::MIN_POSITIVE / 10.0; let a = vec![val, val];
let b = vec![val, val];
let result = cosine_similarity(&a, &b).unwrap();
assert!(!result.is_nan());
}
#[test]
fn test_simd_vector_len_1() {
let a = vec![2.0];
let b = vec![3.0];
let result = dot_product(&a, &b).unwrap();
assert_eq!(result, 6.0);
}
#[test]
fn test_simd_vector_len_3() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 3.0, 4.0];
let result = dot_product(&a, &b).unwrap(); assert_eq!(result, 20.0);
}
#[test]
fn test_simd_vector_len_7() {
let a = vec![1.0; 7];
let b = vec![1.0; 7];
let result = dot_product(&a, &b).unwrap();
assert_eq!(result, 7.0);
}
#[test]
fn test_simd_vector_len_exact_chunk() {
let a = vec![1.0; 8];
let b = vec![1.0; 8];
let result = dot_product(&a, &b).unwrap();
assert_eq!(result, 8.0);
}
#[test]
fn test_simd_vector_len_exact_chunk_plus_one() {
let a = vec![1.0; 9];
let b = vec![1.0; 9];
let result = dot_product(&a, &b).unwrap();
assert_eq!(result, 9.0);
}
#[test]
fn test_cosine_similarity_zero_vector_lhs() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 3.0];
let result = cosine_similarity(&a, &b).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_cosine_similarity_zero_vector_rhs() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 0.0, 0.0];
let result = cosine_similarity(&a, &b).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_cosine_similarity_both_zero() {
let a = vec![0.0, 0.0];
let b = vec![0.0, 0.0];
let result = cosine_similarity(&a, &b).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_simd_dot_and_magnitudes_zero_length() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let (dot, mag_a, mag_b) = super::simd::dot_and_magnitudes(&a, &b);
assert_eq!(dot, 0.0);
assert_eq!(mag_a, 0.0);
assert_eq!(mag_b, 0.0);
}
#[test]
fn test_simd_dot_and_magnitudes_nan() {
let a = vec![1.0, f32::NAN, 3.0];
let b = vec![1.0, 2.0, 3.0];
let (dot, mag_a, mag_b) = super::simd::dot_and_magnitudes(&a, &b);
assert!(dot.is_nan());
assert!(mag_a.is_nan());
assert_eq!(mag_b, 14.0);
}
#[test]
fn test_simd_dot_and_magnitudes_inf() {
let a = vec![1.0, f32::INFINITY, 3.0];
let b = vec![1.0, 2.0, 3.0];
let (dot, mag_a, mag_b) = super::simd::dot_and_magnitudes(&a, &b);
assert!(dot.is_infinite());
assert!(mag_a.is_infinite());
assert_eq!(mag_b, 14.0);
}
#[test]
fn test_simd_squared_diff_sum_zero_length() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let res = super::simd::squared_diff_sum(&a, &b);
assert_eq!(res, 0.0);
}
#[test]
fn test_simd_dot_product_sum_zero_length() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let res = super::simd::dot_product_sum(&a, &b);
assert_eq!(res, 0.0);
}
#[test]
fn test_simd_dot_and_magnitudes_large_vector() {
let len = 1023; let a: Vec<f32> = (0..len).map(|i| (i % 10) as f32).collect();
let b: Vec<f32> = (0..len).map(|i| ((i + 1) % 10) as f32).collect();
let (dot, mag_a, mag_b) = super::simd::dot_and_magnitudes(&a, &b);
let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let expected_mag_a: f32 = a.iter().map(|x| x * x).sum();
let expected_mag_b: f32 = b.iter().map(|x| x * x).sum();
let epsilon = 0.01;
assert!(
(dot - expected_dot).abs() < epsilon,
"Dot product mismatch: {} vs {}",
dot,
expected_dot
);
assert!(
(mag_a - expected_mag_a).abs() < epsilon,
"Mag A mismatch: {} vs {}",
mag_a,
expected_mag_a
);
assert!(
(mag_b - expected_mag_b).abs() < epsilon,
"Mag B mismatch: {} vs {}",
mag_b,
expected_mag_b
);
}
#[test]
fn test_simd_dot_product_associativity() {
let len = 1000;
let a: Vec<f32> = (0..len)
.map(|i| if i % 2 == 0 { 1.0e5 } else { 1.0 })
.collect();
let b: Vec<f32> = (0..len)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0e5 })
.collect();
let scalar_dot = super::simd::dot_product_scalar(&a, &b);
let simd_dot = super::simd::dot_product_sum(&a, &b);
let diff = (scalar_dot - simd_dot).abs();
assert!(
diff < 1.0,
"SIMD vs Scalar divergence: scalar {}, simd {}, diff {}",
scalar_dot,
simd_dot,
diff
);
}
#[test]
fn test_scale_in_place_basic() {
let mut v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
scale_in_place(&mut v, 2.0);
assert_eq!(v, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
}
#[test]
fn test_scale_in_place_zero_length() {
let mut v: Vec<f32> = vec![];
scale_in_place(&mut v, 2.0);
assert!(v.is_empty());
}
#[test]
fn test_scale_in_place_unaligned() {
with_unaligned_f32_slice_mut(100, |v| {
let original: Vec<f32> = v.to_vec();
scale_in_place(&mut *v, 2.0);
for (i, &val) in v.iter().enumerate() {
assert!(
(val - original[i] * 2.0).abs() < 1e-6,
"Index {}: {} vs {}",
i,
val,
original[i] * 2.0
);
}
});
}
#[test]
fn test_scale_in_place_large_vector() {
let len = 1023;
let mut v: Vec<f32> = (0..len).map(|i| i as f32).collect();
let original = v.clone();
scale_in_place(&mut v, 0.5);
for (i, &val) in v.iter().enumerate() {
assert!((val - original[i] * 0.5).abs() < 1e-6);
}
}
#[test]
fn test_scale_in_place_nan_scalar() {
let mut v = vec![1.0, 2.0, 3.0];
scale_in_place(&mut v, f32::NAN);
for val in v {
assert!(val.is_nan());
}
}
#[test]
fn test_scale_in_place_inf_scalar() {
let mut v = vec![1.0, -2.0, 0.0];
scale_in_place(&mut v, f32::INFINITY);
assert_eq!(v[0], f32::INFINITY);
assert_eq!(v[1], f32::NEG_INFINITY);
assert!(v[2].is_nan()); }
#[test]
fn test_scale_in_place_zero_scalar() {
let mut v = vec![1.0, 2.0, f32::INFINITY, f32::NAN];
scale_in_place(&mut v, 0.0);
assert_eq!(v[0], 0.0);
assert_eq!(v[1], 0.0);
assert!(v[2].is_nan()); assert!(v[3].is_nan()); }
use super::constants::SQUARED_MAGNITUDE_THRESHOLD;
#[test]
fn test_is_normalized_lower_boundary() {
let tolerance = 0.1;
let lower_mag = 1.0 - tolerance;
let v = vec![lower_mag, 0.0];
assert!(
is_normalized(&v, tolerance),
"Should accept magnitude exactly at lower bound"
);
let v_less = vec![lower_mag - 1e-6, 0.0];
assert!(
!is_normalized(&v_less, tolerance),
"Should reject magnitude slightly below lower bound"
);
}
#[test]
fn test_is_normalized_upper_boundary() {
let tolerance = 0.1;
let upper_mag = 1.0 + tolerance;
let v = vec![upper_mag, 0.0];
assert!(
is_normalized(&v, tolerance),
"Should accept magnitude exactly at upper bound"
);
let v_more = vec![upper_mag + 1e-6, 0.0];
assert!(
!is_normalized(&v_more, tolerance),
"Should reject magnitude slightly above upper bound"
);
}
#[test]
fn test_normalize_threshold_boundary() {
let val = SQUARED_MAGNITUDE_THRESHOLD.sqrt() * 1.0001;
let v = vec![val];
let normalized = normalize(&v);
assert!(
(normalized[0] - 1.0).abs() < 1e-6,
"Should normalize at/above threshold"
);
let val_small = SQUARED_MAGNITUDE_THRESHOLD.sqrt() * 0.999;
let v_small = vec![val_small];
let normalized_small = normalize(&v_small);
assert_eq!(normalized_small[0], 0.0, "Should zero out below threshold");
}