1#[cfg(feature = "alloc")]
25extern crate alloc;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct SimdCapabilities {
30 pub avx2: bool,
32 pub avx512: bool,
34 pub neon: bool,
36 pub fma: bool,
38}
39
40impl SimdCapabilities {
41 #[must_use]
43 pub fn detect() -> Self {
44 Self {
45 avx2: cfg!(all(target_arch = "x86_64", target_feature = "avx2")),
46 avx512: cfg!(all(target_arch = "x86_64", target_feature = "avx512f")),
47 neon: cfg!(all(target_arch = "aarch64", target_feature = "neon")),
48 fma: cfg!(target_feature = "fma"),
49 }
50 }
51
52 #[must_use]
54 pub const fn lane_width(&self) -> usize {
55 if self.avx512 {
56 16 } else if self.avx2 {
58 8 } else if self.neon {
60 4 } else {
62 1 }
64 }
65
66 #[must_use]
68 pub const fn has_simd(&self) -> bool {
69 self.avx2 || self.avx512 || self.neon
70 }
71}
72
73impl Default for SimdCapabilities {
74 fn default() -> Self {
75 Self::detect()
76 }
77}
78
79#[inline]
87pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
88 assert_eq!(a.len(), b.len(), "Vectors must have same length");
89
90 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
91 {
92 cosine_similarity_neon(a, b)
93 }
94
95 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
96 {
97 cosine_similarity_avx2(a, b)
98 }
99
100 #[cfg(not(any(
101 all(target_arch = "aarch64", target_feature = "neon"),
102 all(target_arch = "x86_64", target_feature = "avx2")
103 )))]
104 {
105 cosine_similarity_scalar(a, b)
106 }
107}
108
109#[inline]
117pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
118 assert_eq!(a.len(), b.len(), "Vectors must have same length");
119
120 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
121 {
122 euclidean_distance_squared_neon(a, b)
123 }
124
125 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
126 {
127 euclidean_distance_squared_avx2(a, b)
128 }
129
130 #[cfg(not(any(
131 all(target_arch = "aarch64", target_feature = "neon"),
132 all(target_arch = "x86_64", target_feature = "avx2")
133 )))]
134 {
135 euclidean_distance_squared_scalar(a, b)
136 }
137}
138
139#[inline]
145pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
146 assert_eq!(a.len(), b.len(), "Vectors must have same length");
147
148 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
149 {
150 dot_product_neon(a, b)
151 }
152
153 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
154 {
155 dot_product_avx2(a, b)
156 }
157
158 #[cfg(not(any(
159 all(target_arch = "aarch64", target_feature = "neon"),
160 all(target_arch = "x86_64", target_feature = "avx2")
161 )))]
162 {
163 dot_product_scalar(a, b)
164 }
165}
166
167#[inline]
169pub fn l2_norm(a: &[f32]) -> f32 {
170 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
171 {
172 l2_norm_neon(a)
173 }
174
175 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
176 {
177 l2_norm_avx2(a)
178 }
179
180 #[cfg(not(any(
181 all(target_arch = "aarch64", target_feature = "neon"),
182 all(target_arch = "x86_64", target_feature = "avx2")
183 )))]
184 {
185 l2_norm_scalar(a)
186 }
187}
188
189#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
194#[inline]
195fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
196 use core::arch::aarch64::*;
197
198 let n = a.len();
199 let chunks = n / 4;
200
201 unsafe {
202 let mut dot_sum = vdupq_n_f32(0.0);
203 let mut norm_a_sum = vdupq_n_f32(0.0);
204 let mut norm_b_sum = vdupq_n_f32(0.0);
205
206 for i in 0..chunks {
207 let offset = i * 4;
208 let va = vld1q_f32(a.as_ptr().add(offset));
209 let vb = vld1q_f32(b.as_ptr().add(offset));
210
211 dot_sum = vfmaq_f32(dot_sum, va, vb);
213 norm_a_sum = vfmaq_f32(norm_a_sum, va, va);
215 norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb);
217 }
218
219 let dot = vaddvq_f32(dot_sum);
221 let norm_a = vaddvq_f32(norm_a_sum);
222 let norm_b = vaddvq_f32(norm_b_sum);
223
224 let mut dot_tail = 0.0f32;
226 let mut norm_a_tail = 0.0f32;
227 let mut norm_b_tail = 0.0f32;
228
229 for i in (chunks * 4)..n {
230 let ai = a[i];
231 let bi = b[i];
232 dot_tail += ai * bi;
233 norm_a_tail += ai * ai;
234 norm_b_tail += bi * bi;
235 }
236
237 let total_dot = dot + dot_tail;
238 let total_norm_a = (norm_a + norm_a_tail).sqrt();
239 let total_norm_b = (norm_b + norm_b_tail).sqrt();
240
241 if total_norm_a == 0.0 || total_norm_b == 0.0 {
242 0.0
243 } else {
244 total_dot / (total_norm_a * total_norm_b)
245 }
246 }
247}
248
249#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
250#[inline]
251fn euclidean_distance_squared_neon(a: &[f32], b: &[f32]) -> f32 {
252 use core::arch::aarch64::*;
253
254 let n = a.len();
255 let chunks = n / 4;
256
257 unsafe {
258 let mut sum = vdupq_n_f32(0.0);
259
260 for i in 0..chunks {
261 let offset = i * 4;
262 let va = vld1q_f32(a.as_ptr().add(offset));
263 let vb = vld1q_f32(b.as_ptr().add(offset));
264
265 let diff = vsubq_f32(va, vb);
267 sum = vfmaq_f32(sum, diff, diff);
269 }
270
271 let mut total = vaddvq_f32(sum);
273
274 for i in (chunks * 4)..n {
276 let diff = a[i] - b[i];
277 total += diff * diff;
278 }
279
280 total
281 }
282}
283
284#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
285#[inline]
286fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
287 use core::arch::aarch64::*;
288
289 let n = a.len();
290 let chunks = n / 4;
291
292 unsafe {
293 let mut sum = vdupq_n_f32(0.0);
294
295 for i in 0..chunks {
296 let offset = i * 4;
297 let va = vld1q_f32(a.as_ptr().add(offset));
298 let vb = vld1q_f32(b.as_ptr().add(offset));
299 sum = vfmaq_f32(sum, va, vb);
300 }
301
302 let mut total = vaddvq_f32(sum);
303
304 for i in (chunks * 4)..n {
305 total += a[i] * b[i];
306 }
307
308 total
309 }
310}
311
312#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
313#[inline]
314fn l2_norm_neon(a: &[f32]) -> f32 {
315 use core::arch::aarch64::*;
316
317 let n = a.len();
318 let chunks = n / 4;
319
320 unsafe {
321 let mut sum = vdupq_n_f32(0.0);
322
323 for i in 0..chunks {
324 let offset = i * 4;
325 let va = vld1q_f32(a.as_ptr().add(offset));
326 sum = vfmaq_f32(sum, va, va);
327 }
328
329 let mut total = vaddvq_f32(sum);
330
331 for i in (chunks * 4)..n {
332 total += a[i] * a[i];
333 }
334
335 total.sqrt()
336 }
337}
338
339#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
344#[inline]
345fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
346 use core::arch::x86_64::*;
347
348 let n = a.len();
349 let chunks = n / 8;
350
351 unsafe {
352 let mut dot_sum = _mm256_setzero_ps();
353 let mut norm_a_sum = _mm256_setzero_ps();
354 let mut norm_b_sum = _mm256_setzero_ps();
355
356 for i in 0..chunks {
357 let offset = i * 8;
358 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
359 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
360
361 dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
363 norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
364 norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
365 }
366
367 let dot = horizontal_sum_avx2(dot_sum);
369 let norm_a = horizontal_sum_avx2(norm_a_sum);
370 let norm_b = horizontal_sum_avx2(norm_b_sum);
371
372 let mut dot_tail = 0.0f32;
374 let mut norm_a_tail = 0.0f32;
375 let mut norm_b_tail = 0.0f32;
376
377 for i in (chunks * 8)..n {
378 let ai = a[i];
379 let bi = b[i];
380 dot_tail += ai * bi;
381 norm_a_tail += ai * ai;
382 norm_b_tail += bi * bi;
383 }
384
385 let total_dot = dot + dot_tail;
386 let total_norm_a = (norm_a + norm_a_tail).sqrt();
387 let total_norm_b = (norm_b + norm_b_tail).sqrt();
388
389 if total_norm_a == 0.0 || total_norm_b == 0.0 {
390 0.0
391 } else {
392 total_dot / (total_norm_a * total_norm_b)
393 }
394 }
395}
396
397#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
398#[inline]
399fn euclidean_distance_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
400 use core::arch::x86_64::*;
401
402 let n = a.len();
403 let chunks = n / 8;
404
405 unsafe {
406 let mut sum = _mm256_setzero_ps();
407
408 for i in 0..chunks {
409 let offset = i * 8;
410 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
411 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
412
413 let diff = _mm256_sub_ps(va, vb);
414 sum = _mm256_fmadd_ps(diff, diff, sum);
415 }
416
417 let mut total = horizontal_sum_avx2(sum);
418
419 for i in (chunks * 8)..n {
420 let diff = a[i] - b[i];
421 total += diff * diff;
422 }
423
424 total
425 }
426}
427
428#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
429#[inline]
430fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
431 use core::arch::x86_64::*;
432
433 let n = a.len();
434 let chunks = n / 8;
435
436 unsafe {
437 let mut sum = _mm256_setzero_ps();
438
439 for i in 0..chunks {
440 let offset = i * 8;
441 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
442 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
443 sum = _mm256_fmadd_ps(va, vb, sum);
444 }
445
446 let mut total = horizontal_sum_avx2(sum);
447
448 for i in (chunks * 8)..n {
449 total += a[i] * b[i];
450 }
451
452 total
453 }
454}
455
456#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
457#[inline]
458fn l2_norm_avx2(a: &[f32]) -> f32 {
459 use core::arch::x86_64::*;
460
461 let n = a.len();
462 let chunks = n / 8;
463
464 unsafe {
465 let mut sum = _mm256_setzero_ps();
466
467 for i in 0..chunks {
468 let offset = i * 8;
469 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
470 sum = _mm256_fmadd_ps(va, va, sum);
471 }
472
473 let mut total = horizontal_sum_avx2(sum);
474
475 for i in (chunks * 8)..n {
476 total += a[i] * a[i];
477 }
478
479 total.sqrt()
480 }
481}
482
483#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
484#[inline]
485unsafe fn horizontal_sum_avx2(v: core::arch::x86_64::__m256) -> f32 {
486 use core::arch::x86_64::*;
487
488 let high = _mm256_extractf128_ps(v, 1);
490 let low = _mm256_castps256_ps128(v);
491 let sum128 = _mm_add_ps(high, low);
492
493 let shuf = _mm_movehdup_ps(sum128);
495 let sums = _mm_add_ps(sum128, shuf);
496 let shuf2 = _mm_movehl_ps(sums, sums);
497 let sums2 = _mm_add_ss(sums, shuf2);
498
499 _mm_cvtss_f32(sums2)
500}
501
502#[inline]
508pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
509 let mut dot = 0.0f32;
510 let mut norm_a = 0.0f32;
511 let mut norm_b = 0.0f32;
512
513 for i in 0..a.len() {
514 let ai = a[i];
515 let bi = b[i];
516 dot += ai * bi;
517 norm_a += ai * ai;
518 norm_b += bi * bi;
519 }
520
521 let norm_a = norm_a.sqrt();
522 let norm_b = norm_b.sqrt();
523
524 if norm_a == 0.0 || norm_b == 0.0 {
525 0.0
526 } else {
527 dot / (norm_a * norm_b)
528 }
529}
530
531#[inline]
533pub fn euclidean_distance_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
534 let mut sum = 0.0f32;
535 for i in 0..a.len() {
536 let diff = a[i] - b[i];
537 sum += diff * diff;
538 }
539 sum
540}
541
542#[inline]
544pub fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
545 let mut sum = 0.0f32;
546 for i in 0..a.len() {
547 sum += a[i] * b[i];
548 }
549 sum
550}
551
552#[inline]
554pub fn l2_norm_scalar(a: &[f32]) -> f32 {
555 let mut sum = 0.0f32;
556 for &x in a {
557 sum += x * x;
558 }
559 sum.sqrt()
560}
561
562#[derive(Debug, Clone)]
568pub struct DistanceResult {
569 pub index: usize,
571 pub distance: f32,
573}
574
575#[cfg(feature = "alloc")]
589pub fn batch_cosine_distances<'a, I>(
590 query: &[f32],
591 batch: I,
592 k: usize,
593) -> alloc::vec::Vec<DistanceResult>
594where
595 I: Iterator<Item = (usize, &'a [f32])>,
596{
597 use alloc::vec::Vec;
598
599 let mut results: Vec<DistanceResult> = batch
600 .map(|(index, vector)| DistanceResult {
601 index,
602 distance: 1.0 - cosine_similarity(query, vector), })
604 .collect();
605
606 if results.len() > k {
608 results.select_nth_unstable_by(k, |a, b| {
609 a.distance
610 .partial_cmp(&b.distance)
611 .unwrap_or(core::cmp::Ordering::Equal)
612 });
613 results.truncate(k);
614 }
615
616 results.sort_by(|a, b| {
617 a.distance
618 .partial_cmp(&b.distance)
619 .unwrap_or(core::cmp::Ordering::Equal)
620 });
621
622 results
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 extern crate alloc;
629 use alloc::vec::Vec;
630
631 #[test]
632 fn test_cosine_similarity_identical() {
633 let a = [1.0f32, 2.0, 3.0, 4.0];
634 let b = [1.0f32, 2.0, 3.0, 4.0];
635 let sim = cosine_similarity(&a, &b);
636 assert!((sim - 1.0).abs() < 1e-6);
637 }
638
639 #[test]
640 fn test_cosine_similarity_opposite() {
641 let a = [1.0f32, 0.0, 0.0, 0.0];
642 let b = [-1.0f32, 0.0, 0.0, 0.0];
643 let sim = cosine_similarity(&a, &b);
644 assert!((sim - (-1.0)).abs() < 1e-6);
645 }
646
647 #[test]
648 fn test_cosine_similarity_orthogonal() {
649 let a = [1.0f32, 0.0, 0.0, 0.0];
650 let b = [0.0f32, 1.0, 0.0, 0.0];
651 let sim = cosine_similarity(&a, &b);
652 assert!(sim.abs() < 1e-6);
653 }
654
655 #[test]
656 fn test_euclidean_distance_zero() {
657 let a = [1.0f32, 2.0, 3.0, 4.0];
658 let b = [1.0f32, 2.0, 3.0, 4.0];
659 let dist = euclidean_distance_squared(&a, &b);
660 assert!(dist.abs() < 1e-6);
661 }
662
663 #[test]
664 fn test_euclidean_distance_known() {
665 let a = [0.0f32, 0.0, 0.0, 0.0];
666 let b = [3.0f32, 4.0, 0.0, 0.0];
667 let dist = euclidean_distance_squared(&a, &b);
668 assert!((dist - 25.0).abs() < 1e-6); }
670
671 #[test]
672 fn test_dot_product() {
673 let a = [1.0f32, 2.0, 3.0, 4.0];
674 let b = [2.0f32, 3.0, 4.0, 5.0];
675 let dot = dot_product(&a, &b);
676 assert!((dot - 40.0).abs() < 1e-6); }
678
679 #[test]
680 fn test_l2_norm() {
681 let a = [3.0f32, 4.0, 0.0, 0.0];
682 let norm = l2_norm(&a);
683 assert!((norm - 5.0).abs() < 1e-6);
684 }
685
686 #[test]
687 fn test_large_vector() {
688 let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
690 let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
691
692 let sim = cosine_similarity(&a, &b);
693 assert!(sim > 0.99); let dot = dot_product(&a, &b);
696 assert!(dot > 0.0);
697 }
698
699 #[test]
700 fn test_scalar_matches_simd() {
701 let a: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
702 let b: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) * 0.1).collect();
703
704 let scalar_sim = cosine_similarity_scalar(&a, &b);
705 let simd_sim = cosine_similarity(&a, &b);
706
707 assert!((scalar_sim - simd_sim).abs() < 1e-5);
708
709 let scalar_dist = euclidean_distance_squared_scalar(&a, &b);
710 let simd_dist = euclidean_distance_squared(&a, &b);
711
712 assert!((scalar_dist - simd_dist).abs() < 1e-4);
713 }
714}