1#[cfg(feature = "simd")]
7use wide::f32x8;
8
9#[cfg(feature = "simd")]
14#[must_use]
15pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
16 debug_assert_eq!(a.len(), b.len(), "vectors must have same length");
17
18 let len = a.len();
19 let chunks = len / 8;
20 let remainder = len % 8;
21
22 let mut sum = f32x8::ZERO;
23
24 for i in 0..chunks {
26 let offset = i * 8;
27 let a_chunk = f32x8::new([
28 a[offset],
29 a[offset + 1],
30 a[offset + 2],
31 a[offset + 3],
32 a[offset + 4],
33 a[offset + 5],
34 a[offset + 6],
35 a[offset + 7],
36 ]);
37 let b_chunk = f32x8::new([
38 b[offset],
39 b[offset + 1],
40 b[offset + 2],
41 b[offset + 3],
42 b[offset + 4],
43 b[offset + 5],
44 b[offset + 6],
45 b[offset + 7],
46 ]);
47 let diff = a_chunk - b_chunk;
48 sum += diff * diff;
49 }
50
51 let sum_array: [f32; 8] = sum.into();
53 let mut total: f32 = sum_array.iter().sum();
54
55 let offset = chunks * 8;
57 for i in 0..remainder {
58 let diff = a[offset + i] - b[offset + i];
59 total += diff * diff;
60 }
61
62 total
63}
64
65#[cfg(feature = "simd")]
67#[must_use]
68pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
69 l2_distance_squared_simd(a, b).sqrt()
70}
71
72#[cfg(not(feature = "simd"))]
76pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
77 a.iter()
78 .zip(b.iter())
79 .map(|(x, y)| {
80 let diff = x - y;
81 diff * diff
82 })
83 .sum()
84}
85
86#[cfg(not(feature = "simd"))]
88pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
89 l2_distance_squared_simd(a, b).sqrt()
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_l2_distance_squared_basic() {
98 let a = [0.0, 0.0, 0.0];
99 let b = [3.0, 4.0, 0.0];
100 let dist_sq = l2_distance_squared_simd(&a, &b);
101 assert!(
102 (dist_sq - 25.0).abs() < 1e-6,
103 "expected 25.0, got {}",
104 dist_sq
105 );
106 }
107
108 #[test]
109 fn test_l2_distance_basic() {
110 let a = [0.0, 0.0];
111 let b = [3.0, 4.0];
112 let dist = l2_distance_simd(&a, &b);
113 assert!((dist - 5.0).abs() < 1e-6, "expected 5.0, got {}", dist);
114 }
115
116 #[test]
117 fn test_l2_distance_384_dims() {
118 let a: Vec<f32> = (0..384).map(|i| i as f32 * 0.01).collect();
120 let b: Vec<f32> = (0..384).map(|i| (i + 1) as f32 * 0.01).collect();
121
122 let dist_simd = l2_distance_simd(&a, &b);
123
124 let dist_scalar: f32 = a
126 .iter()
127 .zip(b.iter())
128 .map(|(x, y)| (x - y).powi(2))
129 .sum::<f32>()
130 .sqrt();
131
132 assert!(
133 (dist_simd - dist_scalar).abs() < 1e-4,
134 "SIMD {} vs Scalar {}",
135 dist_simd,
136 dist_scalar
137 );
138 }
139}