#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use crate::arch;
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use crate::MIN_DIM_SIMD;
#[cfg(target_arch = "x86_64")]
const MIN_DIM_AVX512: usize = 64;
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"innr::dot: slice length mismatch ({} vs {})",
a.len(),
b.len()
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let n = a.len();
#[cfg(target_arch = "x86_64")]
{
if n >= MIN_DIM_AVX512 && is_x86_feature_detected!("avx512f") {
return unsafe { arch::x86_64::dot_avx512(a, b) };
}
if n >= MIN_DIM_SIMD && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
{
return unsafe { arch::x86_64::dot_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= MIN_DIM_SIMD {
return unsafe { arch::aarch64::dot_neon(a, b) };
}
}
#[allow(unreachable_code)]
dot_portable(a, b)
}
#[inline]
#[must_use]
pub fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let chunks = n / 4;
let mut s0 = 0.0f32;
let mut s1 = 0.0f32;
let mut s2 = 0.0f32;
let mut s3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
s0 += a[base] * b[base];
s1 += a[base + 1] * b[base + 1];
s2 += a[base + 2] * b[base + 2];
s3 += a[base + 3] * b[base + 3];
}
let mut result = s0 + s1 + s2 + s3;
for i in (chunks * 4)..n {
result += a[i] * b[i];
}
result
}
#[inline]
#[must_use]
pub fn norm(v: &[f32]) -> f32 {
dot(v, v).sqrt()
}
pub fn normalize(v: &mut [f32]) {
let n = norm(v);
if n > crate::NORM_EPSILON {
let inv = 1.0 / n;
for x in v.iter_mut() {
*x *= inv;
}
}
}
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"innr::cosine: slice length mismatch ({} vs {})",
a.len(),
b.len()
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let n = a.len();
#[cfg(target_arch = "x86_64")]
{
if n >= MIN_DIM_AVX512 && is_x86_feature_detected!("avx512f") {
return unsafe { arch::x86_64::cosine_avx512(a, b) };
}
if n >= MIN_DIM_SIMD && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
{
return unsafe { arch::x86_64::cosine_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= MIN_DIM_SIMD {
return unsafe { arch::aarch64::cosine_neon(a, b) };
}
}
#[allow(unreachable_code)]
cosine_portable(a, b)
}
#[inline]
#[must_use]
pub fn cosine_portable(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let chunks = n / 4;
let mut ab0 = 0.0f32;
let mut ab1 = 0.0f32;
let mut ab2 = 0.0f32;
let mut ab3 = 0.0f32;
let mut aa0 = 0.0f32;
let mut aa1 = 0.0f32;
let mut aa2 = 0.0f32;
let mut aa3 = 0.0f32;
let mut bb0 = 0.0f32;
let mut bb1 = 0.0f32;
let mut bb2 = 0.0f32;
let mut bb3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
let a0 = a[base];
let b0 = b[base];
let a1 = a[base + 1];
let b1 = b[base + 1];
let a2 = a[base + 2];
let b2 = b[base + 2];
let a3 = a[base + 3];
let b3 = b[base + 3];
ab0 += a0 * b0;
aa0 += a0 * a0;
bb0 += b0 * b0;
ab1 += a1 * b1;
aa1 += a1 * a1;
bb1 += b1 * b1;
ab2 += a2 * b2;
aa2 += a2 * a2;
bb2 += b2 * b2;
ab3 += a3 * b3;
aa3 += a3 * a3;
bb3 += b3 * b3;
}
let mut ab = ab0 + ab1 + ab2 + ab3;
let mut aa = aa0 + aa1 + aa2 + aa3;
let mut bb = bb0 + bb1 + bb2 + bb3;
for i in (chunks * 4)..n {
let ai = a[i];
let bi = b[i];
ab += ai * bi;
aa += ai * ai;
bb += bi * bi;
}
if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
}
}
#[inline]
#[must_use]
pub fn angular_distance(a: &[f32], b: &[f32]) -> f32 {
let sim = cosine(a, b).clamp(-1.0, 1.0);
sim.acos() / std::f32::consts::PI
}
#[inline]
#[must_use]
pub fn matryoshka_dot(a: &[f32], b: &[f32], prefix_len: usize) -> f32 {
let end = prefix_len.min(a.len()).min(b.len());
dot(&a[..end], &b[..end])
}
#[inline]
#[must_use]
pub fn matryoshka_cosine(a: &[f32], b: &[f32], prefix_len: usize) -> f32 {
let end = prefix_len.min(a.len()).min(b.len());
cosine(&a[..end], &b[..end])
}
#[inline]
#[must_use]
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared(a, b).sqrt()
}
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn l1_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"innr::l1_distance: slice length mismatch ({} vs {})",
a.len(),
b.len()
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let n = a.len();
#[cfg(target_arch = "x86_64")]
{
if n >= MIN_DIM_AVX512 && is_x86_feature_detected!("avx512f") {
return unsafe { arch::x86_64::l1_avx512(a, b) };
}
if n >= MIN_DIM_SIMD && is_x86_feature_detected!("avx2") {
return unsafe { arch::x86_64::l1_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= MIN_DIM_SIMD {
return unsafe { arch::aarch64::l1_neon(a, b) };
}
}
#[allow(unreachable_code)]
l1_distance_portable(a, b)
}
#[inline]
#[must_use]
pub fn l1_distance_portable(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let chunks = n / 4;
let mut s0 = 0.0f32;
let mut s1 = 0.0f32;
let mut s2 = 0.0f32;
let mut s3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
s0 += (a[base] - b[base]).abs();
s1 += (a[base + 1] - b[base + 1]).abs();
s2 += (a[base + 2] - b[base + 2]).abs();
s3 += (a[base + 3] - b[base + 3]).abs();
}
let mut result = s0 + s1 + s2 + s3;
for i in (chunks * 4)..n {
result += (a[i] - b[i]).abs();
}
result
}
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"innr::l2_distance_squared: slice length mismatch ({} vs {})",
a.len(),
b.len()
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let n = a.len();
#[cfg(target_arch = "x86_64")]
{
if n >= MIN_DIM_AVX512 && is_x86_feature_detected!("avx512f") {
return unsafe { arch::x86_64::l2_squared_avx512(a, b) };
}
if n >= MIN_DIM_SIMD && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
{
return unsafe { arch::x86_64::l2_squared_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= MIN_DIM_SIMD {
return unsafe { arch::aarch64::l2_squared_neon(a, b) };
}
}
#[allow(unreachable_code)]
l2_distance_squared_portable(a, b)
}
#[inline]
#[must_use]
pub fn l2_distance_squared_portable(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let chunks = n / 4;
let mut s0 = 0.0f32;
let mut s1 = 0.0f32;
let mut s2 = 0.0f32;
let mut s3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
let d0 = a[base] - b[base];
let d1 = a[base + 1] - b[base + 1];
let d2 = a[base + 2] - b[base + 2];
let d3 = a[base + 3] - b[base + 3];
s0 += d0 * d0;
s1 += d1 * d1;
s2 += d2 * d2;
s3 += d3 * d3;
}
let mut result = s0 + s1 + s2 + s3;
for i in (chunks * 4)..n {
let d = a[i] - b[i];
result += d * d;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matryoshka_ranking_preservation() {
let query = [1.0, 0.5, 0.2, 0.1];
let doc1 = [0.9, 0.4, 0.1, 0.05]; let doc2 = [0.1, 0.1, 0.1, 0.1]; let doc3 = [-0.5, -0.2, 0.0, 0.0];
let sim1_full = cosine(&query, &doc1);
let sim2_full = cosine(&query, &doc2);
let sim3_full = cosine(&query, &doc3);
assert!(sim1_full > sim2_full);
assert!(sim2_full > sim3_full);
let sim1_prefix = matryoshka_cosine(&query, &doc1, 2);
let sim2_prefix = matryoshka_cosine(&query, &doc2, 2);
let sim3_prefix = matryoshka_cosine(&query, &doc3, 2);
assert!(sim1_prefix > sim2_prefix);
assert!(sim2_prefix > sim3_prefix);
}
#[test]
fn test_dot_simd_threshold() {
let small_a: Vec<f32> = (0..8).map(|i| i as f32).collect();
let small_b: Vec<f32> = (0..8).map(|i| i as f32).collect();
let result_small = dot(&small_a, &small_b);
let large_a: Vec<f32> = (0..32).map(|i| i as f32).collect();
let large_b: Vec<f32> = (0..32).map(|i| i as f32).collect();
let result_large = dot(&large_a, &large_b);
let expected_small: f32 = (0..8).map(|i| (i * i) as f32).sum();
let expected_large: f32 = (0..32).map(|i| (i * i) as f32).sum();
assert!((result_small - expected_small).abs() < 1e-3);
assert!((result_large - expected_large).abs() < 1e-1);
}
#[test]
fn test_l2_distance_triangle_inequality() {
let a = [0.0_f32, 0.0];
let b = [1.0_f32, 0.0];
let c = [0.0_f32, 1.0];
let ab = l2_distance(&a, &b);
let bc = l2_distance(&b, &c);
let ac = l2_distance(&a, &c);
assert!(ac <= ab + bc + 1e-6);
}
#[test]
fn test_dot_empty() {
assert_eq!(dot(&[], &[]), 0.0);
}
#[test]
fn test_cosine_empty() {
let result = cosine(&[], &[]);
assert_eq!(result, 0.0);
}
#[test]
fn test_dot_single() {
assert_eq!(dot(&[3.0], &[4.0]), 12.0);
}
#[test]
fn test_dot_exactly_16_elements() {
let a: Vec<f32> = (0..16).map(|i| i as f32).collect();
let b: Vec<f32> = (0..16).map(|i| (i + 1) as f32).collect();
let result = dot(&a, &b);
let expected: f32 = (0..16).map(|i| (i * (i + 1)) as f32).sum();
assert!(
(result - expected).abs() < 1e-3,
"dot at dim=16: got {result}, expected {expected}"
);
}
#[test]
fn test_cosine_exactly_16_elements() {
let a: Vec<f32> = (1..=16).map(|i| i as f32).collect();
let b: Vec<f32> = (1..=16).map(|i| i as f32 * 2.0).collect();
let result = cosine(&a, &b);
assert!(
(result - 1.0).abs() < 1e-5,
"cosine of parallel vectors at dim=16: got {result}"
);
}
#[test]
fn test_norm_exactly_16_elements() {
let v: Vec<f32> = (0..16).map(|i| i as f32).collect();
let result = norm(&v);
let expected = dot(&v, &v).sqrt();
assert!(
(result - expected).abs() < 1e-5,
"norm at dim=16: got {result}, expected {expected}"
);
}
#[test]
fn test_dot_large_values() {
let a = [1e18_f32, 1e18];
let b = [1e18_f32, 1e18];
let result = dot(&a, &b);
assert!(result.is_finite(), "dot with large values should be finite");
assert!(result > 0.0);
}
#[test]
fn test_norm_large_vector() {
let v = [1e19_f32, 1e19];
let result = norm(&v);
assert!(result.is_finite(), "norm of large vector should be finite");
assert!(result > 0.0);
}
#[test]
fn test_cosine_zero_vector_both() {
let zero = [0.0_f32, 0.0, 0.0];
let result = cosine(&zero, &zero);
assert_eq!(result, 0.0, "cosine of two zero vectors should be 0.0");
}
#[test]
fn test_norm_zero_vector() {
let zero = [0.0_f32, 0.0, 0.0];
assert_eq!(norm(&zero), 0.0);
}
#[test]
fn test_dot_all_negatives() {
let a = [-1.0_f32, -2.0, -3.0];
let b = [-4.0_f32, -5.0, -6.0];
let result = dot(&a, &b);
assert!((result - 32.0).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "innr::dot: slice length mismatch")]
fn dot_panics_on_length_mismatch() {
let _ = dot(&[1.0, 2.0], &[1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "innr::l1_distance: slice length mismatch")]
fn l1_distance_panics_on_length_mismatch() {
let _ = l1_distance(&[1.0], &[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "innr::l2_distance_squared: slice length mismatch")]
fn l2_distance_squared_panics_on_length_mismatch() {
let _ = l2_distance_squared(&[1.0], &[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "innr::l2_distance_squared: slice length mismatch")]
fn l2_distance_panics_on_length_mismatch() {
let _ = l2_distance(&[1.0], &[1.0, 2.0]);
}
#[test]
#[should_panic(expected = "innr::cosine: slice length mismatch")]
fn cosine_panics_on_length_mismatch() {
let _ = cosine(&[1.0, 2.0], &[1.0]);
}
#[test]
fn test_cosine_mixed_signs() {
let a = [1.0_f32, 2.0, 3.0];
let b = [-1.0_f32, -2.0, -3.0];
let result = cosine(&a, &b);
assert!(
(result - (-1.0)).abs() < 1e-5,
"cosine of antiparallel vectors: got {result}, expected -1.0"
);
}
#[test]
fn test_normalize_unit_norm() {
let mut v = vec![3.0_f32, 4.0];
normalize(&mut v);
let n = norm(&v);
assert!(
(n - 1.0).abs() < 1e-6,
"norm(normalize(v)) should be ~1.0, got {n}"
);
}
#[test]
fn test_normalize_direction_preserved() {
let mut v = vec![1.0_f32, 0.0, 0.0];
normalize(&mut v);
assert!((v[0] - 1.0).abs() < 1e-6);
assert!(v[1].abs() < 1e-6);
assert!(v[2].abs() < 1e-6);
}
#[test]
fn test_normalize_zero_vector_unchanged() {
let mut v = vec![0.0_f32, 0.0, 0.0];
normalize(&mut v);
assert_eq!(v, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_normalize_various_dims() {
for dim in [1, 8, 16, 64, 128] {
let mut v: Vec<f32> = (1..=dim).map(|i| i as f32).collect();
normalize(&mut v);
let n = norm(&v);
assert!(
(n - 1.0).abs() < 1e-5,
"norm after normalize should be ~1.0 for dim={dim}, got {n}"
);
}
}
#[test]
fn test_matryoshka_dot_equals_prefix_dot() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let b = vec![5.0_f32, 4.0, 3.0, 2.0, 1.0];
for prefix in [1, 2, 3, 4, 5] {
let mrd = matryoshka_dot(&a, &b, prefix);
let expected = dot(&a[..prefix], &b[..prefix]);
assert!(
(mrd - expected).abs() < 1e-6,
"matryoshka_dot(prefix={prefix}): got {mrd}, expected {expected}"
);
}
}
#[test]
fn test_matryoshka_dot_full_prefix_equals_dot() {
let a = vec![1.0_f32, 0.0, -1.0];
let b = vec![2.0_f32, 3.0, 4.0];
let full = matryoshka_dot(&a, &b, a.len());
let expected = dot(&a, &b);
assert!((full - expected).abs() < 1e-6);
}
#[test]
fn test_matryoshka_dot_prefix_longer_than_vec_clips() {
let a = vec![1.0_f32, 2.0];
let b = vec![3.0_f32, 4.0];
let result = matryoshka_dot(&a, &b, 100);
let expected = dot(&a, &b);
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_matryoshka_cosine_equals_prefix_cosine() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
let b = vec![4.0_f32, 3.0, 2.0, 1.0];
for prefix in [1, 2, 3, 4] {
let mrc = matryoshka_cosine(&a, &b, prefix);
let expected = cosine(&a[..prefix], &b[..prefix]);
assert!(
(mrc - expected).abs() < 1e-6,
"matryoshka_cosine(prefix={prefix}): got {mrc}, expected {expected}"
);
}
}
#[test]
fn test_matryoshka_cosine_full_prefix_equals_cosine() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
let full = matryoshka_cosine(&a, &b, 2);
let expected = cosine(&a, &b);
assert!((full - expected).abs() < 1e-6);
assert!(full.abs() < 1e-6);
}
#[test]
fn test_matryoshka_cosine_prefix_one() {
let a = vec![3.0_f32, -99.0, -99.0];
let b = vec![5.0_f32, 1.0, 1.0];
let result = matryoshka_cosine(&a, &b, 1);
assert!((result - 1.0).abs() < 1e-5, "got {result}");
}
#[test]
fn test_angular_distance_range() {
let pairs: &[(&[f32], &[f32])] = &[
(&[1.0, 0.0], &[0.0, 1.0]), (&[1.0, 0.0], &[1.0, 0.0]), (&[1.0, 0.0], &[-1.0, 0.0]), (&[1.0, 1.0], &[1.0, -1.0]), ];
for (a, b) in pairs {
let d = angular_distance(a, b);
assert!(
(0.0..=1.0).contains(&d),
"angular_distance out of [0,1]: {d}"
);
}
}
#[test]
fn test_angular_distance_orthogonal_is_half() {
let a = [1.0_f32, 0.0];
let b = [0.0_f32, 1.0];
let d = angular_distance(&a, &b);
assert!(
(d - 0.5).abs() < 1e-5,
"orthogonal angular_distance: got {d}"
);
}
#[test]
fn test_angular_distance_identical_is_zero() {
let a = [1.0_f32, 2.0, 3.0];
let d = angular_distance(&a, &a);
assert!(d < 1e-3, "identical vectors angular_distance: got {d}");
}
#[test]
fn test_angular_distance_opposite_is_one() {
let a = [1.0_f32, 0.0];
let b = [-1.0_f32, 0.0];
let d = angular_distance(&a, &b);
assert!(
(d - 1.0).abs() < 1e-5,
"opposite vectors angular_distance: got {d}"
);
}
#[test]
fn test_angular_distance_cosine_relationship() {
let a = [1.0_f32, 2.0, 3.0];
let b = [4.0_f32, -1.0, 2.0];
let c = cosine(&a, &b).clamp(-1.0, 1.0);
let expected = c.acos() / std::f32::consts::PI;
let result = angular_distance(&a, &b);
assert!(
(result - expected).abs() < 1e-6,
"angular_distance mismatch: got {result}, expected {expected}"
);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
const DIMS: &[usize] = &[1, 8, 15, 16, 17, 32, 64, 128, 768];
fn bounded_f32() -> impl Strategy<Value = f32> {
-1000.0_f32..1000.0_f32
}
fn bounded_vec(len: usize) -> impl Strategy<Value = Vec<f32>> {
prop::collection::vec(bounded_f32(), len)
}
fn simd_dim() -> impl Strategy<Value = usize> {
prop::sample::select(DIMS)
}
fn nonzero_vec(len: usize) -> impl Strategy<Value = Vec<f32>> {
prop::collection::vec(
prop_oneof![0.01_f32..100.0_f32, -100.0_f32..-0.01_f32,],
len,
)
}
fn dim_and_two_vecs() -> impl Strategy<Value = (usize, Vec<f32>, Vec<f32>)> {
simd_dim().prop_flat_map(|d| (Just(d), bounded_vec(d), bounded_vec(d)))
}
fn dim_and_vec() -> impl Strategy<Value = (usize, Vec<f32>)> {
simd_dim().prop_flat_map(|d| (Just(d), bounded_vec(d)))
}
fn dim_and_two_nonzero_vecs() -> impl Strategy<Value = (usize, Vec<f32>, Vec<f32>)> {
simd_dim().prop_flat_map(|d| (Just(d), nonzero_vec(d), nonzero_vec(d)))
}
fn dim_and_nonzero_vec() -> impl Strategy<Value = (usize, Vec<f32>)> {
simd_dim().prop_flat_map(|d| (Just(d), nonzero_vec(d)))
}
fn dim_and_three_vecs() -> impl Strategy<Value = (usize, Vec<f32>, Vec<f32>, Vec<f32>)> {
simd_dim().prop_flat_map(|d| (Just(d), bounded_vec(d), bounded_vec(d), bounded_vec(d)))
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(300))]
#[test]
fn dot_commutative((dim, a, b) in dim_and_two_vecs()) {
let ab = dot(&a, &b);
let ba = dot(&b, &a);
let tol = 1e-4 * ab.abs().max(ba.abs()).max(1.0);
prop_assert!(
(ab - ba).abs() <= tol,
"dot commutativity failed: dot(a,b)={ab}, dot(b,a)={ba}, dim={dim}"
);
}
#[test]
fn norm_nonnegative((dim, v) in dim_and_vec()) {
let n = norm(&v);
prop_assert!(
n >= 0.0,
"norm must be >= 0, got {n} for dim={dim}"
);
prop_assert!(
n.is_finite(),
"norm must be finite for bounded input, got {n} for dim={dim}"
);
}
#[test]
fn cosine_range((dim, a, b) in dim_and_two_nonzero_vecs()) {
let c = cosine(&a, &b);
prop_assert!(
(-1.0 - 1e-5..=1.0 + 1e-5).contains(&c),
"cosine out of range: {c}, dim={dim}"
);
}
#[test]
fn cosine_self_similarity((dim, v) in dim_and_nonzero_vec()) {
let c = cosine(&v, &v);
prop_assert!(
(c - 1.0).abs() < 1e-4,
"cosine(v, v) should be ~1.0, got {c}, dim={dim}"
);
}
#[test]
fn l2_self_distance_zero((dim, v) in dim_and_vec()) {
let d = l2_distance(&v, &v);
prop_assert!(
d.abs() < 1e-5,
"l2_distance(v, v) should be ~0, got {d}, dim={dim}"
);
}
#[test]
fn l2_triangle_inequality((dim, a, b, c) in dim_and_three_vecs()) {
let ab = l2_distance(&a, &b);
let bc = l2_distance(&b, &c);
let ac = l2_distance(&a, &c);
let eps = 1e-4 * (ab + bc).max(1.0);
prop_assert!(
ac <= ab + bc + eps,
"triangle inequality violated: d(a,c)={ac} > d(a,b)={ab} + d(b,c)={bc}, dim={dim}"
);
}
#[test]
fn l2_direct_vs_expansion((dim, a, b) in dim_and_two_vecs()) {
let direct = l2_distance_squared(&a, &b);
let aa: f32 = a.iter().map(|x| x * x).sum();
let bb: f32 = b.iter().map(|x| x * x).sum();
let ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let expansion = (aa + bb - 2.0 * ab).max(0.0);
prop_assert!(direct >= 0.0, "direct L2² should be >= 0, got {direct}");
let scale = (aa + bb).max(1.0);
let tol = 1e-3 * scale;
prop_assert!(
(direct - expansion).abs() <= tol,
"direct={direct} vs expansion={expansion}, diff={}, scale={scale}, dim={dim}",
(direct - expansion).abs()
);
}
#[test]
fn l1_commutative((dim, a, b) in dim_and_two_vecs()) {
let ab = l1_distance(&a, &b);
let ba = l1_distance(&b, &a);
let tol = 1e-4 * ab.abs().max(ba.abs()).max(1.0);
prop_assert!(
(ab - ba).abs() <= tol,
"L1 commutativity failed: l1(a,b)={ab}, l1(b,a)={ba}, dim={dim}"
);
}
#[test]
fn l1_nonnegative((dim, a, b) in dim_and_two_vecs()) {
let d = l1_distance(&a, &b);
prop_assert!(
d >= 0.0,
"L1 must be >= 0, got {d} for dim={dim}"
);
}
#[test]
fn l1_self_distance_zero((dim, v) in dim_and_vec()) {
let d = l1_distance(&v, &v);
prop_assert!(
d.abs() < 1e-5,
"l1_distance(v, v) should be ~0, got {d}, dim={dim}"
);
}
#[test]
fn dot_matches_f64_reference((dim, a, b) in dim_and_two_vecs()) {
let f64_ref: f64 = a.iter().zip(&b).map(|(&x, &y)| x as f64 * y as f64).sum();
let f32_result = dot(&a, &b);
let abs_product_sum: f64 = a.iter().zip(&b).map(|(&x, &y)| (x as f64 * y as f64).abs()).sum();
let tol = (dim as f64) * f64::from(f32::EPSILON) * abs_product_sum.max(1.0);
let diff = (f64::from(f32_result) - f64_ref).abs();
prop_assert!(
diff <= tol,
"f64 reference mismatch: f32={f32_result}, f64_ref={f64_ref}, diff={diff}, tol={tol}, dim={dim}"
);
}
#[test]
fn l2_close_vectors_accuracy((dim, a) in dim_and_vec()) {
let b: Vec<f32> = a.iter().map(|x| x + 1e-4).collect();
let direct = l2_distance_squared(&a, &b);
let reference = l2_distance_squared_portable(&a, &b);
let tol = 1e-4 * reference.max(1e-6);
prop_assert!(
(direct - reference).abs() <= tol,
"close vectors: direct={direct} vs reference={reference}, dim={dim}"
);
}
}
}