1use anndists::prelude::Distance;
20
21#[derive(Clone, Copy, Debug, Default)]
23pub struct SimdL2;
24
25#[derive(Clone, Copy, Debug, Default)]
27pub struct SimdDot;
28
29#[derive(Clone, Copy, Debug, Default)]
31pub struct SimdCosine;
32
33#[inline]
38fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
39 a.iter()
40 .zip(b.iter())
41 .map(|(x, y)| {
42 let d = x - y;
43 d * d
44 })
45 .sum()
46}
47
48#[inline]
49fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
50 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
51}
52
53#[inline]
54fn norm_squared_scalar(a: &[f32]) -> f32 {
55 a.iter().map(|x| x * x).sum()
56}
57
58#[cfg(target_arch = "x86_64")]
63mod x86_simd {
64 #[cfg(target_arch = "x86_64")]
65 use std::arch::x86_64::*;
66
67 #[inline]
69 pub fn has_avx2() -> bool {
70 is_x86_feature_detected!("avx2")
71 }
72
73 #[inline]
75 pub fn has_sse41() -> bool {
76 is_x86_feature_detected!("sse4.1")
77 }
78
79 #[target_feature(enable = "avx2")]
81 #[inline]
82 pub unsafe fn l2_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
83 debug_assert_eq!(a.len(), b.len());
84 let n = a.len();
85
86 let mut sum = _mm256_setzero_ps();
87 let mut i = 0;
88
89 while i + 8 <= n {
91 let va = _mm256_loadu_ps(a.as_ptr().add(i));
92 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
93 let diff = _mm256_sub_ps(va, vb);
94 sum = _mm256_fmadd_ps(diff, diff, sum);
95 i += 8;
96 }
97
98 let high = _mm256_extractf128_ps(sum, 1);
100 let low = _mm256_castps256_ps128(sum);
101 let sum128 = _mm_add_ps(high, low);
102 let shuf = _mm_movehdup_ps(sum128);
103 let sums = _mm_add_ps(sum128, shuf);
104 let shuf2 = _mm_movehl_ps(sums, sums);
105 let final_sum = _mm_add_ss(sums, shuf2);
106 let mut result = _mm_cvtss_f32(final_sum);
107
108 while i < n {
110 let d = a[i] - b[i];
111 result += d * d;
112 i += 1;
113 }
114
115 result
116 }
117
118 #[target_feature(enable = "avx2")]
120 #[inline]
121 pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
122 debug_assert_eq!(a.len(), b.len());
123 let n = a.len();
124
125 let mut sum = _mm256_setzero_ps();
126 let mut i = 0;
127
128 while i + 8 <= n {
129 let va = _mm256_loadu_ps(a.as_ptr().add(i));
130 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
131 sum = _mm256_fmadd_ps(va, vb, sum);
132 i += 8;
133 }
134
135 let high = _mm256_extractf128_ps(sum, 1);
137 let low = _mm256_castps256_ps128(sum);
138 let sum128 = _mm_add_ps(high, low);
139 let shuf = _mm_movehdup_ps(sum128);
140 let sums = _mm_add_ps(sum128, shuf);
141 let shuf2 = _mm_movehl_ps(sums, sums);
142 let final_sum = _mm_add_ss(sums, shuf2);
143 let mut result = _mm_cvtss_f32(final_sum);
144
145 while i < n {
146 result += a[i] * b[i];
147 i += 1;
148 }
149
150 result
151 }
152
153 #[target_feature(enable = "avx2")]
155 #[inline]
156 pub unsafe fn norm_squared_avx2(a: &[f32]) -> f32 {
157 let n = a.len();
158 let mut sum = _mm256_setzero_ps();
159 let mut i = 0;
160
161 while i + 8 <= n {
162 let va = _mm256_loadu_ps(a.as_ptr().add(i));
163 sum = _mm256_fmadd_ps(va, va, sum);
164 i += 8;
165 }
166
167 let high = _mm256_extractf128_ps(sum, 1);
168 let low = _mm256_castps256_ps128(sum);
169 let sum128 = _mm_add_ps(high, low);
170 let shuf = _mm_movehdup_ps(sum128);
171 let sums = _mm_add_ps(sum128, shuf);
172 let shuf2 = _mm_movehl_ps(sums, sums);
173 let final_sum = _mm_add_ss(sums, shuf2);
174 let mut result = _mm_cvtss_f32(final_sum);
175
176 while i < n {
177 result += a[i] * a[i];
178 i += 1;
179 }
180
181 result
182 }
183
184 #[target_feature(enable = "sse4.1")]
186 #[inline]
187 pub unsafe fn l2_squared_sse41(a: &[f32], b: &[f32]) -> f32 {
188 debug_assert_eq!(a.len(), b.len());
189 let n = a.len();
190
191 let mut sum = _mm_setzero_ps();
192 let mut i = 0;
193
194 while i + 4 <= n {
195 let va = _mm_loadu_ps(a.as_ptr().add(i));
196 let vb = _mm_loadu_ps(b.as_ptr().add(i));
197 let diff = _mm_sub_ps(va, vb);
198 let sq = _mm_mul_ps(diff, diff);
199 sum = _mm_add_ps(sum, sq);
200 i += 4;
201 }
202
203 let shuf = _mm_movehdup_ps(sum);
205 let sums = _mm_add_ps(sum, shuf);
206 let shuf2 = _mm_movehl_ps(sums, sums);
207 let final_sum = _mm_add_ss(sums, shuf2);
208 let mut result = _mm_cvtss_f32(final_sum);
209
210 while i < n {
211 let d = a[i] - b[i];
212 result += d * d;
213 i += 1;
214 }
215
216 result
217 }
218
219 #[target_feature(enable = "sse4.1")]
221 #[inline]
222 pub unsafe fn dot_product_sse41(a: &[f32], b: &[f32]) -> f32 {
223 debug_assert_eq!(a.len(), b.len());
224 let n = a.len();
225
226 let mut sum = _mm_setzero_ps();
227 let mut i = 0;
228
229 while i + 4 <= n {
230 let va = _mm_loadu_ps(a.as_ptr().add(i));
231 let vb = _mm_loadu_ps(b.as_ptr().add(i));
232 let prod = _mm_mul_ps(va, vb);
233 sum = _mm_add_ps(sum, prod);
234 i += 4;
235 }
236
237 let shuf = _mm_movehdup_ps(sum);
238 let sums = _mm_add_ps(sum, shuf);
239 let shuf2 = _mm_movehl_ps(sums, sums);
240 let final_sum = _mm_add_ss(sums, shuf2);
241 let mut result = _mm_cvtss_f32(final_sum);
242
243 while i < n {
244 result += a[i] * b[i];
245 i += 1;
246 }
247
248 result
249 }
250}
251
252#[cfg(target_arch = "aarch64")]
257mod neon_simd {
258 use std::arch::aarch64::*;
259
260 #[inline]
262 pub fn l2_squared_neon(a: &[f32], b: &[f32]) -> f32 {
263 debug_assert_eq!(a.len(), b.len());
264 let n = a.len();
265
266 unsafe {
268 let mut sum = vdupq_n_f32(0.0);
269 let mut i = 0;
270
271 while i + 4 <= n {
272 let va = vld1q_f32(a.as_ptr().add(i));
273 let vb = vld1q_f32(b.as_ptr().add(i));
274 let diff = vsubq_f32(va, vb);
275 sum = vfmaq_f32(sum, diff, diff);
276 i += 4;
277 }
278
279 let mut result = vaddvq_f32(sum);
281
282 while i < n {
283 let d = a[i] - b[i];
284 result += d * d;
285 i += 1;
286 }
287
288 result
289 }
290 }
291
292 #[inline]
294 pub fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
295 debug_assert_eq!(a.len(), b.len());
296 let n = a.len();
297
298 unsafe {
300 let mut sum = vdupq_n_f32(0.0);
301 let mut i = 0;
302
303 while i + 4 <= n {
304 let va = vld1q_f32(a.as_ptr().add(i));
305 let vb = vld1q_f32(b.as_ptr().add(i));
306 sum = vfmaq_f32(sum, va, vb);
307 i += 4;
308 }
309
310 let mut result = vaddvq_f32(sum);
311
312 while i < n {
313 result += a[i] * b[i];
314 i += 1;
315 }
316
317 result
318 }
319 }
320
321 #[inline]
323 pub fn norm_squared_neon(a: &[f32]) -> f32 {
324 let n = a.len();
325
326 unsafe {
328 let mut sum = vdupq_n_f32(0.0);
329 let mut i = 0;
330
331 while i + 4 <= n {
332 let va = vld1q_f32(a.as_ptr().add(i));
333 sum = vfmaq_f32(sum, va, va);
334 i += 4;
335 }
336
337 let mut result = vaddvq_f32(sum);
338
339 while i < n {
340 result += a[i] * a[i];
341 i += 1;
342 }
343
344 result
345 }
346 }
347}
348
349#[inline]
355pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
356 #[cfg(target_arch = "x86_64")]
357 {
358 if x86_simd::has_avx2() {
359 return unsafe { x86_simd::l2_squared_avx2(a, b) };
360 }
361 if x86_simd::has_sse41() {
362 return unsafe { x86_simd::l2_squared_sse41(a, b) };
363 }
364 }
365
366 #[cfg(target_arch = "aarch64")]
367 {
368 return neon_simd::l2_squared_neon(a, b);
369 }
370
371 #[allow(unreachable_code)]
372 l2_squared_scalar(a, b)
373}
374
375#[inline]
377pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
378 #[cfg(target_arch = "x86_64")]
379 {
380 if x86_simd::has_avx2() {
381 return unsafe { x86_simd::dot_product_avx2(a, b) };
382 }
383 if x86_simd::has_sse41() {
384 return unsafe { x86_simd::dot_product_sse41(a, b) };
385 }
386 }
387
388 #[cfg(target_arch = "aarch64")]
389 {
390 return neon_simd::dot_product_neon(a, b);
391 }
392
393 #[allow(unreachable_code)]
394 dot_product_scalar(a, b)
395}
396
397#[inline]
399pub fn norm_squared(a: &[f32]) -> f32 {
400 #[cfg(target_arch = "x86_64")]
401 {
402 if x86_simd::has_avx2() {
403 return unsafe { x86_simd::norm_squared_avx2(a) };
404 }
405 }
406
407 #[cfg(target_arch = "aarch64")]
408 {
409 return neon_simd::norm_squared_neon(a);
410 }
411
412 #[allow(unreachable_code)]
413 norm_squared_scalar(a)
414}
415
416#[inline]
419pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
420 let dot = dot_product(a, b);
421 let norm_a = norm_squared(a).sqrt();
422 let norm_b = norm_squared(b).sqrt();
423
424 if norm_a == 0.0 || norm_b == 0.0 {
425 return 1.0;
426 }
427
428 let cosine_sim = dot / (norm_a * norm_b);
429 1.0 - cosine_sim.clamp(-1.0, 1.0)
430}
431
432impl Distance<f32> for SimdL2 {
437 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
438 l2_squared(a, b)
439 }
440}
441
442impl Distance<f32> for SimdDot {
443 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
444 1.0 - dot_product(a, b)
447 }
448}
449
450impl Distance<f32> for SimdCosine {
451 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
452 cosine_distance(a, b)
453 }
454}
455
456pub fn simd_info() -> SimdInfo {
462 SimdInfo {
463 #[cfg(target_arch = "x86_64")]
464 avx2: x86_simd::has_avx2(),
465 #[cfg(not(target_arch = "x86_64"))]
466 avx2: false,
467
468 #[cfg(target_arch = "x86_64")]
469 sse41: x86_simd::has_sse41(),
470 #[cfg(not(target_arch = "x86_64"))]
471 sse41: false,
472
473 #[cfg(target_arch = "aarch64")]
474 neon: true,
475 #[cfg(not(target_arch = "aarch64"))]
476 neon: false,
477 }
478}
479
480#[derive(Debug, Clone)]
482pub struct SimdInfo {
483 pub avx2: bool,
484 pub sse41: bool,
485 pub neon: bool,
486}
487
488impl std::fmt::Display for SimdInfo {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 let mut features = Vec::new();
491 if self.avx2 {
492 features.push("AVX2");
493 }
494 if self.sse41 {
495 features.push("SSE4.1");
496 }
497 if self.neon {
498 features.push("NEON");
499 }
500 if features.is_empty() {
501 write!(f, "SIMD: none (scalar fallback)")
502 } else {
503 write!(f, "SIMD: {}", features.join(", "))
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_l2_squared_basic() {
514 let a = vec![1.0, 2.0, 3.0, 4.0];
515 let b = vec![5.0, 6.0, 7.0, 8.0];
516
517 let expected: f32 = a
518 .iter()
519 .zip(&b)
520 .map(|(x, y)| (x - y) * (x - y))
521 .sum();
522
523 let result = l2_squared(&a, &b);
524 assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
525 }
526
527 #[test]
528 fn test_l2_squared_large() {
529 let dim = 133; let a: Vec<f32> = (0..dim).map(|i| i as f32).collect();
532 let b: Vec<f32> = (0..dim).map(|i| (i * 2) as f32).collect();
533
534 let expected = l2_squared_scalar(&a, &b);
535 let result = l2_squared(&a, &b);
536
537 assert!(
538 (result - expected).abs() < 1e-3,
539 "expected {expected}, got {result}"
540 );
541 }
542
543 #[test]
544 fn test_dot_product_basic() {
545 let a = vec![1.0, 2.0, 3.0, 4.0];
546 let b = vec![5.0, 6.0, 7.0, 8.0];
547
548 let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
549 let result = dot_product(&a, &b);
550
551 assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
552 }
553
554 #[test]
555 fn test_dot_product_large() {
556 let dim = 128;
557 let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
558 let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
559
560 let expected = dot_product_scalar(&a, &b);
561 let result = dot_product(&a, &b);
562
563 assert!(
564 (result - expected).abs() < 1e-3,
565 "expected {expected}, got {result}"
566 );
567 }
568
569 #[test]
570 fn test_cosine_identical() {
571 let a = vec![1.0, 2.0, 3.0, 4.0];
572 let result = cosine_distance(&a, &a);
573 assert!(result.abs() < 1e-5, "identical vectors should have distance ~0, got {result}");
574 }
575
576 #[test]
577 fn test_cosine_orthogonal() {
578 let a = vec![1.0, 0.0];
579 let b = vec![0.0, 1.0];
580 let result = cosine_distance(&a, &b);
581 assert!((result - 1.0).abs() < 1e-5, "orthogonal vectors should have distance ~1, got {result}");
582 }
583
584 #[test]
585 fn test_cosine_opposite() {
586 let a = vec![1.0, 2.0, 3.0];
587 let b: Vec<f32> = a.iter().map(|x| -x).collect();
588 let result = cosine_distance(&a, &b);
589 assert!((result - 2.0).abs() < 1e-5, "opposite vectors should have distance ~2, got {result}");
590 }
591
592 #[test]
593 fn test_simd_info() {
594 let info = simd_info();
595 println!("{}", info);
596 }
598
599 #[test]
600 fn test_distance_trait_impl() {
601 let a = vec![1.0, 2.0, 3.0, 4.0];
602 let b = vec![5.0, 6.0, 7.0, 8.0];
603
604 let l2 = SimdL2;
605 let result = l2.eval(&a, &b);
606 assert!(result > 0.0);
607
608 let cosine = SimdCosine;
609 let result = cosine.eval(&a, &b);
610 assert!(result >= 0.0 && result <= 2.0);
611 }
612}