1use anndists::prelude::Distance;
20
21#[derive(Clone, Copy, Debug, Default)]
23pub struct SimdL2;
24
25#[derive(Clone, Copy, Debug, Default)]
27pub struct SimdDot;
28
29#[derive(Clone, Copy, Debug, Default)]
31pub struct SimdCosine;
32
33#[inline]
38fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
39 a.iter()
40 .zip(b.iter())
41 .map(|(x, y)| {
42 let d = x - y;
43 d * d
44 })
45 .sum()
46}
47
48#[inline]
49fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
50 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
51}
52
53#[inline]
54fn norm_squared_scalar(a: &[f32]) -> f32 {
55 a.iter().map(|x| x * x).sum()
56}
57
58#[cfg(target_arch = "x86_64")]
63mod x86_simd {
64 #[cfg(target_arch = "x86_64")]
65 use std::arch::x86_64::*;
66
67 #[inline]
69 pub fn has_avx2() -> bool {
70 is_x86_feature_detected!("avx2")
71 }
72
73 #[inline]
75 pub fn has_sse41() -> bool {
76 is_x86_feature_detected!("sse4.1")
77 }
78
79 #[target_feature(enable = "avx2")]
81 #[inline]
82 pub unsafe fn l2_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
83 debug_assert_eq!(a.len(), b.len());
84 let n = a.len();
85
86 let mut sum = _mm256_setzero_ps();
87 let mut i = 0;
88
89 while i + 8 <= n {
91 let va = _mm256_loadu_ps(a.as_ptr().add(i));
92 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
93 let diff = _mm256_sub_ps(va, vb);
94 sum = _mm256_fmadd_ps(diff, diff, sum);
95 i += 8;
96 }
97
98 let high = _mm256_extractf128_ps(sum, 1);
100 let low = _mm256_castps256_ps128(sum);
101 let sum128 = _mm_add_ps(high, low);
102 let shuf = _mm_movehdup_ps(sum128);
103 let sums = _mm_add_ps(sum128, shuf);
104 let shuf2 = _mm_movehl_ps(sums, sums);
105 let final_sum = _mm_add_ss(sums, shuf2);
106 let mut result = _mm_cvtss_f32(final_sum);
107
108 while i < n {
110 let d = a[i] - b[i];
111 result += d * d;
112 i += 1;
113 }
114
115 result
116 }
117
118 #[target_feature(enable = "avx2")]
120 #[inline]
121 pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
122 debug_assert_eq!(a.len(), b.len());
123 let n = a.len();
124
125 let mut sum = _mm256_setzero_ps();
126 let mut i = 0;
127
128 while i + 8 <= n {
129 let va = _mm256_loadu_ps(a.as_ptr().add(i));
130 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
131 sum = _mm256_fmadd_ps(va, vb, sum);
132 i += 8;
133 }
134
135 let high = _mm256_extractf128_ps(sum, 1);
137 let low = _mm256_castps256_ps128(sum);
138 let sum128 = _mm_add_ps(high, low);
139 let shuf = _mm_movehdup_ps(sum128);
140 let sums = _mm_add_ps(sum128, shuf);
141 let shuf2 = _mm_movehl_ps(sums, sums);
142 let final_sum = _mm_add_ss(sums, shuf2);
143 let mut result = _mm_cvtss_f32(final_sum);
144
145 while i < n {
146 result += a[i] * b[i];
147 i += 1;
148 }
149
150 result
151 }
152
153 #[target_feature(enable = "avx2")]
155 #[inline]
156 pub unsafe fn norm_squared_avx2(a: &[f32]) -> f32 {
157 let n = a.len();
158 let mut sum = _mm256_setzero_ps();
159 let mut i = 0;
160
161 while i + 8 <= n {
162 let va = _mm256_loadu_ps(a.as_ptr().add(i));
163 sum = _mm256_fmadd_ps(va, va, sum);
164 i += 8;
165 }
166
167 let high = _mm256_extractf128_ps(sum, 1);
168 let low = _mm256_castps256_ps128(sum);
169 let sum128 = _mm_add_ps(high, low);
170 let shuf = _mm_movehdup_ps(sum128);
171 let sums = _mm_add_ps(sum128, shuf);
172 let shuf2 = _mm_movehl_ps(sums, sums);
173 let final_sum = _mm_add_ss(sums, shuf2);
174 let mut result = _mm_cvtss_f32(final_sum);
175
176 while i < n {
177 result += a[i] * a[i];
178 i += 1;
179 }
180
181 result
182 }
183
184 #[target_feature(enable = "sse4.1")]
186 #[inline]
187 pub unsafe fn l2_squared_sse41(a: &[f32], b: &[f32]) -> f32 {
188 debug_assert_eq!(a.len(), b.len());
189 let n = a.len();
190
191 let mut sum = _mm_setzero_ps();
192 let mut i = 0;
193
194 while i + 4 <= n {
195 let va = _mm_loadu_ps(a.as_ptr().add(i));
196 let vb = _mm_loadu_ps(b.as_ptr().add(i));
197 let diff = _mm_sub_ps(va, vb);
198 let sq = _mm_mul_ps(diff, diff);
199 sum = _mm_add_ps(sum, sq);
200 i += 4;
201 }
202
203 let shuf = _mm_movehdup_ps(sum);
205 let sums = _mm_add_ps(sum, shuf);
206 let shuf2 = _mm_movehl_ps(sums, sums);
207 let final_sum = _mm_add_ss(sums, shuf2);
208 let mut result = _mm_cvtss_f32(final_sum);
209
210 while i < n {
211 let d = a[i] - b[i];
212 result += d * d;
213 i += 1;
214 }
215
216 result
217 }
218
219 #[target_feature(enable = "sse4.1")]
221 #[inline]
222 pub unsafe fn dot_product_sse41(a: &[f32], b: &[f32]) -> f32 {
223 debug_assert_eq!(a.len(), b.len());
224 let n = a.len();
225
226 let mut sum = _mm_setzero_ps();
227 let mut i = 0;
228
229 while i + 4 <= n {
230 let va = _mm_loadu_ps(a.as_ptr().add(i));
231 let vb = _mm_loadu_ps(b.as_ptr().add(i));
232 let prod = _mm_mul_ps(va, vb);
233 sum = _mm_add_ps(sum, prod);
234 i += 4;
235 }
236
237 let shuf = _mm_movehdup_ps(sum);
238 let sums = _mm_add_ps(sum, shuf);
239 let shuf2 = _mm_movehl_ps(sums, sums);
240 let final_sum = _mm_add_ss(sums, shuf2);
241 let mut result = _mm_cvtss_f32(final_sum);
242
243 while i < n {
244 result += a[i] * b[i];
245 i += 1;
246 }
247
248 result
249 }
250}
251
252#[cfg(target_arch = "aarch64")]
257mod neon_simd {
258 use std::arch::aarch64::*;
259
260 #[inline]
262 pub fn l2_squared_neon(a: &[f32], b: &[f32]) -> f32 {
263 debug_assert_eq!(a.len(), b.len());
264 let n = a.len();
265
266 unsafe {
268 let mut sum = vdupq_n_f32(0.0);
269 let mut i = 0;
270
271 while i + 4 <= n {
272 let va = vld1q_f32(a.as_ptr().add(i));
273 let vb = vld1q_f32(b.as_ptr().add(i));
274 let diff = vsubq_f32(va, vb);
275 sum = vfmaq_f32(sum, diff, diff);
276 i += 4;
277 }
278
279 let mut result = vaddvq_f32(sum);
281
282 while i < n {
283 let d = a[i] - b[i];
284 result += d * d;
285 i += 1;
286 }
287
288 result
289 }
290 }
291
292 #[inline]
294 pub fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
295 debug_assert_eq!(a.len(), b.len());
296 let n = a.len();
297
298 unsafe {
300 let mut sum = vdupq_n_f32(0.0);
301 let mut i = 0;
302
303 while i + 4 <= n {
304 let va = vld1q_f32(a.as_ptr().add(i));
305 let vb = vld1q_f32(b.as_ptr().add(i));
306 sum = vfmaq_f32(sum, va, vb);
307 i += 4;
308 }
309
310 let mut result = vaddvq_f32(sum);
311
312 while i < n {
313 result += a[i] * b[i];
314 i += 1;
315 }
316
317 result
318 }
319 }
320
321 #[inline]
323 pub fn norm_squared_neon(a: &[f32]) -> f32 {
324 let n = a.len();
325
326 unsafe {
328 let mut sum = vdupq_n_f32(0.0);
329 let mut i = 0;
330
331 while i + 4 <= n {
332 let va = vld1q_f32(a.as_ptr().add(i));
333 sum = vfmaq_f32(sum, va, va);
334 i += 4;
335 }
336
337 let mut result = vaddvq_f32(sum);
338
339 while i < n {
340 result += a[i] * a[i];
341 i += 1;
342 }
343
344 result
345 }
346 }
347}
348
349#[cfg(target_arch = "aarch64")]
354mod neon_quant {
355 use std::arch::aarch64::*;
356
357 #[inline]
361 pub fn f16_to_f32_bulk_neon(input: &[u16], output: &mut [f32]) {
362 debug_assert_eq!(input.len(), output.len());
363 for (i, &bits) in input.iter().enumerate() {
364 output[i] = half::f16::from_bits(bits).to_f32();
365 }
366 }
367
368 #[inline]
371 pub fn l2_f16_vs_f32_neon(f16_data: &[u16], query: &[f32]) -> f32 {
372 debug_assert_eq!(f16_data.len(), query.len());
373 let n = f16_data.len();
374
375 let mut db = vec![0.0f32; n];
377 for (i, &bits) in f16_data.iter().enumerate() {
378 db[i] = half::f16::from_bits(bits).to_f32();
379 }
380
381 super::neon_simd::l2_squared_neon(&db, query)
383 }
384
385 #[inline]
388 pub fn l2_u8_scaled_vs_f32_neon(
389 u8_data: &[u8],
390 query: &[f32],
391 scales: &[f32],
392 offsets: &[f32],
393 ) -> f32 {
394 debug_assert_eq!(u8_data.len(), query.len());
395 debug_assert_eq!(scales.len(), query.len());
396 debug_assert_eq!(offsets.len(), query.len());
397 let n = u8_data.len();
398 let mut i = 0;
399
400 unsafe {
401 let mut sum = vdupq_n_f32(0.0);
402
403 while i + 4 <= n {
404 let b0 = u8_data[i] as f32;
406 let b1 = u8_data[i + 1] as f32;
407 let b2 = u8_data[i + 2] as f32;
408 let b3 = u8_data[i + 3] as f32;
409 let vals = [b0, b1, b2, b3];
410 let vu8 = vld1q_f32(vals.as_ptr());
411
412 let vscale = vld1q_f32(scales.as_ptr().add(i));
413 let voff = vld1q_f32(offsets.as_ptr().add(i));
414 let vq = vld1q_f32(query.as_ptr().add(i));
415
416 let dequant = vfmaq_f32(voff, vu8, vscale);
418 let diff = vsubq_f32(dequant, vq);
419 sum = vfmaq_f32(sum, diff, diff);
420 i += 4;
421 }
422
423 let mut result = vaddvq_f32(sum);
424
425 while i < n {
426 let dequant = u8_data[i] as f32 * scales[i] + offsets[i];
427 let d = dequant - query[i];
428 result += d * d;
429 i += 1;
430 }
431
432 result
433 }
434 }
435}
436
437#[cfg(target_arch = "x86_64")]
442mod x86_quant {
443 use std::arch::x86_64::*;
444
445 #[inline]
446 pub fn has_f16c() -> bool {
447 is_x86_feature_detected!("f16c")
448 }
449
450 #[target_feature(enable = "f16c")]
452 #[inline]
453 pub unsafe fn f16_to_f32_bulk_f16c(input: &[u16], output: &mut [f32]) {
454 debug_assert_eq!(input.len(), output.len());
455 let n = input.len();
456 let mut i = 0;
457
458 while i + 8 <= n {
459 let half8 = _mm_loadu_si128(input.as_ptr().add(i) as *const __m128i);
460 let f8 = _mm256_cvtph_ps(half8);
461 _mm256_storeu_ps(output.as_mut_ptr().add(i), f8);
462 i += 8;
463 }
464
465 while i < n {
466 output[i] = half::f16::from_bits(input[i]).to_f32();
467 i += 1;
468 }
469 }
470
471 #[target_feature(enable = "f16c", enable = "avx2")]
473 #[inline]
474 pub unsafe fn l2_f16_vs_f32_f16c(f16_data: &[u16], query: &[f32]) -> f32 {
475 debug_assert_eq!(f16_data.len(), query.len());
476 let n = f16_data.len();
477 let mut i = 0;
478 let mut sum = _mm256_setzero_ps();
479
480 while i + 8 <= n {
481 let half8 = _mm_loadu_si128(f16_data.as_ptr().add(i) as *const __m128i);
482 let db = _mm256_cvtph_ps(half8);
483 let q = _mm256_loadu_ps(query.as_ptr().add(i));
484 let diff = _mm256_sub_ps(db, q);
485 sum = _mm256_fmadd_ps(diff, diff, sum);
486 i += 8;
487 }
488
489 let high = _mm256_extractf128_ps(sum, 1);
491 let low = _mm256_castps256_ps128(sum);
492 let sum128 = _mm_add_ps(high, low);
493 let shuf = _mm_movehdup_ps(sum128);
494 let sums = _mm_add_ps(sum128, shuf);
495 let shuf2 = _mm_movehl_ps(sums, sums);
496 let final_sum = _mm_add_ss(sums, shuf2);
497 let mut result = _mm_cvtss_f32(final_sum);
498
499 while i < n {
500 let f = half::f16::from_bits(f16_data[i]).to_f32();
501 let d = f - query[i];
502 result += d * d;
503 i += 1;
504 }
505
506 result
507 }
508}
509
510#[inline]
516pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
517 #[cfg(target_arch = "x86_64")]
518 {
519 if x86_simd::has_avx2() {
520 return unsafe { x86_simd::l2_squared_avx2(a, b) };
521 }
522 if x86_simd::has_sse41() {
523 return unsafe { x86_simd::l2_squared_sse41(a, b) };
524 }
525 }
526
527 #[cfg(target_arch = "aarch64")]
528 {
529 return neon_simd::l2_squared_neon(a, b);
530 }
531
532 #[allow(unreachable_code)]
533 l2_squared_scalar(a, b)
534}
535
536#[inline]
538pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
539 #[cfg(target_arch = "x86_64")]
540 {
541 if x86_simd::has_avx2() {
542 return unsafe { x86_simd::dot_product_avx2(a, b) };
543 }
544 if x86_simd::has_sse41() {
545 return unsafe { x86_simd::dot_product_sse41(a, b) };
546 }
547 }
548
549 #[cfg(target_arch = "aarch64")]
550 {
551 return neon_simd::dot_product_neon(a, b);
552 }
553
554 #[allow(unreachable_code)]
555 dot_product_scalar(a, b)
556}
557
558#[inline]
560pub fn norm_squared(a: &[f32]) -> f32 {
561 #[cfg(target_arch = "x86_64")]
562 {
563 if x86_simd::has_avx2() {
564 return unsafe { x86_simd::norm_squared_avx2(a) };
565 }
566 }
567
568 #[cfg(target_arch = "aarch64")]
569 {
570 return neon_simd::norm_squared_neon(a);
571 }
572
573 #[allow(unreachable_code)]
574 norm_squared_scalar(a)
575}
576
577#[inline]
580pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
581 let dot = dot_product(a, b);
582 let norm_a = norm_squared(a).sqrt();
583 let norm_b = norm_squared(b).sqrt();
584
585 if norm_a == 0.0 || norm_b == 0.0 {
586 return 1.0;
587 }
588
589 let cosine_sim = dot / (norm_a * norm_b);
590 1.0 - cosine_sim.clamp(-1.0, 1.0)
591}
592
593#[inline]
600pub fn f16_to_f32_bulk(input: &[u16], output: &mut [f32]) {
601 debug_assert_eq!(input.len(), output.len());
602
603 #[cfg(target_arch = "x86_64")]
604 {
605 if x86_quant::has_f16c() {
606 unsafe { x86_quant::f16_to_f32_bulk_f16c(input, output) };
607 return;
608 }
609 }
610
611 #[cfg(target_arch = "aarch64")]
612 {
613 neon_quant::f16_to_f32_bulk_neon(input, output);
614 return;
615 }
616
617 #[allow(unreachable_code)]
619 for (i, &bits) in input.iter().enumerate() {
620 output[i] = half::f16::from_bits(bits).to_f32();
621 }
622}
623
624#[inline]
627pub fn l2_f16_vs_f32(f16_data: &[u16], query: &[f32]) -> f32 {
628 debug_assert_eq!(f16_data.len(), query.len());
629
630 #[cfg(target_arch = "x86_64")]
631 {
632 if x86_quant::has_f16c() && x86_simd::has_avx2() {
633 return unsafe { x86_quant::l2_f16_vs_f32_f16c(f16_data, query) };
634 }
635 }
636
637 #[cfg(target_arch = "aarch64")]
638 {
639 return neon_quant::l2_f16_vs_f32_neon(f16_data, query);
640 }
641
642 #[allow(unreachable_code)]
644 {
645 let mut sum = 0.0f32;
646 for (i, &bits) in f16_data.iter().enumerate() {
647 let f = half::f16::from_bits(bits).to_f32();
648 let d = f - query[i];
649 sum += d * d;
650 }
651 sum
652 }
653}
654
655#[inline]
658pub fn l2_u8_scaled_vs_f32(
659 u8_data: &[u8],
660 query: &[f32],
661 scales: &[f32],
662 offsets: &[f32],
663) -> f32 {
664 debug_assert_eq!(u8_data.len(), query.len());
665 debug_assert_eq!(scales.len(), query.len());
666 debug_assert_eq!(offsets.len(), query.len());
667
668 #[cfg(target_arch = "aarch64")]
669 {
670 return neon_quant::l2_u8_scaled_vs_f32_neon(u8_data, query, scales, offsets);
671 }
672
673 #[allow(unreachable_code)]
675 {
676 let mut sum = 0.0f32;
677 for i in 0..u8_data.len() {
678 let dequant = u8_data[i] as f32 * scales[i] + offsets[i];
679 let d = dequant - query[i];
680 sum += d * d;
681 }
682 sum
683 }
684}
685
686impl Distance<f32> for SimdL2 {
691 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
692 l2_squared(a, b)
693 }
694}
695
696impl Distance<f32> for SimdDot {
697 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
698 1.0 - dot_product(a, b)
701 }
702}
703
704impl Distance<f32> for SimdCosine {
705 fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
706 cosine_distance(a, b)
707 }
708}
709
710pub fn simd_info() -> SimdInfo {
716 SimdInfo {
717 #[cfg(target_arch = "x86_64")]
718 avx2: x86_simd::has_avx2(),
719 #[cfg(not(target_arch = "x86_64"))]
720 avx2: false,
721
722 #[cfg(target_arch = "x86_64")]
723 sse41: x86_simd::has_sse41(),
724 #[cfg(not(target_arch = "x86_64"))]
725 sse41: false,
726
727 #[cfg(target_arch = "aarch64")]
728 neon: true,
729 #[cfg(not(target_arch = "aarch64"))]
730 neon: false,
731 }
732}
733
734#[derive(Debug, Clone)]
736pub struct SimdInfo {
737 pub avx2: bool,
738 pub sse41: bool,
739 pub neon: bool,
740}
741
742impl std::fmt::Display for SimdInfo {
743 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744 let mut features = Vec::new();
745 if self.avx2 {
746 features.push("AVX2");
747 }
748 if self.sse41 {
749 features.push("SSE4.1");
750 }
751 if self.neon {
752 features.push("NEON");
753 }
754 if features.is_empty() {
755 write!(f, "SIMD: none (scalar fallback)")
756 } else {
757 write!(f, "SIMD: {}", features.join(", "))
758 }
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765
766 #[test]
767 fn test_l2_squared_basic() {
768 let a = vec![1.0, 2.0, 3.0, 4.0];
769 let b = vec![5.0, 6.0, 7.0, 8.0];
770
771 let expected: f32 = a
772 .iter()
773 .zip(&b)
774 .map(|(x, y)| (x - y) * (x - y))
775 .sum();
776
777 let result = l2_squared(&a, &b);
778 assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
779 }
780
781 #[test]
782 fn test_l2_squared_large() {
783 let dim = 133; let a: Vec<f32> = (0..dim).map(|i| i as f32).collect();
786 let b: Vec<f32> = (0..dim).map(|i| (i * 2) as f32).collect();
787
788 let expected = l2_squared_scalar(&a, &b);
789 let result = l2_squared(&a, &b);
790
791 assert!(
792 (result - expected).abs() < 1e-3,
793 "expected {expected}, got {result}"
794 );
795 }
796
797 #[test]
798 fn test_dot_product_basic() {
799 let a = vec![1.0, 2.0, 3.0, 4.0];
800 let b = vec![5.0, 6.0, 7.0, 8.0];
801
802 let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
803 let result = dot_product(&a, &b);
804
805 assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
806 }
807
808 #[test]
809 fn test_dot_product_large() {
810 let dim = 128;
811 let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
812 let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
813
814 let expected = dot_product_scalar(&a, &b);
815 let result = dot_product(&a, &b);
816
817 assert!(
818 (result - expected).abs() < 1e-3,
819 "expected {expected}, got {result}"
820 );
821 }
822
823 #[test]
824 fn test_cosine_identical() {
825 let a = vec![1.0, 2.0, 3.0, 4.0];
826 let result = cosine_distance(&a, &a);
827 assert!(result.abs() < 1e-5, "identical vectors should have distance ~0, got {result}");
828 }
829
830 #[test]
831 fn test_cosine_orthogonal() {
832 let a = vec![1.0, 0.0];
833 let b = vec![0.0, 1.0];
834 let result = cosine_distance(&a, &b);
835 assert!((result - 1.0).abs() < 1e-5, "orthogonal vectors should have distance ~1, got {result}");
836 }
837
838 #[test]
839 fn test_cosine_opposite() {
840 let a = vec![1.0, 2.0, 3.0];
841 let b: Vec<f32> = a.iter().map(|x| -x).collect();
842 let result = cosine_distance(&a, &b);
843 assert!((result - 2.0).abs() < 1e-5, "opposite vectors should have distance ~2, got {result}");
844 }
845
846 #[test]
847 fn test_simd_info() {
848 let info = simd_info();
849 println!("{}", info);
850 }
852
853 #[test]
854 fn test_distance_trait_impl() {
855 let a = vec![1.0, 2.0, 3.0, 4.0];
856 let b = vec![5.0, 6.0, 7.0, 8.0];
857
858 let l2 = SimdL2;
859 let result = l2.eval(&a, &b);
860 assert!(result > 0.0);
861
862 let cosine = SimdCosine;
863 let result = cosine.eval(&a, &b);
864 assert!(result >= 0.0 && result <= 2.0);
865 }
866}