1#![allow(unreachable_code)]
34
35use crate::types::DistanceMetric;
36
37#[cfg(target_arch = "x86_64")]
39use std::arch::x86_64::*;
40
41#[cfg(target_arch = "aarch64")]
43use std::arch::aarch64::*;
44
45#[cfg(target_arch = "x86_64")]
47#[inline]
48pub fn is_avx2_available() -> bool {
49 is_x86_feature_detected!("avx2")
50}
51
52#[cfg(not(target_arch = "x86_64"))]
54#[inline]
55pub fn is_avx2_available() -> bool {
56 false
57}
58
59#[cfg(target_arch = "x86_64")]
61#[inline]
62pub fn is_fma_available() -> bool {
63 is_x86_feature_detected!("fma")
64}
65
66#[cfg(not(target_arch = "x86_64"))]
68#[inline]
69pub fn is_fma_available() -> bool {
70 false
71}
72
73#[cfg(target_arch = "aarch64")]
75#[inline]
76pub fn is_neon_available() -> bool {
77 true
79}
80
81#[cfg(not(target_arch = "aarch64"))]
83#[inline]
84pub fn is_neon_available() -> bool {
85 false
86}
87
88#[cfg(target_arch = "x86_64")]
90#[inline]
91pub fn is_avx512_available() -> bool {
92 is_x86_feature_detected!("avx512f")
93}
94
95#[cfg(not(target_arch = "x86_64"))]
97#[inline]
98pub fn is_avx512_available() -> bool {
99 false
100}
101
102#[cfg(target_arch = "x86_64")]
108#[inline]
109unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
110 let low = _mm512_castps512_ps256(v); let high = _mm512_extractf32x8_ps(v, 1); let sum256 = _mm256_add_ps(low, high);
116
117 horizontal_sum_avx2(sum256)
119}
120
121#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx512f")]
124#[inline]
125unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
126 let len = a.len();
127 let mut sum = _mm512_setzero_ps();
128
129 let chunks = len / 16;
131 for i in 0..chunks {
132 let offset = i * 16;
133 let a_ptr = a.as_ptr().add(offset);
134 let b_ptr = b.as_ptr().add(offset);
135
136 let a_vec = _mm512_loadu_ps(a_ptr);
137 let b_vec = _mm512_loadu_ps(b_ptr);
138 sum = _mm512_fmadd_ps(a_vec, b_vec, sum);
140 }
141
142 let mut total = horizontal_sum_avx512(sum);
144
145 for i in (chunks * 16)..len {
147 total += a[i] * b[i];
148 }
149
150 total
151}
152
153#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx512f")]
156#[inline]
157unsafe fn cosine_similarity_avx512(a: &[f32], b: &[f32]) -> f32 {
158 let len = a.len();
159 let mut dot_sum = _mm512_setzero_ps();
160 let mut norm_a_sum = _mm512_setzero_ps();
161 let mut norm_b_sum = _mm512_setzero_ps();
162
163 let chunks = len / 16;
165 for i in 0..chunks {
166 let offset = i * 16;
167 let a_ptr = a.as_ptr().add(offset);
168 let b_ptr = b.as_ptr().add(offset);
169
170 let a_vec = _mm512_loadu_ps(a_ptr);
171 let b_vec = _mm512_loadu_ps(b_ptr);
172
173 dot_sum = _mm512_fmadd_ps(a_vec, b_vec, dot_sum);
175 norm_a_sum = _mm512_fmadd_ps(a_vec, a_vec, norm_a_sum);
176 norm_b_sum = _mm512_fmadd_ps(b_vec, b_vec, norm_b_sum);
177 }
178
179 let mut dot = horizontal_sum_avx512(dot_sum);
181 let mut norm_a = horizontal_sum_avx512(norm_a_sum);
182 let mut norm_b = horizontal_sum_avx512(norm_b_sum);
183
184 for i in (chunks * 16)..len {
186 dot += a[i] * b[i];
187 norm_a += a[i] * a[i];
188 norm_b += b[i] * b[i];
189 }
190
191 let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
192 dot / denominator
193}
194
195#[cfg(target_arch = "x86_64")]
197#[target_feature(enable = "avx512f")]
198#[inline]
199unsafe fn euclidean_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
200 let len = a.len();
201 let mut sum_sq = _mm512_setzero_ps();
202
203 let chunks = len / 16;
205 for i in 0..chunks {
206 let offset = i * 16;
207 let a_ptr = a.as_ptr().add(offset);
208 let b_ptr = b.as_ptr().add(offset);
209
210 let a_vec = _mm512_loadu_ps(a_ptr);
211 let b_vec = _mm512_loadu_ps(b_ptr);
212 let diff = _mm512_sub_ps(a_vec, b_vec);
213 sum_sq = _mm512_fmadd_ps(diff, diff, sum_sq);
215 }
216
217 let mut total = horizontal_sum_avx512(sum_sq);
219
220 for i in (chunks * 16)..len {
222 let diff = a[i] - b[i];
223 total += diff * diff;
224 }
225
226 total.sqrt()
227}
228
229#[cfg(target_arch = "x86_64")]
231#[target_feature(enable = "avx512f")]
232#[inline]
233unsafe fn manhattan_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
234 let len = a.len();
235 let mut sum = _mm512_setzero_ps();
236
237 let chunks = len / 16;
239 for i in 0..chunks {
240 let offset = i * 16;
241 let a_ptr = a.as_ptr().add(offset);
242 let b_ptr = b.as_ptr().add(offset);
243
244 let a_vec = _mm512_loadu_ps(a_ptr);
245 let b_vec = _mm512_loadu_ps(b_ptr);
246 let diff = _mm512_sub_ps(a_vec, b_vec);
247 let abs_diff = _mm512_abs_ps(diff);
249 sum = _mm512_add_ps(sum, abs_diff);
250 }
251
252 let mut total = horizontal_sum_avx512(sum);
254
255 for i in (chunks * 16)..len {
257 total += (a[i] - b[i]).abs();
258 }
259
260 total
261}
262
263#[cfg(target_arch = "aarch64")]
269#[inline]
270unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
271 let pair_sum = vpaddq_f32(v, v);
273 let final_sum = vpaddq_f32(pair_sum, pair_sum);
275 vgetq_lane_f32(final_sum, 0)
277}
278
279#[cfg(target_arch = "aarch64")]
281#[inline]
282unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
283 let len = a.len();
284 let mut sum = vdupq_n_f32(0.0);
285
286 let chunks = len / 4;
288 for i in 0..chunks {
289 let offset = i * 4;
290 let a_ptr = a.as_ptr().add(offset);
291 let b_ptr = b.as_ptr().add(offset);
292
293 let a_vec = vld1q_f32(a_ptr);
294 let b_vec = vld1q_f32(b_ptr);
295 sum = vmlaq_f32(sum, a_vec, b_vec);
297 }
298
299 let mut total = horizontal_sum_neon(sum);
301
302 for i in (chunks * 4)..len {
304 total += a[i] * b[i];
305 }
306
307 total
308}
309
310#[cfg(target_arch = "aarch64")]
312#[inline]
313unsafe fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
314 let len = a.len();
315 let mut dot_sum = vdupq_n_f32(0.0);
316 let mut norm_a_sum = vdupq_n_f32(0.0);
317 let mut norm_b_sum = vdupq_n_f32(0.0);
318
319 let chunks = len / 4;
321 for i in 0..chunks {
322 let offset = i * 4;
323 let a_ptr = a.as_ptr().add(offset);
324 let b_ptr = b.as_ptr().add(offset);
325
326 let a_vec = vld1q_f32(a_ptr);
327 let b_vec = vld1q_f32(b_ptr);
328
329 dot_sum = vmlaq_f32(dot_sum, a_vec, b_vec);
331 norm_a_sum = vmlaq_f32(norm_a_sum, a_vec, a_vec);
332 norm_b_sum = vmlaq_f32(norm_b_sum, b_vec, b_vec);
333 }
334
335 let mut dot = horizontal_sum_neon(dot_sum);
337 let mut norm_a = horizontal_sum_neon(norm_a_sum);
338 let mut norm_b = horizontal_sum_neon(norm_b_sum);
339
340 for i in (chunks * 4)..len {
342 dot += a[i] * b[i];
343 norm_a += a[i] * a[i];
344 norm_b += b[i] * b[i];
345 }
346
347 let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
348 dot / denominator
349}
350
351#[cfg(target_arch = "aarch64")]
353#[inline]
354unsafe fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
355 let len = a.len();
356 let mut sum_sq = vdupq_n_f32(0.0);
357
358 let chunks = len / 4;
360 for i in 0..chunks {
361 let offset = i * 4;
362 let a_ptr = a.as_ptr().add(offset);
363 let b_ptr = b.as_ptr().add(offset);
364
365 let a_vec = vld1q_f32(a_ptr);
366 let b_vec = vld1q_f32(b_ptr);
367 let diff = vsubq_f32(a_vec, b_vec);
368 sum_sq = vmlaq_f32(sum_sq, diff, diff);
370 }
371
372 let mut total = horizontal_sum_neon(sum_sq);
374
375 for i in (chunks * 4)..len {
377 let diff = a[i] - b[i];
378 total += diff * diff;
379 }
380
381 total.sqrt()
382}
383
384#[cfg(target_arch = "aarch64")]
386#[inline]
387unsafe fn manhattan_distance_neon(a: &[f32], b: &[f32]) -> f32 {
388 let len = a.len();
389 let mut sum = vdupq_n_f32(0.0);
390
391 let chunks = len / 4;
393 for i in 0..chunks {
394 let offset = i * 4;
395 let a_ptr = a.as_ptr().add(offset);
396 let b_ptr = b.as_ptr().add(offset);
397
398 let a_vec = vld1q_f32(a_ptr);
399 let b_vec = vld1q_f32(b_ptr);
400 let diff = vsubq_f32(a_vec, b_vec);
401 let abs_diff = vabsq_f32(diff);
402 sum = vaddq_f32(sum, abs_diff);
403 }
404
405 let mut total = horizontal_sum_neon(sum);
407
408 for i in (chunks * 4)..len {
410 total += (a[i] - b[i]).abs();
411 }
412
413 total
414}
415
416#[cfg(target_arch = "x86_64")]
422#[target_feature(enable = "avx2,fma")]
423#[inline]
424unsafe fn dot_product_fma(a: &[f32], b: &[f32]) -> f32 {
425 let len = a.len();
426 let mut sum = _mm256_setzero_ps();
427
428 let chunks = len / 8;
430 for i in 0..chunks {
431 let offset = i * 8;
432 let a_ptr = a.as_ptr().add(offset);
433 let b_ptr = b.as_ptr().add(offset);
434
435 let a_vec = _mm256_loadu_ps(a_ptr);
436 let b_vec = _mm256_loadu_ps(b_ptr);
437 sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
439 }
440
441 let mut total = horizontal_sum_avx2(sum);
443
444 for i in (chunks * 8)..len {
446 total += a[i] * b[i];
447 }
448
449 total
450}
451
452#[cfg(target_arch = "x86_64")]
454#[target_feature(enable = "avx2,fma")]
455#[inline]
456unsafe fn cosine_similarity_fma(a: &[f32], b: &[f32]) -> f32 {
457 let len = a.len();
458 let mut dot_sum = _mm256_setzero_ps();
459 let mut norm_a_sum = _mm256_setzero_ps();
460 let mut norm_b_sum = _mm256_setzero_ps();
461
462 let chunks = len / 8;
464 for i in 0..chunks {
465 let offset = i * 8;
466 let a_ptr = a.as_ptr().add(offset);
467 let b_ptr = b.as_ptr().add(offset);
468
469 let a_vec = _mm256_loadu_ps(a_ptr);
470 let b_vec = _mm256_loadu_ps(b_ptr);
471
472 dot_sum = _mm256_fmadd_ps(a_vec, b_vec, dot_sum);
474 norm_a_sum = _mm256_fmadd_ps(a_vec, a_vec, norm_a_sum);
475 norm_b_sum = _mm256_fmadd_ps(b_vec, b_vec, norm_b_sum);
476 }
477
478 let mut dot = horizontal_sum_avx2(dot_sum);
480 let mut norm_a = horizontal_sum_avx2(norm_a_sum);
481 let mut norm_b = horizontal_sum_avx2(norm_b_sum);
482
483 for i in (chunks * 8)..len {
485 dot += a[i] * b[i];
486 norm_a += a[i] * a[i];
487 norm_b += b[i] * b[i];
488 }
489
490 let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
491 dot / denominator
492}
493
494#[cfg(target_arch = "x86_64")]
496#[target_feature(enable = "avx2,fma")]
497#[inline]
498unsafe fn euclidean_distance_fma(a: &[f32], b: &[f32]) -> f32 {
499 let len = a.len();
500 let mut sum_sq = _mm256_setzero_ps();
501
502 let chunks = len / 8;
504 for i in 0..chunks {
505 let offset = i * 8;
506 let a_ptr = a.as_ptr().add(offset);
507 let b_ptr = b.as_ptr().add(offset);
508
509 let a_vec = _mm256_loadu_ps(a_ptr);
510 let b_vec = _mm256_loadu_ps(b_ptr);
511 let diff = _mm256_sub_ps(a_vec, b_vec);
512 sum_sq = _mm256_fmadd_ps(diff, diff, sum_sq);
514 }
515
516 let mut total = horizontal_sum_avx2(sum_sq);
518
519 for i in (chunks * 8)..len {
521 let diff = a[i] - b[i];
522 total += diff * diff;
523 }
524
525 total.sqrt()
526}
527
528#[cfg(target_arch = "x86_64")]
530#[inline]
531unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
532 let hi = _mm256_extractf128_ps(v, 1); let lo = _mm256_castps256_ps128(v); let sum128 = _mm_add_ps(lo, hi); let sum64 = _mm_hadd_ps(sum128, sum128); let sum32 = _mm_hadd_ps(sum64, sum64); _mm_cvtss_f32(sum32)
546}
547
548#[cfg(target_arch = "x86_64")]
550#[target_feature(enable = "avx2")]
551#[inline]
552unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
553 let len = a.len();
554 let mut sum = _mm256_setzero_ps();
555
556 let chunks = len / 8;
558 for i in 0..chunks {
559 let offset = i * 8;
560 let a_ptr = a.as_ptr().add(offset);
561 let b_ptr = b.as_ptr().add(offset);
562
563 let a_vec = _mm256_loadu_ps(a_ptr);
564 let b_vec = _mm256_loadu_ps(b_ptr);
565 let mul = _mm256_mul_ps(a_vec, b_vec);
566 sum = _mm256_add_ps(sum, mul);
567 }
568
569 let mut total = horizontal_sum_avx2(sum);
571
572 for i in (chunks * 8)..len {
574 total += a[i] * b[i];
575 }
576
577 total
578}
579
580#[cfg(target_arch = "x86_64")]
582#[target_feature(enable = "avx2")]
583#[inline]
584unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
585 let len = a.len();
586 let mut dot_sum = _mm256_setzero_ps();
587 let mut norm_a_sum = _mm256_setzero_ps();
588 let mut norm_b_sum = _mm256_setzero_ps();
589
590 let chunks = len / 8;
592 for i in 0..chunks {
593 let offset = i * 8;
594 let a_ptr = a.as_ptr().add(offset);
595 let b_ptr = b.as_ptr().add(offset);
596
597 let a_vec = _mm256_loadu_ps(a_ptr);
598 let b_vec = _mm256_loadu_ps(b_ptr);
599
600 dot_sum = _mm256_add_ps(dot_sum, _mm256_mul_ps(a_vec, b_vec));
601 norm_a_sum = _mm256_add_ps(norm_a_sum, _mm256_mul_ps(a_vec, a_vec));
602 norm_b_sum = _mm256_add_ps(norm_b_sum, _mm256_mul_ps(b_vec, b_vec));
603 }
604
605 let mut dot = horizontal_sum_avx2(dot_sum);
607 let mut norm_a = horizontal_sum_avx2(norm_a_sum);
608 let mut norm_b = horizontal_sum_avx2(norm_b_sum);
609
610 for i in (chunks * 8)..len {
612 dot += a[i] * b[i];
613 norm_a += a[i] * a[i];
614 norm_b += b[i] * b[i];
615 }
616
617 let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
618 dot / denominator
619}
620
621#[cfg(target_arch = "x86_64")]
623#[target_feature(enable = "avx2")]
624#[inline]
625unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
626 let len = a.len();
627 let mut sum_sq = _mm256_setzero_ps();
628
629 let chunks = len / 8;
631 for i in 0..chunks {
632 let offset = i * 8;
633 let a_ptr = a.as_ptr().add(offset);
634 let b_ptr = b.as_ptr().add(offset);
635
636 let a_vec = _mm256_loadu_ps(a_ptr);
637 let b_vec = _mm256_loadu_ps(b_ptr);
638 let diff = _mm256_sub_ps(a_vec, b_vec);
639 sum_sq = _mm256_add_ps(sum_sq, _mm256_mul_ps(diff, diff));
640 }
641
642 let mut total = horizontal_sum_avx2(sum_sq);
644
645 for i in (chunks * 8)..len {
647 let diff = a[i] - b[i];
648 total += diff * diff;
649 }
650
651 total.sqrt()
652}
653
654#[cfg(target_arch = "x86_64")]
656#[target_feature(enable = "avx2")]
657#[inline]
658unsafe fn manhattan_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
659 let len = a.len();
660 let mut sum = _mm256_setzero_ps();
661 let sign_mask = _mm256_set1_ps(-0.0); let chunks = len / 8;
665 for i in 0..chunks {
666 let offset = i * 8;
667 let a_ptr = a.as_ptr().add(offset);
668 let b_ptr = b.as_ptr().add(offset);
669
670 let a_vec = _mm256_loadu_ps(a_ptr);
671 let b_vec = _mm256_loadu_ps(b_ptr);
672 let diff = _mm256_sub_ps(a_vec, b_vec);
673 let abs_diff = _mm256_andnot_ps(sign_mask, diff);
675 sum = _mm256_add_ps(sum, abs_diff);
676 }
677
678 let mut total = horizontal_sum_avx2(sum);
680
681 for i in (chunks * 8)..len {
683 total += (a[i] - b[i]).abs();
684 }
685
686 total
687}
688
689#[inline]
696#[allow(dead_code)]
697fn cosine_similarity_autovec(a: &[f32], b: &[f32]) -> f32 {
698 debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
699
700 let chunk_size = 8; let len = a.len();
703 let chunks = len / chunk_size;
704
705 let mut dot_product = 0.0f32;
706 let mut norm_a = 0.0f32;
707 let mut norm_b = 0.0f32;
708
709 for i in 0..chunks {
711 let offset = i * chunk_size;
712 for j in 0..chunk_size {
713 let idx = offset + j;
714 let a_val = unsafe { *a.get_unchecked(idx) };
715 let b_val = unsafe { *b.get_unchecked(idx) };
716
717 dot_product += a_val * b_val;
718 norm_a += a_val * a_val;
719 norm_b += b_val * b_val;
720 }
721 }
722
723 for i in (chunks * chunk_size)..len {
725 let a_val = unsafe { *a.get_unchecked(i) };
726 let b_val = unsafe { *b.get_unchecked(i) };
727
728 dot_product += a_val * b_val;
729 norm_a += a_val * a_val;
730 norm_b += b_val * b_val;
731 }
732
733 let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
734 dot_product / denominator
735}
736
737#[inline]
744pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
745 #[cfg(target_arch = "x86_64")]
746 {
747 if is_avx512_available() {
748 unsafe { cosine_similarity_avx512(a, b) }
749 } else if is_fma_available() {
750 unsafe { cosine_similarity_fma(a, b) }
751 } else if is_avx2_available() {
752 unsafe { cosine_similarity_avx2(a, b) }
753 } else {
754 cosine_similarity_autovec(a, b)
755 }
756 }
757 #[cfg(target_arch = "aarch64")]
758 {
759 unsafe { cosine_similarity_neon(a, b) }
760 }
761 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
762 {
763 cosine_similarity_autovec(a, b)
764 }
765}
766
767#[inline]
769#[allow(dead_code)]
770fn euclidean_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
771 debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
772
773 let chunk_size = 8;
774 let len = a.len();
775 let chunks = len / chunk_size;
776
777 let mut sum_sq = 0.0f32;
778
779 for i in 0..chunks {
781 let offset = i * chunk_size;
782 for j in 0..chunk_size {
783 let idx = offset + j;
784 let diff = unsafe { *a.get_unchecked(idx) - *b.get_unchecked(idx) };
785 sum_sq += diff * diff;
786 }
787 }
788
789 for i in (chunks * chunk_size)..len {
791 let diff = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
792 sum_sq += diff * diff;
793 }
794
795 sum_sq.sqrt()
796}
797
798#[inline]
805pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
806 #[cfg(target_arch = "x86_64")]
807 {
808 if is_avx512_available() {
809 unsafe { euclidean_distance_avx512(a, b) }
810 } else if is_fma_available() {
811 unsafe { euclidean_distance_fma(a, b) }
812 } else if is_avx2_available() {
813 unsafe { euclidean_distance_avx2(a, b) }
814 } else {
815 euclidean_distance_autovec(a, b)
816 }
817 }
818 #[cfg(target_arch = "aarch64")]
819 {
820 unsafe { euclidean_distance_neon(a, b) }
821 }
822 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
823 {
824 euclidean_distance_autovec(a, b)
825 }
826}
827
828#[inline]
830#[allow(dead_code)]
831fn dot_product_autovec(a: &[f32], b: &[f32]) -> f32 {
832 debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
833
834 let chunk_size = 8;
835 let len = a.len();
836 let chunks = len / chunk_size;
837
838 let mut dot = 0.0f32;
839
840 for i in 0..chunks {
842 let offset = i * chunk_size;
843 for j in 0..chunk_size {
844 let idx = offset + j;
845 dot += unsafe { *a.get_unchecked(idx) * *b.get_unchecked(idx) };
846 }
847 }
848
849 for i in (chunks * chunk_size)..len {
851 dot += unsafe { *a.get_unchecked(i) * *b.get_unchecked(i) };
852 }
853
854 dot
855}
856
857#[inline]
864pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
865 #[cfg(target_arch = "x86_64")]
866 {
867 if is_avx512_available() {
868 unsafe { dot_product_avx512(a, b) }
869 } else if is_fma_available() {
870 unsafe { dot_product_fma(a, b) }
871 } else if is_avx2_available() {
872 unsafe { dot_product_avx2(a, b) }
873 } else {
874 dot_product_autovec(a, b)
875 }
876 }
877 #[cfg(target_arch = "aarch64")]
878 {
879 unsafe { dot_product_neon(a, b) }
880 }
881 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
882 {
883 dot_product_autovec(a, b)
884 }
885}
886
887#[inline]
889#[allow(dead_code)]
890fn manhattan_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
891 debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
892
893 let chunk_size = 8;
894 let len = a.len();
895 let chunks = len / chunk_size;
896
897 let mut sum = 0.0f32;
898
899 for i in 0..chunks {
901 let offset = i * chunk_size;
902 for j in 0..chunk_size {
903 let idx = offset + j;
904 sum += unsafe { (*a.get_unchecked(idx) - *b.get_unchecked(idx)).abs() };
905 }
906 }
907
908 for i in (chunks * chunk_size)..len {
910 sum += unsafe { (*a.get_unchecked(i) - *b.get_unchecked(i)).abs() };
911 }
912
913 sum
914}
915
916#[inline]
923pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
924 #[cfg(target_arch = "x86_64")]
925 {
926 if is_avx512_available() {
927 unsafe { manhattan_distance_avx512(a, b) }
928 } else if is_avx2_available() {
929 unsafe { manhattan_distance_avx2(a, b) }
930 } else {
931 manhattan_distance_autovec(a, b)
932 }
933 }
934 #[cfg(target_arch = "aarch64")]
935 {
936 unsafe { manhattan_distance_neon(a, b) }
937 }
938 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
939 {
940 manhattan_distance_autovec(a, b)
941 }
942}
943
944pub fn compute_distance_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
948 match metric {
949 DistanceMetric::Cosine => cosine_similarity_simd(a, b),
950 DistanceMetric::Euclidean => -euclidean_distance_simd(a, b),
951 DistanceMetric::DotProduct => dot_product_simd(a, b),
952 DistanceMetric::Manhattan => -manhattan_distance_simd(a, b),
953 }
954}
955
956#[inline]
960pub fn compute_distance_lower_is_better_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
961 match metric {
962 DistanceMetric::Cosine => {
963 1.0 - cosine_similarity_simd(a, b)
965 }
966 DistanceMetric::Euclidean => euclidean_distance_simd(a, b),
967 DistanceMetric::DotProduct => {
968 -dot_product_simd(a, b)
970 }
971 DistanceMetric::Manhattan => manhattan_distance_simd(a, b),
972 }
973}
974
975#[inline]
984pub fn quantized_manhattan_distance_simd(a: &[u8], b: &[u8]) -> u32 {
985 assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
986
987 #[cfg(target_arch = "x86_64")]
988 {
989 if is_avx2_available() {
990 return unsafe { quantized_manhattan_distance_avx2(a, b) };
991 }
992 }
993
994 #[cfg(target_arch = "aarch64")]
995 {
996 return unsafe { quantized_manhattan_distance_neon(a, b) };
997 }
998
999 quantized_manhattan_distance_scalar(a, b)
1001}
1002
1003#[inline]
1008pub fn quantized_dot_product_simd(a: &[u8], b: &[u8]) -> u32 {
1009 assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
1010
1011 #[cfg(target_arch = "x86_64")]
1012 {
1013 if is_avx2_available() {
1014 return unsafe { quantized_dot_product_avx2(a, b) };
1015 }
1016 }
1017
1018 #[cfg(target_arch = "aarch64")]
1019 {
1020 return unsafe { quantized_dot_product_neon(a, b) };
1021 }
1022
1023 quantized_dot_product_scalar(a, b)
1025}
1026
1027#[inline]
1032pub fn quantized_euclidean_squared_simd(a: &[u8], b: &[u8]) -> u32 {
1033 assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
1034
1035 #[cfg(target_arch = "x86_64")]
1036 {
1037 if is_avx2_available() {
1038 return unsafe { quantized_euclidean_squared_avx2(a, b) };
1039 }
1040 }
1041
1042 #[cfg(target_arch = "aarch64")]
1043 {
1044 return unsafe { quantized_euclidean_squared_neon(a, b) };
1045 }
1046
1047 quantized_euclidean_squared_scalar(a, b)
1049}
1050
1051#[inline]
1056fn quantized_manhattan_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
1057 a.iter()
1058 .zip(b.iter())
1059 .map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
1060 .sum()
1061}
1062
1063#[inline]
1064fn quantized_dot_product_scalar(a: &[u8], b: &[u8]) -> u32 {
1065 a.iter()
1066 .zip(b.iter())
1067 .map(|(&x, &y)| x as u32 * y as u32)
1068 .sum()
1069}
1070
1071#[inline]
1072fn quantized_euclidean_squared_scalar(a: &[u8], b: &[u8]) -> u32 {
1073 a.iter()
1074 .zip(b.iter())
1075 .map(|(&x, &y)| {
1076 let diff = x as i32 - y as i32;
1077 (diff * diff) as u32
1078 })
1079 .sum()
1080}
1081
1082#[cfg(target_arch = "x86_64")]
1087#[target_feature(enable = "avx2")]
1088#[inline]
1089unsafe fn quantized_manhattan_distance_avx2(a: &[u8], b: &[u8]) -> u32 {
1090 let len = a.len();
1091 let mut sum = _mm256_setzero_si256();
1092
1093 let mut i = 0;
1094 while i + 32 <= len {
1096 let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
1097 let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
1098
1099 let diff1 = _mm256_subs_epu8(va, vb);
1101 let diff2 = _mm256_subs_epu8(vb, va);
1102 let abs_diff = _mm256_or_si256(diff1, diff2);
1103
1104 let abs_diff_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
1106 let abs_diff_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());
1107
1108 sum = _mm256_add_epi16(sum, abs_diff_lo);
1110 sum = _mm256_add_epi16(sum, abs_diff_hi);
1111
1112 i += 32;
1113 }
1114
1115 let sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256());
1117 let sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256());
1118 let sum32 = _mm256_add_epi32(sum_lo, sum_hi);
1119
1120 let mut result_arr = [0u32; 8];
1122 _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum32);
1123 let mut result: u32 = result_arr.iter().sum();
1124
1125 while i < len {
1127 result += (a[i] as i32 - b[i] as i32).unsigned_abs();
1128 i += 1;
1129 }
1130
1131 result
1132}
1133
1134#[cfg(target_arch = "x86_64")]
1135#[target_feature(enable = "avx2")]
1136#[inline]
1137unsafe fn quantized_dot_product_avx2(a: &[u8], b: &[u8]) -> u32 {
1138 let len = a.len();
1139 let mut sum = _mm256_setzero_si256();
1140
1141 let mut i = 0;
1142 while i + 16 <= len {
1144 let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
1146 let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
1147
1148 let va = _mm256_cvtepu8_epi16(va_128);
1150 let vb = _mm256_cvtepu8_epi16(vb_128);
1151
1152 let prod = _mm256_madd_epi16(va, vb);
1154 sum = _mm256_add_epi32(sum, prod);
1155
1156 i += 16;
1157 }
1158
1159 let mut result_arr = [0u32; 8];
1161 _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
1162 let mut result: u32 = result_arr.iter().sum();
1163
1164 while i < len {
1166 result += a[i] as u32 * b[i] as u32;
1167 i += 1;
1168 }
1169
1170 result
1171}
1172
1173#[cfg(target_arch = "x86_64")]
1174#[target_feature(enable = "avx2")]
1175#[inline]
1176unsafe fn quantized_euclidean_squared_avx2(a: &[u8], b: &[u8]) -> u32 {
1177 let len = a.len();
1178 let mut sum = _mm256_setzero_si256();
1179
1180 let mut i = 0;
1181 while i + 16 <= len {
1183 let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
1184 let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
1185
1186 let va = _mm256_cvtepu8_epi16(va_128);
1188 let vb = _mm256_cvtepu8_epi16(vb_128);
1189
1190 let diff = _mm256_sub_epi16(va, vb);
1192
1193 let squared = _mm256_madd_epi16(diff, diff);
1195 sum = _mm256_add_epi32(sum, squared);
1196
1197 i += 16;
1198 }
1199
1200 let mut result_arr = [0u32; 8];
1202 _mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
1203 let mut result: u32 = result_arr.iter().sum();
1204
1205 while i < len {
1207 let diff = a[i] as i32 - b[i] as i32;
1208 result += (diff * diff) as u32;
1209 i += 1;
1210 }
1211
1212 result
1213}
1214
1215#[cfg(target_arch = "aarch64")]
1220#[inline]
1221unsafe fn quantized_manhattan_distance_neon(a: &[u8], b: &[u8]) -> u32 {
1222 let len = a.len();
1223 let mut sum = vdupq_n_u32(0);
1224
1225 let mut i = 0;
1226 while i + 16 <= len {
1228 let va = vld1q_u8(a.as_ptr().add(i));
1229 let vb = vld1q_u8(b.as_ptr().add(i));
1230
1231 let abs_diff = vabdq_u8(va, vb);
1233
1234 let abs_diff_lo = vmovl_u8(vget_low_u8(abs_diff));
1236 let abs_diff_hi = vmovl_u8(vget_high_u8(abs_diff));
1237
1238 sum = vaddw_u16(sum, vget_low_u16(abs_diff_lo));
1240 sum = vaddw_u16(sum, vget_high_u16(abs_diff_lo));
1241 sum = vaddw_u16(sum, vget_low_u16(abs_diff_hi));
1242 sum = vaddw_u16(sum, vget_high_u16(abs_diff_hi));
1243
1244 i += 16;
1245 }
1246
1247 let mut result = vaddvq_u32(sum);
1249
1250 while i < len {
1252 result += (a[i] as i32 - b[i] as i32).unsigned_abs();
1253 i += 1;
1254 }
1255
1256 result
1257}
1258
1259#[cfg(target_arch = "aarch64")]
1260#[inline]
1261unsafe fn quantized_dot_product_neon(a: &[u8], b: &[u8]) -> u32 {
1262 let len = a.len();
1263 let mut sum = vdupq_n_u32(0);
1264
1265 let mut i = 0;
1266 while i + 8 <= len {
1268 let va = vld1_u8(a.as_ptr().add(i));
1269 let vb = vld1_u8(b.as_ptr().add(i));
1270
1271 let va_16 = vmovl_u8(va);
1273 let vb_16 = vmovl_u8(vb);
1274
1275 let prod = vmull_u16(vget_low_u16(va_16), vget_low_u16(vb_16));
1277 sum = vaddq_u32(sum, prod);
1278
1279 let prod_hi = vmull_u16(vget_high_u16(va_16), vget_high_u16(vb_16));
1280 sum = vaddq_u32(sum, prod_hi);
1281
1282 i += 8;
1283 }
1284
1285 let mut result = vaddvq_u32(sum);
1287
1288 while i < len {
1290 result += a[i] as u32 * b[i] as u32;
1291 i += 1;
1292 }
1293
1294 result
1295}
1296
1297#[cfg(target_arch = "aarch64")]
1298#[inline]
1299unsafe fn quantized_euclidean_squared_neon(a: &[u8], b: &[u8]) -> u32 {
1300 let len = a.len();
1301 let mut sum = vdupq_n_u32(0);
1302
1303 let mut i = 0;
1304 while i + 8 <= len {
1306 let va = vld1_u8(a.as_ptr().add(i));
1307 let vb = vld1_u8(b.as_ptr().add(i));
1308
1309 let abs_diff = vabd_u8(va, vb);
1311 let diff_16 = vmovl_u8(abs_diff);
1312
1313 let squared = vmull_u16(vget_low_u16(diff_16), vget_low_u16(diff_16));
1315 sum = vaddq_u32(sum, squared);
1316
1317 let squared_hi = vmull_u16(vget_high_u16(diff_16), vget_high_u16(diff_16));
1318 sum = vaddq_u32(sum, squared_hi);
1319
1320 i += 8;
1321 }
1322
1323 let mut result = vaddvq_u32(sum);
1325
1326 while i < len {
1328 let diff = a[i] as i32 - b[i] as i32;
1329 result += (diff * diff) as u32;
1330 i += 1;
1331 }
1332
1333 result
1334}
1335
1336#[inline]
1345pub fn normalize_vector_simd(vec: &mut [f32]) {
1346 let norm_squared = dot_product_simd(vec, vec);
1348 let norm = norm_squared.sqrt();
1349
1350 if norm > 1e-10 {
1351 let inv_norm = 1.0 / norm;
1352 scale_vector_simd(vec, inv_norm);
1353 }
1354}
1355
1356#[inline]
1360pub fn scale_vector_simd(vec: &mut [f32], scalar: f32) {
1361 #[cfg(target_arch = "x86_64")]
1362 {
1363 if is_avx512_available() {
1364 unsafe {
1365 scale_vector_avx512(vec, scalar);
1366 }
1367 return;
1368 }
1369 if is_avx2_available() {
1370 unsafe {
1371 scale_vector_avx2(vec, scalar);
1372 }
1373 return;
1374 }
1375 }
1376
1377 #[cfg(target_arch = "aarch64")]
1378 {
1379 unsafe {
1380 scale_vector_neon(vec, scalar);
1381 }
1382 return;
1383 }
1384
1385 for x in vec.iter_mut() {
1387 *x *= scalar;
1388 }
1389}
1390
1391#[cfg(target_arch = "x86_64")]
1392#[target_feature(enable = "avx512f")]
1393#[inline]
1394unsafe fn scale_vector_avx512(vec: &mut [f32], scalar: f32) {
1395 let len = vec.len();
1396 let scalar_vec = _mm512_set1_ps(scalar);
1397 let mut i = 0;
1398
1399 while i + 16 <= len {
1401 let ptr = vec.as_mut_ptr().add(i);
1402 let v = _mm512_loadu_ps(ptr);
1403 let scaled = _mm512_mul_ps(v, scalar_vec);
1404 _mm512_storeu_ps(ptr, scaled);
1405 i += 16;
1406 }
1407
1408 while i < len {
1410 vec[i] *= scalar;
1411 i += 1;
1412 }
1413}
1414
1415#[cfg(target_arch = "x86_64")]
1416#[target_feature(enable = "avx2")]
1417#[inline]
1418unsafe fn scale_vector_avx2(vec: &mut [f32], scalar: f32) {
1419 let len = vec.len();
1420 let scalar_vec = _mm256_set1_ps(scalar);
1421 let mut i = 0;
1422
1423 while i + 8 <= len {
1425 let ptr = vec.as_mut_ptr().add(i);
1426 let v = _mm256_loadu_ps(ptr);
1427 let scaled = _mm256_mul_ps(v, scalar_vec);
1428 _mm256_storeu_ps(ptr, scaled);
1429 i += 8;
1430 }
1431
1432 while i < len {
1434 vec[i] *= scalar;
1435 i += 1;
1436 }
1437}
1438
1439#[cfg(target_arch = "aarch64")]
1440#[inline]
1441unsafe fn scale_vector_neon(vec: &mut [f32], scalar: f32) {
1442 let len = vec.len();
1443 let scalar_vec = vdupq_n_f32(scalar);
1444 let mut i = 0;
1445
1446 while i + 4 <= len {
1448 let ptr = vec.as_mut_ptr().add(i);
1449 let v = vld1q_f32(ptr);
1450 let scaled = vmulq_f32(v, scalar_vec);
1451 vst1q_f32(ptr, scaled);
1452 i += 4;
1453 }
1454
1455 while i < len {
1457 vec[i] *= scalar;
1458 i += 1;
1459 }
1460}
1461
1462#[cfg(test)]
1463mod tests {
1464 use super::*;
1465
1466 #[test]
1467 fn test_cosine_similarity_simd() {
1468 let v1 = vec![1.0, 0.0, 0.0];
1469 let v2 = vec![1.0, 0.0, 0.0];
1470 let sim = cosine_similarity_simd(&v1, &v2);
1471 assert!((sim - 1.0).abs() < 1e-6);
1472
1473 let v1 = vec![1.0, 0.0, 0.0];
1474 let v2 = vec![0.0, 1.0, 0.0];
1475 let sim = cosine_similarity_simd(&v1, &v2);
1476 assert!(sim.abs() < 1e-6);
1477 }
1478
1479 #[test]
1480 fn test_cosine_similarity_simd_large() {
1481 let v1: Vec<f32> = (0..100).map(|i| i as f32).collect();
1483 let v2: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
1484 let sim = cosine_similarity_simd(&v1, &v2);
1485 assert!(sim > 0.99); }
1487
1488 #[test]
1489 fn test_euclidean_distance_simd() {
1490 let v1 = vec![0.0, 0.0, 0.0];
1491 let v2 = vec![3.0, 4.0, 0.0];
1492 let dist = euclidean_distance_simd(&v1, &v2);
1493 assert!((dist - 5.0).abs() < 1e-6);
1494 }
1495
1496 #[test]
1497 fn test_euclidean_distance_simd_large() {
1498 let v1 = vec![0.0; 100];
1499 let v2 = vec![1.0; 100];
1500 let dist = euclidean_distance_simd(&v1, &v2);
1501 assert!((dist - 10.0).abs() < 1e-6); }
1503
1504 #[test]
1505 fn test_dot_product_simd() {
1506 let v1 = vec![1.0, 2.0, 3.0];
1507 let v2 = vec![4.0, 5.0, 6.0];
1508 let dot = dot_product_simd(&v1, &v2);
1509 assert!((dot - 32.0).abs() < 1e-6); }
1511
1512 #[test]
1513 fn test_dot_product_simd_large() {
1514 let v1: Vec<f32> = (1..=100).map(|i| i as f32).collect();
1515 let v2: Vec<f32> = (1..=100).map(|i| i as f32).collect();
1516 let dot = dot_product_simd(&v1, &v2);
1517 let expected: f32 = (1..=100).map(|i| (i * i) as f32).sum();
1518 assert!((dot - expected).abs() < 1e-3);
1519 }
1520
1521 #[test]
1522 fn test_manhattan_distance_simd() {
1523 let v1 = vec![1.0, 2.0, 3.0];
1524 let v2 = vec![4.0, 5.0, 6.0];
1525 let dist = manhattan_distance_simd(&v1, &v2);
1526 assert!((dist - 9.0).abs() < 1e-6); }
1528
1529 #[test]
1530 fn test_manhattan_distance_simd_large() {
1531 let v1 = vec![0.0; 100];
1532 let v2 = vec![1.0; 100];
1533 let dist = manhattan_distance_simd(&v1, &v2);
1534 assert!((dist - 100.0).abs() < 1e-6);
1535 }
1536
1537 #[test]
1538 fn test_compute_distance_simd() {
1539 let v1 = vec![1.0, 0.0, 0.0];
1540 let v2 = vec![1.0, 0.0, 0.0];
1541
1542 let sim = compute_distance_simd(DistanceMetric::Cosine, &v1, &v2);
1543 assert!((sim - 1.0).abs() < 1e-6);
1544
1545 let dist = compute_distance_simd(DistanceMetric::Euclidean, &v1, &v2);
1546 assert!(dist.abs() < 1e-6); let dot = compute_distance_simd(DistanceMetric::DotProduct, &v1, &v2);
1549 assert!((dot - 1.0).abs() < 1e-6);
1550
1551 let manhattan = compute_distance_simd(DistanceMetric::Manhattan, &v1, &v2);
1552 assert!(manhattan.abs() < 1e-6);
1553 }
1554
1555 #[test]
1556 fn test_is_avx2_available() {
1557 let _available = is_avx2_available();
1559 #[cfg(not(target_arch = "x86_64"))]
1562 assert!(!is_avx2_available());
1563 }
1564
1565 #[test]
1566 fn test_is_neon_available() {
1567 let available = is_neon_available();
1569 #[cfg(target_arch = "aarch64")]
1571 assert!(available, "NEON should always be available on aarch64");
1572 #[cfg(not(target_arch = "aarch64"))]
1574 assert!(!available, "NEON should not be available on non-aarch64");
1575 }
1576
1577 #[test]
1578 fn test_is_avx512_available() {
1579 let _available = is_avx512_available();
1581 #[cfg(not(target_arch = "x86_64"))]
1584 assert!(!is_avx512_available());
1585 }
1586
1587 #[test]
1588 fn test_avx2_correctness() {
1589 let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
1591 let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
1592
1593 let cosine = cosine_similarity_simd(&v1, &v2);
1594 let euclidean = euclidean_distance_simd(&v1, &v2);
1595 let dot = dot_product_simd(&v1, &v2);
1596 let manhattan = manhattan_distance_simd(&v1, &v2);
1597
1598 assert!(cosine > 0.0 && cosine <= 1.0);
1600 assert!(euclidean > 0.0);
1601 assert!(dot > 0.0);
1602 assert!(manhattan > 0.0);
1603
1604 let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1606 let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1607 let dot_autovec = dot_product_autovec(&v1, &v2);
1608 let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1609
1610 let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1612 assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1613 assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1614 assert!(relative_error(dot, dot_autovec) < 1e-5);
1615 assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1616 }
1617
1618 #[test]
1619 fn test_neon_correctness() {
1620 let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
1622 let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
1623
1624 let cosine = cosine_similarity_simd(&v1, &v2);
1625 let euclidean = euclidean_distance_simd(&v1, &v2);
1626 let dot = dot_product_simd(&v1, &v2);
1627 let manhattan = manhattan_distance_simd(&v1, &v2);
1628
1629 assert!(cosine > 0.0 && cosine <= 1.0);
1631 assert!(euclidean > 0.0);
1632 assert!(dot > 0.0);
1633 assert!(manhattan > 0.0);
1634
1635 let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1637 let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1638 let dot_autovec = dot_product_autovec(&v1, &v2);
1639 let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1640
1641 let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1643 assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1644 assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1645 assert!(relative_error(dot, dot_autovec) < 1e-5);
1646 assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1647 }
1648
1649 #[test]
1650 fn test_avx512_correctness() {
1651 let v1: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.01).collect();
1653 let v2: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.02).collect();
1654
1655 let cosine = cosine_similarity_simd(&v1, &v2);
1656 let euclidean = euclidean_distance_simd(&v1, &v2);
1657 let dot = dot_product_simd(&v1, &v2);
1658 let manhattan = manhattan_distance_simd(&v1, &v2);
1659
1660 assert!(cosine > 0.0 && cosine <= 1.0);
1662 assert!(euclidean > 0.0);
1663 assert!(dot > 0.0);
1664 assert!(manhattan > 0.0);
1665
1666 let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
1668 let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
1669 let dot_autovec = dot_product_autovec(&v1, &v2);
1670 let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
1671
1672 let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
1674 assert!(relative_error(cosine, cosine_autovec) < 1e-5);
1675 assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
1676 assert!(relative_error(dot, dot_autovec) < 1e-5);
1677 assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
1678 }
1679
1680 #[test]
1681 fn test_quantized_manhattan_distance() {
1682 let a = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
1684 let b = vec![15u8, 25, 35, 45, 55, 65, 75, 85];
1685
1686 let distance_simd = quantized_manhattan_distance_simd(&a, &b);
1687 let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
1688
1689 assert_eq!(distance_simd, distance_scalar);
1690 assert_eq!(distance_simd, 40); }
1692
1693 #[test]
1694 fn test_quantized_manhattan_distance_large() {
1695 let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1697 let b: Vec<u8> = (0..768).map(|i| ((i + 10) % 256) as u8).collect();
1698
1699 let distance_simd = quantized_manhattan_distance_simd(&a, &b);
1700 let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
1701
1702 assert_eq!(distance_simd, distance_scalar);
1703 }
1704
1705 #[test]
1706 fn test_quantized_dot_product() {
1707 let a = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
1709 let b = vec![8u8, 7, 6, 5, 4, 3, 2, 1];
1710
1711 let dot_simd = quantized_dot_product_simd(&a, &b);
1712 let dot_scalar = quantized_dot_product_scalar(&a, &b);
1713
1714 assert_eq!(dot_simd, dot_scalar);
1715 assert_eq!(dot_simd, 120);
1717 }
1718
1719 #[test]
1720 fn test_quantized_dot_product_large() {
1721 let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1723 let b: Vec<u8> = (0..768).map(|i| ((255 - i) % 256) as u8).collect();
1724
1725 let dot_simd = quantized_dot_product_simd(&a, &b);
1726 let dot_scalar = quantized_dot_product_scalar(&a, &b);
1727
1728 assert_eq!(dot_simd, dot_scalar);
1729 }
1730
1731 #[test]
1732 fn test_quantized_euclidean_squared() {
1733 let a = vec![10u8, 20, 30, 40];
1735 let b = vec![13u8, 24, 27, 45];
1736
1737 let dist_simd = quantized_euclidean_squared_simd(&a, &b);
1738 let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
1739
1740 assert_eq!(dist_simd, dist_scalar);
1741 assert_eq!(dist_simd, 59);
1743 }
1744
1745 #[test]
1746 fn test_quantized_euclidean_squared_large() {
1747 let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
1749 let b: Vec<u8> = (0..768).map(|i| ((i + 5) % 256) as u8).collect();
1750
1751 let dist_simd = quantized_euclidean_squared_simd(&a, &b);
1752 let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
1753
1754 assert_eq!(dist_simd, dist_scalar);
1755 }
1756
1757 #[test]
1758 fn test_quantized_edge_cases() {
1759 let a = vec![100u8; 100];
1761 let b = vec![100u8; 100];
1762
1763 assert_eq!(quantized_manhattan_distance_simd(&a, &b), 0);
1764 assert_eq!(quantized_euclidean_squared_simd(&a, &b), 0);
1765
1766 let c = vec![0u8; 100];
1768 let d = vec![255u8; 100];
1769
1770 assert_eq!(quantized_manhattan_distance_simd(&c, &d), 255 * 100);
1771 assert_eq!(quantized_euclidean_squared_simd(&c, &d), 255 * 255 * 100);
1772 }
1773
1774 #[test]
1775 fn test_quantized_simd_correctness() {
1776 let a: Vec<u8> = (0..1024).map(|i| ((i * 17 + 42) % 256) as u8).collect();
1778 let b: Vec<u8> = (0..1024).map(|i| ((i * 23 + 99) % 256) as u8).collect();
1779
1780 let manhattan_simd = quantized_manhattan_distance_simd(&a, &b);
1782 let manhattan_scalar = quantized_manhattan_distance_scalar(&a, &b);
1783 assert_eq!(manhattan_simd, manhattan_scalar);
1784
1785 let dot_simd = quantized_dot_product_simd(&a, &b);
1786 let dot_scalar = quantized_dot_product_scalar(&a, &b);
1787 assert_eq!(dot_simd, dot_scalar);
1788
1789 let euclidean_simd = quantized_euclidean_squared_simd(&a, &b);
1790 let euclidean_scalar = quantized_euclidean_squared_scalar(&a, &b);
1791 assert_eq!(euclidean_simd, euclidean_scalar);
1792 }
1793
1794 #[test]
1795 fn test_normalize_vector_simd() {
1796 let mut vec = vec![3.0, 4.0, 0.0];
1797 normalize_vector_simd(&mut vec);
1798
1799 assert!((vec[0] - 0.6).abs() < 1e-6);
1801 assert!((vec[1] - 0.8).abs() < 1e-6);
1802 assert!((vec[2] - 0.0).abs() < 1e-6);
1803
1804 let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
1806 assert!((norm_squared - 1.0).abs() < 1e-6);
1807 }
1808
1809 #[test]
1810 fn test_normalize_vector_simd_large() {
1811 let mut vec: Vec<f32> = (0..768).map(|i| (i % 100) as f32).collect();
1813 normalize_vector_simd(&mut vec);
1814
1815 let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
1817 assert!((norm_squared - 1.0).abs() < 1e-5);
1818 }
1819
1820 #[test]
1821 fn test_normalize_vector_simd_zero() {
1822 let mut vec = vec![0.0, 0.0, 0.0];
1824 normalize_vector_simd(&mut vec);
1825
1826 assert_eq!(vec, vec![0.0, 0.0, 0.0]);
1828 }
1829
1830 #[test]
1831 fn test_scale_vector_simd() {
1832 let mut vec = vec![1.0, 2.0, 3.0, 4.0];
1833 scale_vector_simd(&mut vec, 2.0);
1834
1835 assert_eq!(vec, vec![2.0, 4.0, 6.0, 8.0]);
1836 }
1837
1838 #[test]
1839 fn test_scale_vector_simd_large() {
1840 let mut vec: Vec<f32> = (0..1024).map(|i| i as f32).collect();
1842 scale_vector_simd(&mut vec, 0.5);
1843
1844 for (i, &value) in vec.iter().enumerate() {
1845 assert!((value - (i as f32 * 0.5)).abs() < 1e-5);
1846 }
1847 }
1848}