1use std::convert::AsRef;
7
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11#[cfg(target_arch = "aarch64")]
12use diskann_wide::arch::aarch64::{algorithms, Neon};
13
14use diskann_wide::{
15 arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDDotProduct, SIMDMulAdd,
16 SIMDSumTree, SIMDVector,
17};
18
19use crate::Half;
20
21pub trait LossyF32Conversion: Copy {
23 fn as_f32_lossy(self) -> f32;
24}
25
26impl LossyF32Conversion for f32 {
27 fn as_f32_lossy(self) -> f32 {
28 self
29 }
30}
31
32impl LossyF32Conversion for i32 {
33 fn as_f32_lossy(self) -> f32 {
34 self as f32
35 }
36}
37
38impl LossyF32Conversion for u32 {
39 fn as_f32_lossy(self) -> f32 {
40 self as f32
41 }
42}
43
44cfg_if::cfg_if! {
45 if #[cfg(miri)] {
46 fn force_eval(_x: f32) {}
47 } else if #[cfg(target_arch = "x86_64")] {
48 use std::arch::asm;
49
50 #[inline(always)]
56 fn force_eval(x: f32) {
57 unsafe {
62 asm!(
63 "/* {0} */",
65 in(xmm_reg) x,
68 options(nostack, nomem, preserves_flags)
77 )
78 }
79 }
80 } else {
81 fn force_eval(_x: f32) {}
83 }
84}
85
86#[derive(Debug, Clone, Copy)]
97pub struct Loader<Schema, Left, Right, A>
98where
99 Schema: SIMDSchema<Left, Right, A>,
100 A: Architecture,
101{
102 arch: A,
103 schema: Schema,
104 left: *const Left,
105 right: *const Right,
106 len: usize,
107}
108
109impl<Schema, Left, Right, A> Loader<Schema, Left, Right, A>
110where
111 Schema: SIMDSchema<Left, Right, A>,
112 A: Architecture,
113{
114 #[inline(always)]
119 fn new(arch: A, schema: Schema, left: *const Left, right: *const Right, len: usize) -> Self {
120 Self {
121 arch,
122 schema,
123 left,
124 right,
125 len,
126 }
127 }
128
129 #[inline(always)]
131 fn arch(&self) -> A {
132 self.arch
133 }
134
135 #[inline(always)]
137 fn schema(&self) -> Schema {
138 self.schema
139 }
140
141 #[inline(always)]
172 unsafe fn load(&self, block: usize, offset: usize) -> (Schema::Left, Schema::Right) {
173 let stride = Schema::SIMDWidth::value();
174 let block_stride = stride * Schema::Main::BLOCK_SIZE;
175 let offset = block_stride * block + stride * offset;
176
177 debug_assert!(
178 offset + stride <= self.len,
179 "length = {}, offset = {}",
180 self.len,
181 offset
182 );
183
184 (
185 Schema::Left::load_simd(self.arch, self.left.add(offset)),
186 Schema::Right::load_simd(self.arch, self.right.add(offset)),
187 )
188 }
189}
190
191pub trait MainLoop {
193 const BLOCK_SIZE: usize;
200
201 unsafe fn main<S, L, R, A>(
237 loader: &Loader<S, L, R, A>,
238 trip_count: usize,
239 epilogues: usize,
240 ) -> S::Accumulator
241 where
242 A: Architecture,
243 S: SIMDSchema<L, R, A>;
244}
245pub struct Strategy1x1;
248
249pub struct Strategy2x1;
252
253pub struct Strategy4x1;
256
257pub struct Strategy4x2;
260
261pub struct Strategy2x4;
264
265impl MainLoop for Strategy1x1 {
266 const BLOCK_SIZE: usize = 1;
267
268 #[inline(always)]
269 unsafe fn main<S, L, R, A>(
270 loader: &Loader<S, L, R, A>,
271 trip_count: usize,
272 _epilogues: usize,
273 ) -> S::Accumulator
274 where
275 A: Architecture,
276 S: SIMDSchema<L, R, A>,
277 {
278 let arch = loader.arch();
279 let schema = loader.schema();
280
281 let mut s0 = schema.init(arch);
282 for i in 0..trip_count {
283 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
284 }
285
286 s0
287 }
288}
289
290impl MainLoop for Strategy2x1 {
291 const BLOCK_SIZE: usize = 2;
292
293 #[inline(always)]
294 unsafe fn main<S, L, R, A>(
295 loader: &Loader<S, L, R, A>,
296 trip_count: usize,
297 epilogues: usize,
298 ) -> S::Accumulator
299 where
300 A: Architecture,
301 S: SIMDSchema<L, R, A>,
302 {
303 let arch = loader.arch();
304 let schema = loader.schema();
305
306 let mut s0 = schema.init(arch);
307 let mut s1 = schema.init(arch);
308
309 for i in 0..trip_count {
310 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
311 s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
312 }
313
314 let mut s = schema.combine(s0, s1);
315 if epilogues != 0 {
316 s = schema.accumulate_tuple(s, loader.load(trip_count, 0));
317 }
318
319 s
320 }
321}
322
323impl MainLoop for Strategy4x1 {
324 const BLOCK_SIZE: usize = 4;
325
326 #[inline(always)]
327 unsafe fn main<S, L, R, A>(
328 loader: &Loader<S, L, R, A>,
329 trip_count: usize,
330 epilogues: usize,
331 ) -> S::Accumulator
332 where
333 A: Architecture,
334 S: SIMDSchema<L, R, A>,
335 {
336 let arch = loader.arch();
337 let schema = loader.schema();
338
339 let mut s0 = schema.init(arch);
340 let mut s1 = schema.init(arch);
341 let mut s2 = schema.init(arch);
342 let mut s3 = schema.init(arch);
343
344 for i in 0..trip_count {
345 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
346 s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
347 s2 = schema.accumulate_tuple(s2, loader.load(i, 2));
348 s3 = schema.accumulate_tuple(s3, loader.load(i, 3));
349 }
350
351 if epilogues >= 1 {
352 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
353 }
354
355 if epilogues >= 2 {
356 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
357 }
358
359 if epilogues >= 3 {
360 s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
361 }
362
363 schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
364 }
365}
366
367impl MainLoop for Strategy4x2 {
368 const BLOCK_SIZE: usize = 4;
369
370 #[inline(always)]
371 unsafe fn main<S, L, R, A>(
372 loader: &Loader<S, L, R, A>,
373 trip_count: usize,
374 epilogues: usize,
375 ) -> S::Accumulator
376 where
377 A: Architecture,
378 S: SIMDSchema<L, R, A>,
379 {
380 let arch = loader.arch();
381 let schema = loader.schema();
382
383 let mut s0 = schema.init(arch);
384 let mut s1 = schema.init(arch);
385 let mut s2 = schema.init(arch);
386 let mut s3 = schema.init(arch);
387
388 for i in 0..(trip_count / 2) {
389 let j = 2 * i;
390 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
391 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
392 s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
393 s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
394
395 s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
396 s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
397 s2 = schema.accumulate_tuple(s2, loader.load(j, 6));
398 s3 = schema.accumulate_tuple(s3, loader.load(j, 7));
399 }
400
401 if !trip_count.is_multiple_of(2) {
402 let j = trip_count - 1;
404 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
405 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
406 s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
407 s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
408 }
409
410 if epilogues >= 1 {
411 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
412 }
413
414 if epilogues >= 2 {
415 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
416 }
417
418 if epilogues >= 3 {
419 s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
420 }
421
422 schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
423 }
424}
425
426impl MainLoop for Strategy2x4 {
427 const BLOCK_SIZE: usize = 4;
428
429 #[inline(always)]
435 unsafe fn main<S, L, R, A>(
436 loader: &Loader<S, L, R, A>,
437 trip_count: usize,
438 epilogues: usize,
439 ) -> S::Accumulator
440 where
441 A: Architecture,
442 S: SIMDSchema<L, R, A>,
443 {
444 let arch = loader.arch();
445 let schema = loader.schema();
446
447 let mut s0 = schema.init(arch);
448 let mut s1 = schema.init(arch);
449
450 for i in 0..(trip_count / 2) {
451 let j = 2 * i;
452 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
453 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
454 s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
455 s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
456
457 s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
458 s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
459 s0 = schema.accumulate_tuple(s0, loader.load(j, 6));
460 s1 = schema.accumulate_tuple(s1, loader.load(j, 7));
461 }
462
463 if !trip_count.is_multiple_of(2) {
464 let j = trip_count - 1;
465 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
466 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
467 s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
468 s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
469 }
470
471 if epilogues >= 1 {
472 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
473 }
474
475 if epilogues >= 2 {
476 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
477 }
478
479 if epilogues >= 3 {
480 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 2));
481 }
482
483 schema.combine(s0, s1)
484 }
485}
486
487pub trait SIMDSchema<T, U, A: Architecture = diskann_wide::arch::Current>: Copy {
495 type SIMDWidth: Constant<Type = usize>;
498
499 type Accumulator: std::ops::Add<Output = Self::Accumulator> + std::fmt::Debug + Copy;
501
502 type Left: SIMDVector<Arch = A, Scalar = T, ConstLanes = Self::SIMDWidth>;
504
505 type Right: SIMDVector<Arch = A, Scalar = U, ConstLanes = Self::SIMDWidth>;
507
508 type Return;
511
512 type Main: MainLoop;
514
515 fn init(&self, arch: A) -> Self::Accumulator;
517
518 fn accumulate(
520 &self,
521 x: Self::Left,
522 y: Self::Right,
523 acc: Self::Accumulator,
524 ) -> Self::Accumulator;
525
526 #[inline(always)]
528 fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
529 x + y
530 }
531
532 #[inline(always)]
550 unsafe fn epilogue(
551 &self,
552 arch: A,
553 x: *const T,
554 y: *const U,
555 len: usize,
556 acc: Self::Accumulator,
557 ) -> Self::Accumulator {
558 let a = Self::Left::load_simd_first(arch, x, len);
561
562 let b = Self::Right::load_simd_first(arch, y, len);
565 self.accumulate(a, b, acc)
566 }
567
568 fn reduce(&self, x: Self::Accumulator) -> Self::Return;
572
573 #[inline(always)]
578 fn get_simd_width() -> usize {
579 Self::SIMDWidth::value()
580 }
581
582 #[inline(always)]
588 fn get_main_bocksize() -> usize {
589 Self::Main::BLOCK_SIZE
590 }
591
592 #[doc(hidden)]
595 #[inline(always)]
596 fn accumulate_tuple(
597 &self,
598 acc: Self::Accumulator,
599 (x, y): (Self::Left, Self::Right),
600 ) -> Self::Accumulator {
601 self.accumulate(x, y, acc)
602 }
603}
604
605pub trait ResumableSIMDSchema<T, U, A = diskann_wide::arch::Current>: Copy
613where
614 A: Architecture,
615{
616 type NonResumable: SIMDSchema<T, U, A> + Default;
618 type FinalReturn;
619
620 fn init(arch: A) -> Self;
621 fn combine_with(&self, other: <Self::NonResumable as SIMDSchema<T, U, A>>::Accumulator)
622 -> Self;
623 fn sum(&self) -> Self::FinalReturn;
624}
625
626#[derive(Debug, Clone, Copy)]
627pub struct Resumable<T>(T);
628
629impl<T> Resumable<T> {
630 pub fn new(val: T) -> Self {
631 Self(val)
632 }
633
634 pub fn consume(self) -> T {
635 self.0
636 }
637}
638
639impl<T, U, R, A> SIMDSchema<T, U, A> for Resumable<R>
640where
641 A: Architecture,
642 R: ResumableSIMDSchema<T, U, A>,
643{
644 type SIMDWidth = <R::NonResumable as SIMDSchema<T, U, A>>::SIMDWidth;
645 type Accumulator = <R::NonResumable as SIMDSchema<T, U, A>>::Accumulator;
646 type Left = <R::NonResumable as SIMDSchema<T, U, A>>::Left;
647 type Right = <R::NonResumable as SIMDSchema<T, U, A>>::Right;
648 type Return = Self;
649 type Main = <R::NonResumable as SIMDSchema<T, U, A>>::Main;
650
651 fn init(&self, arch: A) -> Self::Accumulator {
652 R::NonResumable::default().init(arch)
653 }
654
655 fn accumulate(
656 &self,
657 x: Self::Left,
658 y: Self::Right,
659 acc: Self::Accumulator,
660 ) -> Self::Accumulator {
661 R::NonResumable::default().accumulate(x, y, acc)
662 }
663
664 fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
665 R::NonResumable::default().combine(x, y)
666 }
667
668 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
669 Self(self.0.combine_with(x))
670 }
671}
672
673#[inline(never)]
674#[allow(clippy::panic)]
675fn emit_length_error(xlen: usize, ylen: usize) -> ! {
676 panic!(
677 "lengths must be equal, instead got: xlen = {}, ylen = {}",
678 xlen, ylen
679 )
680}
681
682#[inline(always)]
688pub fn simd_op<L, R, S, T, U, A>(schema: &S, arch: A, x: T, y: U) -> S::Return
689where
690 A: Architecture,
691 T: AsRef<[L]>,
692 U: AsRef<[R]>,
693 S: SIMDSchema<L, R, A>,
694{
695 let x: &[L] = x.as_ref();
696 let y: &[R] = y.as_ref();
697
698 let len = x.len();
699
700 if len != y.len() {
708 emit_length_error(len, y.len());
709 }
710 let px = x.as_ptr();
711 let py = y.as_ptr();
712
713 let simd_width: usize = S::get_simd_width();
723 let unroll: usize = S::get_main_bocksize();
724
725 let trip_count = len / (simd_width * unroll);
726 let epilogues = (len - simd_width * unroll * trip_count) / simd_width;
727
728 let loader: Loader<S, L, R, A> = Loader::new(arch, *schema, px, py, len);
731
732 let mut s0 = unsafe { <S as SIMDSchema<L, R, A>>::Main::main(&loader, trip_count, epilogues) };
736
737 let remainder = len % simd_width;
738 if remainder != 0 {
739 let i = len - remainder;
740
741 s0 = unsafe { schema.epilogue(arch, px.add(i), py.add(i), remainder, s0) };
746 }
747
748 schema.reduce(s0)
749}
750
751#[cfg(target_arch = "aarch64")]
756#[inline(always)]
757unsafe fn scalar_epilogue<L, R, F, Acc>(
758 left: *const L,
759 right: *const R,
760 len: usize,
761 mut acc: Acc,
762 mut f: F,
763) -> Acc
764where
765 L: Copy,
766 R: Copy,
767 F: FnMut(Acc, L, R) -> Acc,
768{
769 for i in 0..len {
770 let left = unsafe { left.add(i).read_unaligned() };
772 let right = unsafe { right.add(i).read_unaligned() };
774 acc = f(acc, left, right);
775 }
776 acc
777}
778
779#[derive(Debug, Default, Clone, Copy)]
785pub struct L2;
786
787#[cfg(target_arch = "x86_64")]
788impl SIMDSchema<f32, f32, V4> for L2 {
789 type SIMDWidth = Const<8>;
790 type Accumulator = <V4 as Architecture>::f32x8;
791 type Left = <V4 as Architecture>::f32x8;
792 type Right = <V4 as Architecture>::f32x8;
793 type Return = f32;
794 type Main = Strategy4x1;
795
796 #[inline(always)]
797 fn init(&self, arch: V4) -> Self::Accumulator {
798 Self::Accumulator::default(arch)
799 }
800
801 #[inline(always)]
802 fn accumulate(
803 &self,
804 x: Self::Left,
805 y: Self::Right,
806 acc: Self::Accumulator,
807 ) -> Self::Accumulator {
808 let c = x - y;
809 c.mul_add_simd(c, acc)
810 }
811
812 #[inline(always)]
813 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
814 x.sum_tree()
815 }
816}
817
818#[cfg(target_arch = "x86_64")]
819impl SIMDSchema<f32, f32, V3> for L2 {
820 type SIMDWidth = Const<8>;
821 type Accumulator = <V3 as Architecture>::f32x8;
822 type Left = <V3 as Architecture>::f32x8;
823 type Right = <V3 as Architecture>::f32x8;
824 type Return = f32;
825 type Main = Strategy4x1;
826
827 #[inline(always)]
828 fn init(&self, arch: V3) -> Self::Accumulator {
829 Self::Accumulator::default(arch)
830 }
831
832 #[inline(always)]
833 fn accumulate(
834 &self,
835 x: Self::Left,
836 y: Self::Right,
837 acc: Self::Accumulator,
838 ) -> Self::Accumulator {
839 let c = x - y;
840 c.mul_add_simd(c, acc)
841 }
842
843 #[inline(always)]
844 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
845 x.sum_tree()
846 }
847}
848
849#[cfg(target_arch = "aarch64")]
850impl SIMDSchema<f32, f32, Neon> for L2 {
851 type SIMDWidth = Const<4>;
852 type Accumulator = <Neon as Architecture>::f32x4;
853 type Left = <Neon as Architecture>::f32x4;
854 type Right = <Neon as Architecture>::f32x4;
855 type Return = f32;
856 type Main = Strategy4x1;
857
858 #[inline(always)]
859 fn init(&self, arch: Neon) -> Self::Accumulator {
860 Self::Accumulator::default(arch)
861 }
862
863 #[inline(always)]
864 fn accumulate(
865 &self,
866 x: Self::Left,
867 y: Self::Right,
868 acc: Self::Accumulator,
869 ) -> Self::Accumulator {
870 let c = x - y;
871 c.mul_add_simd(c, acc)
872 }
873
874 #[inline(always)]
875 unsafe fn epilogue(
876 &self,
877 arch: Neon,
878 x: *const f32,
879 y: *const f32,
880 len: usize,
881 acc: Self::Accumulator,
882 ) -> Self::Accumulator {
883 let scalar = scalar_epilogue(
884 x,
885 y,
886 len.min(Self::SIMDWidth::value() - 1),
887 0.0f32,
888 |acc, x, y| -> f32 {
889 let c = x - y;
890 c.mul_add(c, acc)
891 },
892 );
893 acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0])
894 }
895
896 #[inline(always)]
897 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
898 x.sum_tree()
899 }
900}
901
902impl SIMDSchema<f32, f32, Scalar> for L2 {
903 type SIMDWidth = Const<4>;
904 type Accumulator = Emulated<f32, 4>;
905 type Left = Emulated<f32, 4>;
906 type Right = Emulated<f32, 4>;
907 type Return = f32;
908 type Main = Strategy2x1;
909
910 #[inline(always)]
911 fn init(&self, arch: Scalar) -> Self::Accumulator {
912 Self::Accumulator::default(arch)
913 }
914
915 #[inline(always)]
916 fn accumulate(
917 &self,
918 x: Self::Left,
919 y: Self::Right,
920 acc: Self::Accumulator,
921 ) -> Self::Accumulator {
922 let c = x - y;
924 (c * c) + acc
925 }
926
927 #[inline(always)]
928 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
929 x.sum_tree()
930 }
931
932 #[inline(always)]
933 unsafe fn epilogue(
934 &self,
935 arch: Scalar,
936 x: *const f32,
937 y: *const f32,
938 len: usize,
939 acc: Self::Accumulator,
940 ) -> Self::Accumulator {
941 let mut s: f32 = 0.0;
942 for i in 0..len {
943 let vx = unsafe { x.add(i).read_unaligned() };
945 let vy = unsafe { y.add(i).read_unaligned() };
947 let d = vx - vy;
948 s += d * d;
949 }
950 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
951 }
952}
953
954#[cfg(target_arch = "x86_64")]
955impl SIMDSchema<Half, Half, V4> for L2 {
956 type SIMDWidth = Const<8>;
957 type Accumulator = <V4 as Architecture>::f32x8;
958 type Left = <V4 as Architecture>::f16x8;
959 type Right = <V4 as Architecture>::f16x8;
960 type Return = f32;
961 type Main = Strategy2x4;
962
963 #[inline(always)]
964 fn init(&self, arch: V4) -> Self::Accumulator {
965 Self::Accumulator::default(arch)
966 }
967
968 #[inline(always)]
969 fn accumulate(
970 &self,
971 x: Self::Left,
972 y: Self::Right,
973 acc: Self::Accumulator,
974 ) -> Self::Accumulator {
975 diskann_wide::alias!(f32s = <V4>::f32x8);
976
977 let x: f32s = x.into();
978 let y: f32s = y.into();
979
980 let c = x - y;
981 c.mul_add_simd(c, acc)
982 }
983
984 #[inline(always)]
985 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
986 x.sum_tree()
987 }
988}
989
990#[cfg(target_arch = "x86_64")]
991impl SIMDSchema<Half, Half, V3> for L2 {
992 type SIMDWidth = Const<8>;
993 type Accumulator = <V3 as Architecture>::f32x8;
994 type Left = <V3 as Architecture>::f16x8;
995 type Right = <V3 as Architecture>::f16x8;
996 type Return = f32;
997 type Main = Strategy2x4;
998
999 #[inline(always)]
1000 fn init(&self, arch: V3) -> Self::Accumulator {
1001 Self::Accumulator::default(arch)
1002 }
1003
1004 #[inline(always)]
1005 fn accumulate(
1006 &self,
1007 x: Self::Left,
1008 y: Self::Right,
1009 acc: Self::Accumulator,
1010 ) -> Self::Accumulator {
1011 diskann_wide::alias!(f32s = <V3>::f32x8);
1012
1013 let x: f32s = x.into();
1014 let y: f32s = y.into();
1015
1016 let c = x - y;
1017 c.mul_add_simd(c, acc)
1018 }
1019
1020 #[inline(always)]
1022 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1023 x.sum_tree()
1024 }
1025}
1026
1027#[cfg(target_arch = "aarch64")]
1028impl SIMDSchema<Half, Half, Neon> for L2 {
1029 type SIMDWidth = Const<4>;
1030 type Accumulator = <Neon as Architecture>::f32x4;
1031 type Left = diskann_wide::arch::aarch64::f16x4;
1032 type Right = diskann_wide::arch::aarch64::f16x4;
1033 type Return = f32;
1034 type Main = Strategy4x1;
1035
1036 #[inline(always)]
1037 fn init(&self, arch: Neon) -> Self::Accumulator {
1038 Self::Accumulator::default(arch)
1039 }
1040
1041 #[inline(always)]
1042 fn accumulate(
1043 &self,
1044 x: Self::Left,
1045 y: Self::Right,
1046 acc: Self::Accumulator,
1047 ) -> Self::Accumulator {
1048 diskann_wide::alias!(f32s = <Neon>::f32x4);
1049
1050 let x: f32s = x.into();
1051 let y: f32s = y.into();
1052
1053 let c = x - y;
1054 c.mul_add_simd(c, acc)
1055 }
1056
1057 #[inline(always)]
1058 unsafe fn epilogue(
1059 &self,
1060 arch: Neon,
1061 x: *const Half,
1062 y: *const Half,
1063 len: usize,
1064 acc: Self::Accumulator,
1065 ) -> Self::Accumulator {
1066 diskann_wide::alias!(f32s = <Neon>::f32x4);
1067
1068 let rest = scalar_epilogue(
1069 x,
1070 y,
1071 len.min(Self::SIMDWidth::value() - 1),
1072 f32s::default(arch),
1073 |acc, x: Half, y: Half| -> f32s {
1074 let zero = Half::default();
1075 let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
1076 let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
1077 let c: f32s = x - y;
1078 c.mul_add_simd(c, acc)
1079 },
1080 );
1081 acc + rest
1082 }
1083
1084 #[inline(always)]
1085 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1086 x.sum_tree()
1087 }
1088}
1089
1090impl SIMDSchema<Half, Half, Scalar> for L2 {
1091 type SIMDWidth = Const<1>;
1092 type Accumulator = Emulated<f32, 1>;
1093 type Left = Emulated<Half, 1>;
1094 type Right = Emulated<Half, 1>;
1095 type Return = f32;
1096 type Main = Strategy1x1;
1097
1098 #[inline(always)]
1099 fn init(&self, arch: Scalar) -> Self::Accumulator {
1100 Self::Accumulator::default(arch)
1101 }
1102
1103 #[inline(always)]
1104 fn accumulate(
1105 &self,
1106 x: Self::Left,
1107 y: Self::Right,
1108 acc: Self::Accumulator,
1109 ) -> Self::Accumulator {
1110 let x: Self::Accumulator = x.into();
1111 let y: Self::Accumulator = y.into();
1112
1113 let c = x - y;
1114 acc + (c * c)
1115 }
1116
1117 #[inline(always)]
1118 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1119 x.to_array()[0]
1120 }
1121}
1122
1123impl<A> SIMDSchema<f32, Half, A> for L2
1124where
1125 A: Architecture,
1126{
1127 type SIMDWidth = Const<8>;
1128 type Accumulator = A::f32x8;
1129 type Left = A::f32x8;
1130 type Right = A::f16x8;
1131 type Return = f32;
1132 type Main = Strategy4x2;
1133
1134 #[inline(always)]
1135 fn init(&self, arch: A) -> Self::Accumulator {
1136 Self::Accumulator::default(arch)
1137 }
1138
1139 #[inline(always)]
1140 fn accumulate(
1141 &self,
1142 x: Self::Left,
1143 y: Self::Right,
1144 acc: Self::Accumulator,
1145 ) -> Self::Accumulator {
1146 let y: A::f32x8 = y.into();
1147 let c = x - y;
1148 c.mul_add_simd(c, acc)
1149 }
1150
1151 #[inline(always)]
1153 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1154 x.sum_tree()
1155 }
1156}
1157
1158#[cfg(target_arch = "x86_64")]
1159impl SIMDSchema<i8, i8, V4> for L2 {
1160 type SIMDWidth = Const<32>;
1161 type Accumulator = <V4 as Architecture>::i32x16;
1162 type Left = <V4 as Architecture>::i8x32;
1163 type Right = <V4 as Architecture>::i8x32;
1164 type Return = f32;
1165 type Main = Strategy4x1;
1166
1167 #[inline(always)]
1168 fn init(&self, arch: V4) -> Self::Accumulator {
1169 Self::Accumulator::default(arch)
1170 }
1171
1172 #[inline(always)]
1173 fn accumulate(
1174 &self,
1175 x: Self::Left,
1176 y: Self::Right,
1177 acc: Self::Accumulator,
1178 ) -> Self::Accumulator {
1179 diskann_wide::alias!(i16s = <V4>::i16x32);
1180
1181 let x: i16s = x.into();
1182 let y: i16s = y.into();
1183 let c = x - y;
1184 acc.dot_simd(c, c)
1185 }
1186
1187 #[inline(always)]
1188 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1189 x.sum_tree().as_f32_lossy()
1190 }
1191}
1192
1193#[cfg(target_arch = "x86_64")]
1194impl SIMDSchema<i8, i8, V3> for L2 {
1195 type SIMDWidth = Const<16>;
1196 type Accumulator = <V3 as Architecture>::i32x8;
1197 type Left = <V3 as Architecture>::i8x16;
1198 type Right = <V3 as Architecture>::i8x16;
1199 type Return = f32;
1200 type Main = Strategy4x1;
1201
1202 #[inline(always)]
1203 fn init(&self, arch: V3) -> Self::Accumulator {
1204 Self::Accumulator::default(arch)
1205 }
1206
1207 #[inline(always)]
1208 fn accumulate(
1209 &self,
1210 x: Self::Left,
1211 y: Self::Right,
1212 acc: Self::Accumulator,
1213 ) -> Self::Accumulator {
1214 diskann_wide::alias!(i16s = <V3>::i16x16);
1215
1216 let x: i16s = x.into();
1217 let y: i16s = y.into();
1218 let c = x - y;
1219 acc.dot_simd(c, c)
1220 }
1221
1222 #[inline(always)]
1224 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1225 x.sum_tree().as_f32_lossy()
1226 }
1227}
1228
1229#[cfg(target_arch = "aarch64")]
1230impl SIMDSchema<i8, i8, Neon> for L2 {
1231 type SIMDWidth = Const<16>;
1232 type Accumulator = <Neon as Architecture>::i32x8;
1233 type Left = diskann_wide::arch::aarch64::i8x16;
1234 type Right = diskann_wide::arch::aarch64::i8x16;
1235 type Return = f32;
1236 type Main = Strategy2x1;
1237
1238 #[inline(always)]
1239 fn init(&self, arch: Neon) -> Self::Accumulator {
1240 Self::Accumulator::default(arch)
1241 }
1242
1243 #[inline(always)]
1244 fn accumulate(
1245 &self,
1246 x: Self::Left,
1247 y: Self::Right,
1248 acc: Self::Accumulator,
1249 ) -> Self::Accumulator {
1250 algorithms::squared_euclidean_accum_i8x16(x, y, acc)
1251 }
1252
1253 #[inline(always)]
1254 unsafe fn epilogue(
1255 &self,
1256 arch: Neon,
1257 x: *const i8,
1258 y: *const i8,
1259 len: usize,
1260 acc: Self::Accumulator,
1261 ) -> Self::Accumulator {
1262 let scalar = scalar_epilogue(
1263 x,
1264 y,
1265 len.min(Self::SIMDWidth::value() - 1),
1266 0i32,
1267 |acc, x: i8, y: i8| -> i32 {
1268 let c = (x as i32) - (y as i32);
1269 acc + c * c
1270 },
1271 );
1272 acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0, 0, 0, 0, 0])
1273 }
1274
1275 #[inline(always)]
1277 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1278 x.sum_tree().as_f32_lossy()
1279 }
1280}
1281
1282impl SIMDSchema<i8, i8, Scalar> for L2 {
1283 type SIMDWidth = Const<4>;
1284 type Accumulator = Emulated<i32, 4>;
1285 type Left = Emulated<i8, 4>;
1286 type Right = Emulated<i8, 4>;
1287 type Return = f32;
1288 type Main = Strategy1x1;
1289
1290 #[inline(always)]
1291 fn init(&self, arch: Scalar) -> Self::Accumulator {
1292 Self::Accumulator::default(arch)
1293 }
1294
1295 #[inline(always)]
1296 fn accumulate(
1297 &self,
1298 x: Self::Left,
1299 y: Self::Right,
1300 acc: Self::Accumulator,
1301 ) -> Self::Accumulator {
1302 let x: Self::Accumulator = x.into();
1303 let y: Self::Accumulator = y.into();
1304 let c = x - y;
1305 acc + c * c
1306 }
1307
1308 #[inline(always)]
1310 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1311 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1312 }
1313
1314 #[inline(always)]
1315 unsafe fn epilogue(
1316 &self,
1317 arch: Scalar,
1318 x: *const i8,
1319 y: *const i8,
1320 len: usize,
1321 acc: Self::Accumulator,
1322 ) -> Self::Accumulator {
1323 let mut s: i32 = 0;
1324 for i in 0..len {
1325 let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
1327 let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
1329 let d = vx - vy;
1330 s += d * d;
1331 }
1332 acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1333 }
1334}
1335
1336#[cfg(target_arch = "x86_64")]
1337impl SIMDSchema<u8, u8, V4> for L2 {
1338 type SIMDWidth = Const<32>;
1339 type Accumulator = <V4 as Architecture>::i32x16;
1340 type Left = <V4 as Architecture>::u8x32;
1341 type Right = <V4 as Architecture>::u8x32;
1342 type Return = f32;
1343 type Main = Strategy4x1;
1344
1345 #[inline(always)]
1346 fn init(&self, arch: V4) -> Self::Accumulator {
1347 Self::Accumulator::default(arch)
1348 }
1349
1350 #[inline(always)]
1351 fn accumulate(
1352 &self,
1353 x: Self::Left,
1354 y: Self::Right,
1355 acc: Self::Accumulator,
1356 ) -> Self::Accumulator {
1357 diskann_wide::alias!(i16s = <V4>::i16x32);
1358
1359 let x: i16s = x.into();
1360 let y: i16s = y.into();
1361 let c = x - y;
1362 acc.dot_simd(c, c)
1363 }
1364
1365 #[inline(always)]
1366 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1367 x.sum_tree().as_f32_lossy()
1368 }
1369}
1370
1371#[cfg(target_arch = "x86_64")]
1372impl SIMDSchema<u8, u8, V3> for L2 {
1373 type SIMDWidth = Const<16>;
1374 type Accumulator = <V3 as Architecture>::i32x8;
1375 type Left = <V3 as Architecture>::u8x16;
1376 type Right = <V3 as Architecture>::u8x16;
1377 type Return = f32;
1378 type Main = Strategy4x1;
1379
1380 #[inline(always)]
1381 fn init(&self, arch: V3) -> Self::Accumulator {
1382 Self::Accumulator::default(arch)
1383 }
1384
1385 #[inline(always)]
1386 fn accumulate(
1387 &self,
1388 x: Self::Left,
1389 y: Self::Right,
1390 acc: Self::Accumulator,
1391 ) -> Self::Accumulator {
1392 diskann_wide::alias!(i16s = <V3>::i16x16);
1393
1394 let x: i16s = x.into();
1395 let y: i16s = y.into();
1396 let c = x - y;
1397 acc.dot_simd(c, c)
1398 }
1399
1400 #[inline(always)]
1402 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1403 x.sum_tree().as_f32_lossy()
1404 }
1405}
1406
1407#[cfg(target_arch = "aarch64")]
1408impl SIMDSchema<u8, u8, Neon> for L2 {
1409 type SIMDWidth = Const<16>;
1410 type Accumulator = <Neon as Architecture>::u32x8;
1411 type Left = diskann_wide::arch::aarch64::u8x16;
1412 type Right = diskann_wide::arch::aarch64::u8x16;
1413 type Return = f32;
1414 type Main = Strategy2x1;
1415
1416 #[inline(always)]
1417 fn init(&self, arch: Neon) -> Self::Accumulator {
1418 Self::Accumulator::default(arch)
1419 }
1420
1421 #[inline(always)]
1422 fn accumulate(
1423 &self,
1424 x: Self::Left,
1425 y: Self::Right,
1426 acc: Self::Accumulator,
1427 ) -> Self::Accumulator {
1428 algorithms::squared_euclidean_accum_u8x16(x, y, acc)
1429 }
1430
1431 #[inline(always)]
1432 unsafe fn epilogue(
1433 &self,
1434 arch: Neon,
1435 x: *const u8,
1436 y: *const u8,
1437 len: usize,
1438 acc: Self::Accumulator,
1439 ) -> Self::Accumulator {
1440 let scalar = scalar_epilogue(
1441 x,
1442 y,
1443 len.min(Self::SIMDWidth::value() - 1),
1444 0u32,
1445 |acc, x: u8, y: u8| -> u32 {
1446 let c = (x as i32) - (y as i32);
1447 acc + ((c * c) as u32)
1448 },
1449 );
1450 acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0, 0, 0, 0, 0])
1451 }
1452
1453 #[inline(always)]
1455 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1456 x.sum_tree().as_f32_lossy()
1457 }
1458}
1459
1460impl SIMDSchema<u8, u8, Scalar> for L2 {
1461 type SIMDWidth = Const<4>;
1462 type Accumulator = Emulated<i32, 4>;
1463 type Left = Emulated<u8, 4>;
1464 type Right = Emulated<u8, 4>;
1465 type Return = f32;
1466 type Main = Strategy1x1;
1467
1468 #[inline(always)]
1469 fn init(&self, arch: Scalar) -> Self::Accumulator {
1470 Self::Accumulator::default(arch)
1471 }
1472
1473 #[inline(always)]
1474 fn accumulate(
1475 &self,
1476 x: Self::Left,
1477 y: Self::Right,
1478 acc: Self::Accumulator,
1479 ) -> Self::Accumulator {
1480 let x: Self::Accumulator = x.into();
1481 let y: Self::Accumulator = y.into();
1482 let c = x - y;
1483 acc + c * c
1484 }
1485
1486 #[inline(always)]
1488 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1489 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1490 }
1491
1492 #[inline(always)]
1493 unsafe fn epilogue(
1494 &self,
1495 arch: Scalar,
1496 x: *const u8,
1497 y: *const u8,
1498 len: usize,
1499 acc: Self::Accumulator,
1500 ) -> Self::Accumulator {
1501 let mut s: i32 = 0;
1502 for i in 0..len {
1503 let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
1505 let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
1507 let d = vx - vy;
1508 s += d * d;
1509 }
1510 acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1511 }
1512}
1513
1514#[derive(Clone, Copy, Debug)]
1517pub struct ResumableL2<A = diskann_wide::arch::Current>
1518where
1519 A: Architecture,
1520 L2: SIMDSchema<f32, f32, A>,
1521{
1522 acc: <L2 as SIMDSchema<f32, f32, A>>::Accumulator,
1523}
1524
1525impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableL2<A>
1526where
1527 A: Architecture,
1528 L2: SIMDSchema<f32, f32, A, Return = f32>,
1529{
1530 type NonResumable = L2;
1531 type FinalReturn = f32;
1532
1533 #[inline(always)]
1534 fn init(arch: A) -> Self {
1535 Self { acc: L2.init(arch) }
1536 }
1537
1538 #[inline(always)]
1539 fn combine_with(&self, other: <L2 as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1540 Self {
1541 acc: self.acc + other,
1542 }
1543 }
1544
1545 #[inline(always)]
1546 fn sum(&self) -> f32 {
1547 L2.reduce(self.acc)
1548 }
1549}
1550
1551#[derive(Clone, Copy, Debug, Default)]
1557pub struct IP;
1558
1559#[cfg(target_arch = "x86_64")]
1560impl SIMDSchema<f32, f32, V4> for IP {
1561 type SIMDWidth = Const<8>;
1562 type Accumulator = <V4 as Architecture>::f32x8;
1563 type Left = <V4 as Architecture>::f32x8;
1564 type Right = <V4 as Architecture>::f32x8;
1565 type Return = f32;
1566 type Main = Strategy4x1;
1567
1568 #[inline(always)]
1569 fn init(&self, arch: V4) -> Self::Accumulator {
1570 Self::Accumulator::default(arch)
1571 }
1572
1573 #[inline(always)]
1574 fn accumulate(
1575 &self,
1576 x: Self::Left,
1577 y: Self::Right,
1578 acc: Self::Accumulator,
1579 ) -> Self::Accumulator {
1580 x.mul_add_simd(y, acc)
1581 }
1582
1583 #[inline(always)]
1584 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1585 x.sum_tree()
1586 }
1587}
1588
1589#[cfg(target_arch = "x86_64")]
1590impl SIMDSchema<f32, f32, V3> for IP {
1591 type SIMDWidth = Const<8>;
1592 type Accumulator = <V3 as Architecture>::f32x8;
1593 type Left = <V3 as Architecture>::f32x8;
1594 type Right = <V3 as Architecture>::f32x8;
1595 type Return = f32;
1596 type Main = Strategy4x1;
1597
1598 #[inline(always)]
1599 fn init(&self, arch: V3) -> Self::Accumulator {
1600 Self::Accumulator::default(arch)
1601 }
1602
1603 #[inline(always)]
1604 fn accumulate(
1605 &self,
1606 x: Self::Left,
1607 y: Self::Right,
1608 acc: Self::Accumulator,
1609 ) -> Self::Accumulator {
1610 x.mul_add_simd(y, acc)
1611 }
1612
1613 #[inline(always)]
1615 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1616 x.sum_tree()
1617 }
1618}
1619
1620#[cfg(target_arch = "aarch64")]
1621impl SIMDSchema<f32, f32, Neon> for IP {
1622 type SIMDWidth = Const<4>;
1623 type Accumulator = <Neon as Architecture>::f32x4;
1624 type Left = <Neon as Architecture>::f32x4;
1625 type Right = <Neon as Architecture>::f32x4;
1626 type Return = f32;
1627 type Main = Strategy4x1;
1628
1629 #[inline(always)]
1630 fn init(&self, arch: Neon) -> Self::Accumulator {
1631 Self::Accumulator::default(arch)
1632 }
1633
1634 #[inline(always)]
1635 fn accumulate(
1636 &self,
1637 x: Self::Left,
1638 y: Self::Right,
1639 acc: Self::Accumulator,
1640 ) -> Self::Accumulator {
1641 x.mul_add_simd(y, acc)
1642 }
1643
1644 #[inline(always)]
1645 unsafe fn epilogue(
1646 &self,
1647 arch: Neon,
1648 x: *const f32,
1649 y: *const f32,
1650 len: usize,
1651 acc: Self::Accumulator,
1652 ) -> Self::Accumulator {
1653 let scalar = scalar_epilogue(
1654 x,
1655 y,
1656 len.min(Self::SIMDWidth::value() - 1),
1657 0.0f32,
1658 |acc, x: f32, y: f32| -> f32 { x.mul_add(y, acc) },
1659 );
1660 acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0])
1661 }
1662
1663 #[inline(always)]
1664 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1665 x.sum_tree()
1666 }
1667}
1668
1669impl SIMDSchema<f32, f32, Scalar> for IP {
1670 type SIMDWidth = Const<4>;
1671 type Accumulator = Emulated<f32, 4>;
1672 type Left = Emulated<f32, 4>;
1673 type Right = Emulated<f32, 4>;
1674 type Return = f32;
1675 type Main = Strategy2x1;
1676
1677 #[inline(always)]
1678 fn init(&self, arch: Scalar) -> Self::Accumulator {
1679 Self::Accumulator::default(arch)
1680 }
1681
1682 #[inline(always)]
1683 fn accumulate(
1684 &self,
1685 x: Self::Left,
1686 y: Self::Right,
1687 acc: Self::Accumulator,
1688 ) -> Self::Accumulator {
1689 x * y + acc
1690 }
1691
1692 #[inline(always)]
1694 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1695 x.sum_tree()
1696 }
1697
1698 #[inline(always)]
1699 unsafe fn epilogue(
1700 &self,
1701 arch: Scalar,
1702 x: *const f32,
1703 y: *const f32,
1704 len: usize,
1705 acc: Self::Accumulator,
1706 ) -> Self::Accumulator {
1707 let mut s: f32 = 0.0;
1708 for i in 0..len {
1709 let vx = unsafe { x.add(i).read_unaligned() };
1711 let vy = unsafe { y.add(i).read_unaligned() };
1713 s += vx * vy;
1714 }
1715 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
1716 }
1717}
1718
1719#[cfg(target_arch = "x86_64")]
1720impl SIMDSchema<Half, Half, V4> for IP {
1721 type SIMDWidth = Const<8>;
1722 type Accumulator = <V4 as Architecture>::f32x8;
1723 type Left = <V4 as Architecture>::f16x8;
1724 type Right = <V4 as Architecture>::f16x8;
1725 type Return = f32;
1726 type Main = Strategy4x1;
1727
1728 #[inline(always)]
1729 fn init(&self, arch: V4) -> Self::Accumulator {
1730 Self::Accumulator::default(arch)
1731 }
1732
1733 #[inline(always)]
1734 fn accumulate(
1735 &self,
1736 x: Self::Left,
1737 y: Self::Right,
1738 acc: Self::Accumulator,
1739 ) -> Self::Accumulator {
1740 diskann_wide::alias!(f32s = <V4>::f32x8);
1741
1742 let x: f32s = x.into();
1743 let y: f32s = y.into();
1744 x.mul_add_simd(y, acc)
1745 }
1746
1747 #[inline(always)]
1748 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1749 x.sum_tree()
1750 }
1751}
1752
1753#[cfg(target_arch = "x86_64")]
1754impl SIMDSchema<Half, Half, V3> for IP {
1755 type SIMDWidth = Const<8>;
1756 type Accumulator = <V3 as Architecture>::f32x8;
1757 type Left = <V3 as Architecture>::f16x8;
1758 type Right = <V3 as Architecture>::f16x8;
1759 type Return = f32;
1760 type Main = Strategy2x4;
1761
1762 #[inline(always)]
1763 fn init(&self, arch: V3) -> Self::Accumulator {
1764 Self::Accumulator::default(arch)
1765 }
1766
1767 #[inline(always)]
1768 fn accumulate(
1769 &self,
1770 x: Self::Left,
1771 y: Self::Right,
1772 acc: Self::Accumulator,
1773 ) -> Self::Accumulator {
1774 diskann_wide::alias!(f32s = <V3>::f32x8);
1775
1776 let x: f32s = x.into();
1777 let y: f32s = y.into();
1778 x.mul_add_simd(y, acc)
1779 }
1780
1781 #[inline(always)]
1783 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1784 x.sum_tree()
1785 }
1786}
1787
1788#[cfg(target_arch = "aarch64")]
1789impl SIMDSchema<Half, Half, Neon> for IP {
1790 type SIMDWidth = Const<4>;
1791 type Accumulator = <Neon as Architecture>::f32x4;
1792 type Left = diskann_wide::arch::aarch64::f16x4;
1793 type Right = diskann_wide::arch::aarch64::f16x4;
1794 type Return = f32;
1795 type Main = Strategy4x1;
1796
1797 #[inline(always)]
1798 fn init(&self, arch: Neon) -> Self::Accumulator {
1799 Self::Accumulator::default(arch)
1800 }
1801
1802 #[inline(always)]
1803 fn accumulate(
1804 &self,
1805 x: Self::Left,
1806 y: Self::Right,
1807 acc: Self::Accumulator,
1808 ) -> Self::Accumulator {
1809 diskann_wide::alias!(f32s = <Neon>::f32x4);
1810
1811 let x: f32s = x.into();
1812 let y: f32s = y.into();
1813
1814 x.mul_add_simd(y, acc)
1815 }
1816
1817 #[inline(always)]
1818 unsafe fn epilogue(
1819 &self,
1820 arch: Neon,
1821 x: *const Half,
1822 y: *const Half,
1823 len: usize,
1824 acc: Self::Accumulator,
1825 ) -> Self::Accumulator {
1826 diskann_wide::alias!(f32s = <Neon>::f32x4);
1827
1828 let rest = scalar_epilogue(
1829 x,
1830 y,
1831 len.min(Self::SIMDWidth::value() - 1),
1832 f32s::default(arch),
1833 |acc, x: Half, y: Half| -> f32s {
1834 let zero = Half::default();
1835 let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
1836 let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
1837 x.mul_add_simd(y, acc)
1838 },
1839 );
1840 acc + rest
1841 }
1842
1843 #[inline(always)]
1844 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1845 x.sum_tree()
1846 }
1847}
1848
1849impl SIMDSchema<Half, Half, Scalar> for IP {
1850 type SIMDWidth = Const<1>;
1851 type Accumulator = Emulated<f32, 1>;
1852 type Left = Emulated<Half, 1>;
1853 type Right = Emulated<Half, 1>;
1854 type Return = f32;
1855 type Main = Strategy1x1;
1856
1857 #[inline(always)]
1858 fn init(&self, arch: Scalar) -> Self::Accumulator {
1859 Self::Accumulator::default(arch)
1860 }
1861
1862 #[inline(always)]
1863 fn accumulate(
1864 &self,
1865 x: Self::Left,
1866 y: Self::Right,
1867 acc: Self::Accumulator,
1868 ) -> Self::Accumulator {
1869 let x: Self::Accumulator = x.into();
1870 let y: Self::Accumulator = y.into();
1871 x * y + acc
1872 }
1873
1874 #[inline(always)]
1875 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1876 x.to_array()[0]
1877 }
1878}
1879
1880impl<A> SIMDSchema<f32, Half, A> for IP
1881where
1882 A: Architecture,
1883{
1884 type SIMDWidth = Const<8>;
1885 type Accumulator = A::f32x8;
1886 type Left = A::f32x8;
1887 type Right = A::f16x8;
1888 type Return = f32;
1889 type Main = Strategy4x2;
1890
1891 #[inline(always)]
1892 fn init(&self, arch: A) -> Self::Accumulator {
1893 Self::Accumulator::default(arch)
1894 }
1895
1896 #[inline(always)]
1897 fn accumulate(
1898 &self,
1899 x: Self::Left,
1900 y: Self::Right,
1901 acc: Self::Accumulator,
1902 ) -> Self::Accumulator {
1903 let y: A::f32x8 = y.into();
1904 x.mul_add_simd(y, acc)
1905 }
1906
1907 #[inline(always)]
1909 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1910 x.sum_tree()
1911 }
1912}
1913
1914#[cfg(target_arch = "x86_64")]
1915impl SIMDSchema<i8, i8, V4> for IP {
1916 type SIMDWidth = Const<32>;
1917 type Accumulator = <V4 as Architecture>::i32x16;
1918 type Left = <V4 as Architecture>::i8x32;
1919 type Right = <V4 as Architecture>::i8x32;
1920 type Return = f32;
1921 type Main = Strategy4x1;
1922
1923 #[inline(always)]
1924 fn init(&self, arch: V4) -> Self::Accumulator {
1925 Self::Accumulator::default(arch)
1926 }
1927
1928 #[inline(always)]
1929 fn accumulate(
1930 &self,
1931 x: Self::Left,
1932 y: Self::Right,
1933 acc: Self::Accumulator,
1934 ) -> Self::Accumulator {
1935 diskann_wide::alias!(i16s = <V4>::i16x32);
1936
1937 let x: i16s = x.into();
1938 let y: i16s = y.into();
1939 acc.dot_simd(x, y)
1940 }
1941
1942 #[inline(always)]
1943 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1944 x.sum_tree().as_f32_lossy()
1945 }
1946}
1947
1948#[cfg(target_arch = "x86_64")]
1949impl SIMDSchema<i8, i8, V3> for IP {
1950 type SIMDWidth = Const<16>;
1951 type Accumulator = <V3 as Architecture>::i32x8;
1952 type Left = <V3 as Architecture>::i8x16;
1953 type Right = <V3 as Architecture>::i8x16;
1954 type Return = f32;
1955 type Main = Strategy4x1;
1956
1957 #[inline(always)]
1958 fn init(&self, arch: V3) -> Self::Accumulator {
1959 Self::Accumulator::default(arch)
1960 }
1961
1962 #[inline(always)]
1963 fn accumulate(
1964 &self,
1965 x: Self::Left,
1966 y: Self::Right,
1967 acc: Self::Accumulator,
1968 ) -> Self::Accumulator {
1969 diskann_wide::alias!(i16s = <V3>::i16x16);
1970
1971 let x: i16s = x.into();
1972 let y: i16s = y.into();
1973 acc.dot_simd(x, y)
1974 }
1975
1976 #[inline(always)]
1978 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1979 x.sum_tree().as_f32_lossy()
1980 }
1981}
1982
1983#[cfg(target_arch = "aarch64")]
1984impl SIMDSchema<i8, i8, Neon> for IP {
1985 type SIMDWidth = Const<16>;
1986 type Accumulator = <Neon as Architecture>::i32x4;
1987 type Left = <Neon as Architecture>::i8x16;
1988 type Right = <Neon as Architecture>::i8x16;
1989 type Return = f32;
1990 type Main = Strategy2x1;
1991
1992 #[inline(always)]
1993 fn init(&self, arch: Neon) -> Self::Accumulator {
1994 Self::Accumulator::default(arch)
1995 }
1996
1997 #[inline(always)]
1998 fn accumulate(
1999 &self,
2000 x: Self::Left,
2001 y: Self::Right,
2002 acc: Self::Accumulator,
2003 ) -> Self::Accumulator {
2004 acc.dot_simd(x, y)
2005 }
2006
2007 #[inline(always)]
2008 unsafe fn epilogue(
2009 &self,
2010 arch: Neon,
2011 x: *const i8,
2012 y: *const i8,
2013 len: usize,
2014 acc: Self::Accumulator,
2015 ) -> Self::Accumulator {
2016 let scalar = scalar_epilogue(
2017 x,
2018 y,
2019 len.min(Self::SIMDWidth::value() - 1),
2020 0i32,
2021 |acc, x: i8, y: i8| -> i32 { acc + (x as i32) * (y as i32) },
2022 );
2023 acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0])
2024 }
2025
2026 #[inline(always)]
2027 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2028 x.sum_tree().as_f32_lossy()
2029 }
2030}
2031
2032impl SIMDSchema<i8, i8, Scalar> for IP {
2033 type SIMDWidth = Const<1>;
2034 type Accumulator = Emulated<i32, 1>;
2035 type Left = Emulated<i8, 1>;
2036 type Right = Emulated<i8, 1>;
2037 type Return = f32;
2038 type Main = Strategy1x1;
2039
2040 #[inline(always)]
2041 fn init(&self, arch: Scalar) -> Self::Accumulator {
2042 Self::Accumulator::default(arch)
2043 }
2044
2045 #[inline(always)]
2046 fn accumulate(
2047 &self,
2048 x: Self::Left,
2049 y: Self::Right,
2050 acc: Self::Accumulator,
2051 ) -> Self::Accumulator {
2052 let x: Self::Accumulator = x.into();
2053 let y: Self::Accumulator = y.into();
2054 x * y + acc
2055 }
2056
2057 #[inline(always)]
2059 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2060 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
2061 }
2062
2063 #[inline(always)]
2064 unsafe fn epilogue(
2065 &self,
2066 _arch: Scalar,
2067 _x: *const i8,
2068 _y: *const i8,
2069 _len: usize,
2070 _acc: Self::Accumulator,
2071 ) -> Self::Accumulator {
2072 unreachable!("The SIMD width is 1, so there should be no epilogue")
2073 }
2074}
2075
2076#[cfg(target_arch = "x86_64")]
2077impl SIMDSchema<u8, u8, V4> for IP {
2078 type SIMDWidth = Const<32>;
2079 type Accumulator = <V4 as Architecture>::i32x16;
2080 type Left = <V4 as Architecture>::u8x32;
2081 type Right = <V4 as Architecture>::u8x32;
2082 type Return = f32;
2083 type Main = Strategy4x1;
2084
2085 #[inline(always)]
2086 fn init(&self, arch: V4) -> Self::Accumulator {
2087 Self::Accumulator::default(arch)
2088 }
2089
2090 #[inline(always)]
2091 fn accumulate(
2092 &self,
2093 x: Self::Left,
2094 y: Self::Right,
2095 acc: Self::Accumulator,
2096 ) -> Self::Accumulator {
2097 diskann_wide::alias!(i16s = <V4>::i16x32);
2098
2099 let x: i16s = x.into();
2100 let y: i16s = y.into();
2101 acc.dot_simd(x, y)
2102 }
2103
2104 #[inline(always)]
2105 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2106 x.sum_tree().as_f32_lossy()
2107 }
2108}
2109
2110#[cfg(target_arch = "x86_64")]
2111impl SIMDSchema<u8, u8, V3> for IP {
2112 type SIMDWidth = Const<16>;
2113 type Accumulator = <V3 as Architecture>::i32x8;
2114 type Left = <V3 as Architecture>::u8x16;
2115 type Right = <V3 as Architecture>::u8x16;
2116 type Return = f32;
2117 type Main = Strategy4x1;
2118
2119 #[inline(always)]
2120 fn init(&self, arch: V3) -> Self::Accumulator {
2121 Self::Accumulator::default(arch)
2122 }
2123
2124 #[inline(always)]
2125 fn accumulate(
2126 &self,
2127 x: Self::Left,
2128 y: Self::Right,
2129 acc: Self::Accumulator,
2130 ) -> Self::Accumulator {
2131 diskann_wide::alias!(i16s = <V3>::i16x16);
2132
2133 let x: i16s = x.into();
2136 let y: i16s = y.into();
2137 acc.dot_simd(x, y)
2138 }
2139
2140 #[inline(always)]
2142 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2143 x.sum_tree().as_f32_lossy()
2144 }
2145}
2146
2147#[cfg(target_arch = "aarch64")]
2148impl SIMDSchema<u8, u8, Neon> for IP {
2149 type SIMDWidth = Const<16>;
2150 type Accumulator = <Neon as Architecture>::u32x4;
2151 type Left = <Neon as Architecture>::u8x16;
2152 type Right = <Neon as Architecture>::u8x16;
2153 type Return = f32;
2154 type Main = Strategy2x1;
2155
2156 #[inline(always)]
2157 fn init(&self, arch: Neon) -> Self::Accumulator {
2158 Self::Accumulator::default(arch)
2159 }
2160
2161 #[inline(always)]
2162 fn accumulate(
2163 &self,
2164 x: Self::Left,
2165 y: Self::Right,
2166 acc: Self::Accumulator,
2167 ) -> Self::Accumulator {
2168 acc.dot_simd(x, y)
2169 }
2170
2171 #[inline(always)]
2172 unsafe fn epilogue(
2173 &self,
2174 arch: Neon,
2175 x: *const u8,
2176 y: *const u8,
2177 len: usize,
2178 acc: Self::Accumulator,
2179 ) -> Self::Accumulator {
2180 let scalar = scalar_epilogue(
2181 x,
2182 y,
2183 len.min(Self::SIMDWidth::value() - 1),
2184 0u32,
2185 |acc, x: u8, y: u8| -> u32 { acc + (x as u32) * (y as u32) },
2186 );
2187 acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0])
2188 }
2189
2190 #[inline(always)]
2191 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2192 x.sum_tree().as_f32_lossy()
2193 }
2194}
2195
2196impl SIMDSchema<u8, u8, Scalar> for IP {
2197 type SIMDWidth = Const<1>;
2198 type Accumulator = Emulated<i32, 1>;
2199 type Left = Emulated<u8, 1>;
2200 type Right = Emulated<u8, 1>;
2201 type Return = f32;
2202 type Main = Strategy1x1;
2203
2204 #[inline(always)]
2205 fn init(&self, arch: Scalar) -> Self::Accumulator {
2206 Self::Accumulator::default(arch)
2207 }
2208
2209 #[inline(always)]
2210 fn accumulate(
2211 &self,
2212 x: Self::Left,
2213 y: Self::Right,
2214 acc: Self::Accumulator,
2215 ) -> Self::Accumulator {
2216 let x: Self::Accumulator = x.into();
2217 let y: Self::Accumulator = y.into();
2218 x * y + acc
2219 }
2220
2221 #[inline(always)]
2223 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2224 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
2225 }
2226
2227 #[inline(always)]
2228 unsafe fn epilogue(
2229 &self,
2230 _arch: Scalar,
2231 _x: *const u8,
2232 _y: *const u8,
2233 _len: usize,
2234 _acc: Self::Accumulator,
2235 ) -> Self::Accumulator {
2236 unreachable!("The SIMD width is 1, so there should be no epilogue")
2237 }
2238}
2239
2240#[derive(Clone, Copy, Debug)]
2242pub struct ResumableIP<A = diskann_wide::arch::Current>
2243where
2244 A: Architecture,
2245 IP: SIMDSchema<f32, f32, A>,
2246{
2247 acc: <IP as SIMDSchema<f32, f32, A>>::Accumulator,
2248}
2249
2250impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableIP<A>
2251where
2252 A: Architecture,
2253 IP: SIMDSchema<f32, f32, A, Return = f32>,
2254{
2255 type NonResumable = IP;
2256 type FinalReturn = f32;
2257
2258 #[inline(always)]
2259 fn init(arch: A) -> Self {
2260 Self { acc: IP.init(arch) }
2261 }
2262
2263 #[inline(always)]
2264 fn combine_with(&self, other: <IP as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
2265 Self {
2266 acc: self.acc + other,
2267 }
2268 }
2269
2270 #[inline(always)]
2271 fn sum(&self) -> f32 {
2272 IP.reduce(self.acc)
2273 }
2274}
2275
2276#[derive(Debug, Clone, Copy)]
2283pub struct FullCosineAccumulator<T> {
2284 normx: T,
2285 normy: T,
2286 xy: T,
2287}
2288
2289impl<T> FullCosineAccumulator<T>
2290where
2291 T: SIMDVector
2292 + SIMDSumTree
2293 + SIMDMulAdd
2294 + std::ops::Mul<Output = T>
2295 + std::ops::Add<Output = T>,
2296 T::Scalar: LossyF32Conversion,
2297{
2298 #[inline(always)]
2299 pub fn new(arch: T::Arch) -> Self {
2300 let zero = T::default(arch);
2302 Self {
2303 normx: zero,
2304 normy: zero,
2305 xy: zero,
2306 }
2307 }
2308
2309 #[inline(always)]
2310 pub fn add_with(&self, x: T, y: T) -> Self {
2311 FullCosineAccumulator {
2313 normx: x.mul_add_simd(x, self.normx),
2314 normy: y.mul_add_simd(y, self.normy),
2315 xy: x.mul_add_simd(y, self.xy),
2316 }
2317 }
2318
2319 #[inline(always)]
2320 pub fn add_with_unfused(&self, x: T, y: T) -> Self {
2321 FullCosineAccumulator {
2323 normx: x * x + self.normx,
2324 normy: y * y + self.normy,
2325 xy: x * y + self.xy,
2326 }
2327 }
2328
2329 #[inline(always)]
2330 pub fn sum(&self) -> f32 {
2331 let normx = self.normx.sum_tree().as_f32_lossy();
2332 let normy = self.normy.sum_tree().as_f32_lossy();
2333
2334 let denominator = normx.sqrt() * normy.sqrt();
2341 let prod = self.xy.sum_tree().as_f32_lossy();
2342
2343 force_eval(denominator);
2351 force_eval(prod);
2352
2353 if normx < f32::MIN_POSITIVE || normy < f32::MIN_POSITIVE {
2361 return 0.0;
2362 }
2363
2364 let v = prod / denominator;
2365 (-1.0f32).max(1.0f32.min(v))
2366 }
2367
2368 #[inline(always)]
2370 pub fn sum_as_l2(&self) -> f32 {
2371 let normx = self.normx.sum_tree().as_f32_lossy();
2372 let normy = self.normy.sum_tree().as_f32_lossy();
2373 let xy = self.xy.sum_tree().as_f32_lossy();
2374 normx + normy - (xy + xy)
2375 }
2376}
2377
2378impl<T> std::ops::Add for FullCosineAccumulator<T>
2379where
2380 T: std::ops::Add<Output = T>,
2381{
2382 type Output = Self;
2383 #[inline(always)]
2384 fn add(self, other: Self) -> Self {
2385 FullCosineAccumulator {
2386 normx: self.normx + other.normx,
2387 normy: self.normy + other.normy,
2388 xy: self.xy + other.xy,
2389 }
2390 }
2391}
2392
2393#[derive(Default, Clone, Copy)]
2395pub struct CosineStateless;
2396
2397#[cfg(target_arch = "x86_64")]
2398impl SIMDSchema<f32, f32, V4> for CosineStateless {
2399 type SIMDWidth = Const<16>;
2400 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2401 type Left = <V4 as Architecture>::f32x16;
2402 type Right = <V4 as Architecture>::f32x16;
2403 type Return = f32;
2404
2405 type Main = Strategy2x4;
2408
2409 #[inline(always)]
2410 fn init(&self, arch: V4) -> Self::Accumulator {
2411 Self::Accumulator::new(arch)
2412 }
2413
2414 #[inline(always)]
2415 fn accumulate(
2416 &self,
2417 x: Self::Left,
2418 y: Self::Right,
2419 acc: Self::Accumulator,
2420 ) -> Self::Accumulator {
2421 acc.add_with(x, y)
2422 }
2423
2424 #[inline(always)]
2426 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2427 acc.sum()
2428 }
2429}
2430
2431#[cfg(target_arch = "x86_64")]
2432impl SIMDSchema<f32, f32, V3> for CosineStateless {
2433 type SIMDWidth = Const<8>;
2434 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2435 type Left = <V3 as Architecture>::f32x8;
2436 type Right = <V3 as Architecture>::f32x8;
2437 type Return = f32;
2438
2439 type Main = Strategy2x4;
2442
2443 #[inline(always)]
2444 fn init(&self, arch: V3) -> Self::Accumulator {
2445 Self::Accumulator::new(arch)
2446 }
2447
2448 #[inline(always)]
2449 fn accumulate(
2450 &self,
2451 x: Self::Left,
2452 y: Self::Right,
2453 acc: Self::Accumulator,
2454 ) -> Self::Accumulator {
2455 acc.add_with(x, y)
2456 }
2457
2458 #[inline(always)]
2460 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2461 acc.sum()
2462 }
2463}
2464
2465#[cfg(target_arch = "aarch64")]
2466impl SIMDSchema<f32, f32, Neon> for CosineStateless {
2467 type SIMDWidth = Const<4>;
2468 type Accumulator = FullCosineAccumulator<<Neon as Architecture>::f32x4>;
2469 type Left = <Neon as Architecture>::f32x4;
2470 type Right = <Neon as Architecture>::f32x4;
2471 type Return = f32;
2472
2473 type Main = Strategy2x4;
2476
2477 #[inline(always)]
2478 fn init(&self, arch: Neon) -> Self::Accumulator {
2479 Self::Accumulator::new(arch)
2480 }
2481
2482 #[inline(always)]
2483 fn accumulate(
2484 &self,
2485 x: Self::Left,
2486 y: Self::Right,
2487 acc: Self::Accumulator,
2488 ) -> Self::Accumulator {
2489 acc.add_with(x, y)
2490 }
2491
2492 #[inline(always)]
2493 unsafe fn epilogue(
2494 &self,
2495 arch: Neon,
2496 x: *const f32,
2497 y: *const f32,
2498 len: usize,
2499 acc: Self::Accumulator,
2500 ) -> Self::Accumulator {
2501 let mut xx: f32 = 0.0;
2502 let mut yy: f32 = 0.0;
2503 let mut xy: f32 = 0.0;
2504 for i in 0..len.min(Self::SIMDWidth::value() - 1) {
2505 let vx = unsafe { x.add(i).read_unaligned() };
2507 let vy = unsafe { y.add(i).read_unaligned() };
2509 xx = vx.mul_add(vx, xx);
2510 yy = vy.mul_add(vy, yy);
2511 xy = vx.mul_add(vy, xy);
2512 }
2513 type V = <Neon as Architecture>::f32x4;
2514 acc + FullCosineAccumulator {
2515 normx: V::from_array(arch, [xx, 0.0, 0.0, 0.0]),
2516 normy: V::from_array(arch, [yy, 0.0, 0.0, 0.0]),
2517 xy: V::from_array(arch, [xy, 0.0, 0.0, 0.0]),
2518 }
2519 }
2520
2521 #[inline(always)]
2523 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2524 acc.sum()
2525 }
2526}
2527
2528impl SIMDSchema<f32, f32, Scalar> for CosineStateless {
2529 type SIMDWidth = Const<4>;
2530 type Accumulator = FullCosineAccumulator<Emulated<f32, 4>>;
2531 type Left = Emulated<f32, 4>;
2532 type Right = Emulated<f32, 4>;
2533 type Return = f32;
2534
2535 type Main = Strategy2x1;
2536
2537 #[inline(always)]
2538 fn init(&self, arch: Scalar) -> Self::Accumulator {
2539 Self::Accumulator::new(arch)
2540 }
2541
2542 #[inline(always)]
2543 fn accumulate(
2544 &self,
2545 x: Self::Left,
2546 y: Self::Right,
2547 acc: Self::Accumulator,
2548 ) -> Self::Accumulator {
2549 acc.add_with_unfused(x, y)
2550 }
2551
2552 #[inline(always)]
2553 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2554 acc.sum()
2555 }
2556}
2557
2558#[cfg(target_arch = "x86_64")]
2559impl SIMDSchema<Half, Half, V4> for CosineStateless {
2560 type SIMDWidth = Const<16>;
2561 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2562 type Left = <V4 as Architecture>::f16x16;
2563 type Right = <V4 as Architecture>::f16x16;
2564 type Return = f32;
2565 type Main = Strategy2x4;
2566
2567 #[inline(always)]
2568 fn init(&self, arch: V4) -> Self::Accumulator {
2569 Self::Accumulator::new(arch)
2570 }
2571
2572 #[inline(always)]
2573 fn accumulate(
2574 &self,
2575 x: Self::Left,
2576 y: Self::Right,
2577 acc: Self::Accumulator,
2578 ) -> Self::Accumulator {
2579 diskann_wide::alias!(f32s = <V4>::f32x16);
2580
2581 let x: f32s = x.into();
2582 let y: f32s = y.into();
2583 acc.add_with(x, y)
2584 }
2585
2586 #[inline(always)]
2587 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2588 acc.sum()
2589 }
2590}
2591
2592#[cfg(target_arch = "x86_64")]
2593impl SIMDSchema<Half, Half, V3> for CosineStateless {
2594 type SIMDWidth = Const<8>;
2595 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2596 type Left = <V3 as Architecture>::f16x8;
2597 type Right = <V3 as Architecture>::f16x8;
2598 type Return = f32;
2599 type Main = Strategy2x4;
2600
2601 #[inline(always)]
2602 fn init(&self, arch: V3) -> Self::Accumulator {
2603 Self::Accumulator::new(arch)
2604 }
2605
2606 #[inline(always)]
2607 fn accumulate(
2608 &self,
2609 x: Self::Left,
2610 y: Self::Right,
2611 acc: Self::Accumulator,
2612 ) -> Self::Accumulator {
2613 diskann_wide::alias!(f32s = <V3>::f32x8);
2614
2615 let x: f32s = x.into();
2616 let y: f32s = y.into();
2617 acc.add_with(x, y)
2618 }
2619
2620 #[inline(always)]
2622 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2623 acc.sum()
2624 }
2625}
2626
2627#[cfg(target_arch = "aarch64")]
2628impl SIMDSchema<Half, Half, Neon> for CosineStateless {
2629 type SIMDWidth = Const<4>;
2630 type Accumulator = FullCosineAccumulator<<Neon as Architecture>::f32x4>;
2631 type Left = diskann_wide::arch::aarch64::f16x4;
2632 type Right = diskann_wide::arch::aarch64::f16x4;
2633 type Return = f32;
2634
2635 type Main = Strategy2x4;
2636
2637 #[inline(always)]
2638 fn init(&self, arch: Neon) -> Self::Accumulator {
2639 Self::Accumulator::new(arch)
2640 }
2641
2642 #[inline(always)]
2643 fn accumulate(
2644 &self,
2645 x: Self::Left,
2646 y: Self::Right,
2647 acc: Self::Accumulator,
2648 ) -> Self::Accumulator {
2649 diskann_wide::alias!(f32s = <Neon>::f32x4);
2650
2651 let x: f32s = x.into();
2652 let y: f32s = y.into();
2653 acc.add_with(x, y)
2654 }
2655
2656 #[inline(always)]
2657 unsafe fn epilogue(
2658 &self,
2659 arch: Neon,
2660 x: *const Half,
2661 y: *const Half,
2662 len: usize,
2663 acc: Self::Accumulator,
2664 ) -> Self::Accumulator {
2665 type V = <Neon as Architecture>::f32x4;
2666
2667 let rest = scalar_epilogue(
2668 x,
2669 y,
2670 len.min(Self::SIMDWidth::value() - 1),
2671 FullCosineAccumulator::<V>::new(arch),
2672 |acc, x: Half, y: Half| -> FullCosineAccumulator<V> {
2673 let zero = Half::default();
2674 let x: V = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
2675 let y: V = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
2676 acc.add_with(x, y)
2677 },
2678 );
2679 acc + rest
2680 }
2681
2682 #[inline(always)]
2683 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2684 acc.sum()
2685 }
2686}
2687
2688impl SIMDSchema<Half, Half, Scalar> for CosineStateless {
2689 type SIMDWidth = Const<1>;
2690 type Accumulator = FullCosineAccumulator<Emulated<f32, 1>>;
2691 type Left = Emulated<Half, 1>;
2692 type Right = Emulated<Half, 1>;
2693 type Return = f32;
2694 type Main = Strategy1x1;
2695
2696 #[inline(always)]
2697 fn init(&self, arch: Scalar) -> Self::Accumulator {
2698 Self::Accumulator::new(arch)
2699 }
2700
2701 #[inline(always)]
2702 fn accumulate(
2703 &self,
2704 x: Self::Left,
2705 y: Self::Right,
2706 acc: Self::Accumulator,
2707 ) -> Self::Accumulator {
2708 let x: Emulated<f32, 1> = x.into();
2709 let y: Emulated<f32, 1> = y.into();
2710 acc.add_with_unfused(x, y)
2711 }
2712
2713 #[inline(always)]
2714 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2715 acc.sum()
2716 }
2717}
2718impl<A> SIMDSchema<f32, Half, A> for CosineStateless
2719where
2720 A: Architecture,
2721{
2722 type SIMDWidth = Const<8>;
2723 type Accumulator = FullCosineAccumulator<A::f32x8>;
2724 type Left = A::f32x8;
2725 type Right = A::f16x8;
2726 type Return = f32;
2727 type Main = Strategy2x4;
2728
2729 #[inline(always)]
2730 fn init(&self, arch: A) -> Self::Accumulator {
2731 Self::Accumulator::new(arch)
2732 }
2733
2734 #[inline(always)]
2735 fn accumulate(
2736 &self,
2737 x: Self::Left,
2738 y: Self::Right,
2739 acc: Self::Accumulator,
2740 ) -> Self::Accumulator {
2741 let y: A::f32x8 = y.into();
2742 acc.add_with(x, y)
2743 }
2744
2745 #[inline(always)]
2746 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2747 acc.sum()
2748 }
2749}
2750
2751#[cfg(target_arch = "x86_64")]
2752impl SIMDSchema<i8, i8, V4> for CosineStateless {
2753 type SIMDWidth = Const<32>;
2754 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2755 type Left = <V4 as Architecture>::i8x32;
2756 type Right = <V4 as Architecture>::i8x32;
2757 type Return = f32;
2758 type Main = Strategy4x1;
2759
2760 #[inline(always)]
2761 fn init(&self, arch: V4) -> Self::Accumulator {
2762 Self::Accumulator::new(arch)
2763 }
2764
2765 #[inline(always)]
2766 fn accumulate(
2767 &self,
2768 x: Self::Left,
2769 y: Self::Right,
2770 acc: Self::Accumulator,
2771 ) -> Self::Accumulator {
2772 diskann_wide::alias!(i16s = <V4>::i16x32);
2773
2774 let x: i16s = x.into();
2775 let y: i16s = y.into();
2776
2777 FullCosineAccumulator {
2778 normx: acc.normx.dot_simd(x, x),
2779 normy: acc.normy.dot_simd(y, y),
2780 xy: acc.xy.dot_simd(x, y),
2781 }
2782 }
2783
2784 #[inline(always)]
2786 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2787 x.sum()
2788 }
2789}
2790
2791#[cfg(target_arch = "x86_64")]
2792impl SIMDSchema<i8, i8, V3> for CosineStateless {
2793 type SIMDWidth = Const<16>;
2794 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2795 type Left = <V3 as Architecture>::i8x16;
2796 type Right = <V3 as Architecture>::i8x16;
2797 type Return = f32;
2798 type Main = Strategy4x1;
2799
2800 #[inline(always)]
2801 fn init(&self, arch: V3) -> Self::Accumulator {
2802 Self::Accumulator::new(arch)
2803 }
2804
2805 #[inline(always)]
2806 fn accumulate(
2807 &self,
2808 x: Self::Left,
2809 y: Self::Right,
2810 acc: Self::Accumulator,
2811 ) -> Self::Accumulator {
2812 diskann_wide::alias!(i16s = <V3>::i16x16);
2813
2814 let x: i16s = x.into();
2815 let y: i16s = y.into();
2816
2817 FullCosineAccumulator {
2818 normx: acc.normx.dot_simd(x, x),
2819 normy: acc.normy.dot_simd(y, y),
2820 xy: acc.xy.dot_simd(x, y),
2821 }
2822 }
2823
2824 #[inline(always)]
2826 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2827 x.sum()
2828 }
2829}
2830
2831#[cfg(target_arch = "aarch64")]
2832impl SIMDSchema<i8, i8, Neon> for CosineStateless {
2833 type SIMDWidth = Const<16>;
2834 type Accumulator = FullCosineAccumulator<<Neon as Architecture>::i32x4>;
2835 type Left = <Neon as Architecture>::i8x16;
2836 type Right = <Neon as Architecture>::i8x16;
2837 type Return = f32;
2838 type Main = Strategy2x1;
2839
2840 #[inline(always)]
2841 fn init(&self, arch: Neon) -> Self::Accumulator {
2842 Self::Accumulator::new(arch)
2843 }
2844
2845 #[inline(always)]
2846 fn accumulate(
2847 &self,
2848 x: Self::Left,
2849 y: Self::Right,
2850 acc: Self::Accumulator,
2851 ) -> Self::Accumulator {
2852 FullCosineAccumulator {
2853 normx: acc.normx.dot_simd(x, x),
2854 normy: acc.normy.dot_simd(y, y),
2855 xy: acc.xy.dot_simd(x, y),
2856 }
2857 }
2858
2859 #[inline(always)]
2860 unsafe fn epilogue(
2861 &self,
2862 arch: Neon,
2863 x: *const i8,
2864 y: *const i8,
2865 len: usize,
2866 acc: Self::Accumulator,
2867 ) -> Self::Accumulator {
2868 let mut xx: i32 = 0;
2869 let mut yy: i32 = 0;
2870 let mut xy: i32 = 0;
2871 for i in 0..len.min(Self::SIMDWidth::value() - 1) {
2872 let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
2874 let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
2876 xx += vx * vx;
2877 xy += vx * vy;
2878 yy += vy * vy;
2879 }
2880 type V = <Neon as Architecture>::i32x4;
2881 acc + FullCosineAccumulator {
2882 normx: V::from_array(arch, [xx, 0, 0, 0]),
2883 normy: V::from_array(arch, [yy, 0, 0, 0]),
2884 xy: V::from_array(arch, [xy, 0, 0, 0]),
2885 }
2886 }
2887
2888 #[inline(always)]
2889 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2890 x.sum()
2891 }
2892}
2893
2894impl SIMDSchema<i8, i8, Scalar> for CosineStateless {
2895 type SIMDWidth = Const<4>;
2896 type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2897 type Left = Emulated<i8, 4>;
2898 type Right = Emulated<i8, 4>;
2899 type Return = f32;
2900 type Main = Strategy1x1;
2901
2902 #[inline(always)]
2903 fn init(&self, arch: Scalar) -> Self::Accumulator {
2904 Self::Accumulator::new(arch)
2905 }
2906
2907 #[inline(always)]
2908 fn accumulate(
2909 &self,
2910 x: Self::Left,
2911 y: Self::Right,
2912 acc: Self::Accumulator,
2913 ) -> Self::Accumulator {
2914 let x: Emulated<i32, 4> = x.into();
2915 let y: Emulated<i32, 4> = y.into();
2916 acc.add_with(x, y)
2917 }
2918
2919 #[inline(always)]
2921 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2922 x.sum()
2923 }
2924
2925 #[inline(always)]
2926 unsafe fn epilogue(
2927 &self,
2928 arch: Scalar,
2929 x: *const i8,
2930 y: *const i8,
2931 len: usize,
2932 acc: Self::Accumulator,
2933 ) -> Self::Accumulator {
2934 let mut xy: i32 = 0;
2935 let mut xx: i32 = 0;
2936 let mut yy: i32 = 0;
2937
2938 for i in 0..len {
2939 let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
2941 let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
2943
2944 xx += vx * vx;
2945 xy += vx * vy;
2946 yy += vy * vy;
2947 }
2948
2949 acc + FullCosineAccumulator {
2950 normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2951 normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2952 xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2953 }
2954 }
2955}
2956
2957#[cfg(target_arch = "x86_64")]
2958impl SIMDSchema<u8, u8, V4> for CosineStateless {
2959 type SIMDWidth = Const<32>;
2960 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2961 type Left = <V4 as Architecture>::u8x32;
2962 type Right = <V4 as Architecture>::u8x32;
2963 type Return = f32;
2964 type Main = Strategy4x1;
2965
2966 #[inline(always)]
2967 fn init(&self, arch: V4) -> Self::Accumulator {
2968 Self::Accumulator::new(arch)
2969 }
2970
2971 #[inline(always)]
2972 fn accumulate(
2973 &self,
2974 x: Self::Left,
2975 y: Self::Right,
2976 acc: Self::Accumulator,
2977 ) -> Self::Accumulator {
2978 diskann_wide::alias!(i16s = <V4>::i16x32);
2979
2980 let x: i16s = x.into();
2981 let y: i16s = y.into();
2982
2983 FullCosineAccumulator {
2984 normx: acc.normx.dot_simd(x, x),
2985 normy: acc.normy.dot_simd(y, y),
2986 xy: acc.xy.dot_simd(x, y),
2987 }
2988 }
2989
2990 #[inline(always)]
2992 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2993 x.sum()
2994 }
2995}
2996
2997#[cfg(target_arch = "x86_64")]
2998impl SIMDSchema<u8, u8, V3> for CosineStateless {
2999 type SIMDWidth = Const<16>;
3000 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
3001 type Left = <V3 as Architecture>::u8x16;
3002 type Right = <V3 as Architecture>::u8x16;
3003 type Return = f32;
3004 type Main = Strategy4x1;
3005
3006 #[inline(always)]
3007 fn init(&self, arch: V3) -> Self::Accumulator {
3008 Self::Accumulator::new(arch)
3009 }
3010
3011 #[inline(always)]
3012 fn accumulate(
3013 &self,
3014 x: Self::Left,
3015 y: Self::Right,
3016 acc: Self::Accumulator,
3017 ) -> Self::Accumulator {
3018 diskann_wide::alias!(i16s = <V3>::i16x16);
3019
3020 let x: i16s = x.into();
3021 let y: i16s = y.into();
3022
3023 FullCosineAccumulator {
3024 normx: acc.normx.dot_simd(x, x),
3025 normy: acc.normy.dot_simd(y, y),
3026 xy: acc.xy.dot_simd(x, y),
3027 }
3028 }
3029
3030 #[inline(always)]
3032 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3033 x.sum()
3034 }
3035}
3036
3037#[cfg(target_arch = "aarch64")]
3038impl SIMDSchema<u8, u8, Neon> for CosineStateless {
3039 type SIMDWidth = Const<16>;
3040 type Accumulator = FullCosineAccumulator<<Neon as Architecture>::u32x4>;
3041 type Left = <Neon as Architecture>::u8x16;
3042 type Right = <Neon as Architecture>::u8x16;
3043 type Return = f32;
3044 type Main = Strategy2x1;
3045
3046 #[inline(always)]
3047 fn init(&self, arch: Neon) -> Self::Accumulator {
3048 Self::Accumulator::new(arch)
3049 }
3050
3051 #[inline(always)]
3052 fn accumulate(
3053 &self,
3054 x: Self::Left,
3055 y: Self::Right,
3056 acc: Self::Accumulator,
3057 ) -> Self::Accumulator {
3058 FullCosineAccumulator {
3059 normx: acc.normx.dot_simd(x, x),
3060 normy: acc.normy.dot_simd(y, y),
3061 xy: acc.xy.dot_simd(x, y),
3062 }
3063 }
3064
3065 #[inline(always)]
3066 unsafe fn epilogue(
3067 &self,
3068 arch: Neon,
3069 x: *const u8,
3070 y: *const u8,
3071 len: usize,
3072 acc: Self::Accumulator,
3073 ) -> Self::Accumulator {
3074 let mut xx: u32 = 0;
3075 let mut yy: u32 = 0;
3076 let mut xy: u32 = 0;
3077 for i in 0..len.min(Self::SIMDWidth::value() - 1) {
3078 let vx: u32 = unsafe { x.add(i).read_unaligned() }.into();
3080 let vy: u32 = unsafe { y.add(i).read_unaligned() }.into();
3082 xx += vx * vx;
3083 xy += vx * vy;
3084 yy += vy * vy;
3085 }
3086 type V = <Neon as Architecture>::u32x4;
3087 acc + FullCosineAccumulator {
3088 normx: V::from_array(arch, [xx, 0, 0, 0]),
3089 normy: V::from_array(arch, [yy, 0, 0, 0]),
3090 xy: V::from_array(arch, [xy, 0, 0, 0]),
3091 }
3092 }
3093
3094 #[inline(always)]
3095 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3096 x.sum()
3097 }
3098}
3099
3100impl SIMDSchema<u8, u8, Scalar> for CosineStateless {
3101 type SIMDWidth = Const<4>;
3102 type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
3103 type Left = Emulated<u8, 4>;
3104 type Right = Emulated<u8, 4>;
3105 type Return = f32;
3106 type Main = Strategy1x1;
3107
3108 #[inline(always)]
3109 fn init(&self, arch: Scalar) -> Self::Accumulator {
3110 Self::Accumulator::new(arch)
3111 }
3112
3113 #[inline(always)]
3114 fn accumulate(
3115 &self,
3116 x: Self::Left,
3117 y: Self::Right,
3118 acc: Self::Accumulator,
3119 ) -> Self::Accumulator {
3120 let x: Emulated<i32, 4> = x.into();
3121 let y: Emulated<i32, 4> = y.into();
3122 acc.add_with(x, y)
3123 }
3124
3125 #[inline(always)]
3127 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3128 x.sum()
3129 }
3130
3131 #[inline(always)]
3132 unsafe fn epilogue(
3133 &self,
3134 arch: Scalar,
3135 x: *const u8,
3136 y: *const u8,
3137 len: usize,
3138 acc: Self::Accumulator,
3139 ) -> Self::Accumulator {
3140 let mut xy: i32 = 0;
3141 let mut xx: i32 = 0;
3142 let mut yy: i32 = 0;
3143
3144 for i in 0..len {
3145 let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
3147 let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
3149
3150 xx += vx * vx;
3151 xy += vx * vy;
3152 yy += vy * vy;
3153 }
3154
3155 acc + FullCosineAccumulator {
3156 normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
3157 normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
3158 xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
3159 }
3160 }
3161}
3162
3163#[derive(Debug, Clone, Copy)]
3165pub struct ResumableCosine<A = diskann_wide::arch::Current>(
3166 <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
3167)
3168where
3169 A: Architecture,
3170 CosineStateless: SIMDSchema<f32, f32, A>;
3171
3172impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableCosine<A>
3173where
3174 A: Architecture,
3175 CosineStateless: SIMDSchema<f32, f32, A, Return = f32>,
3176{
3177 type NonResumable = CosineStateless;
3178 type FinalReturn = f32;
3179
3180 #[inline(always)]
3181 fn init(arch: A) -> Self {
3182 Self(CosineStateless.init(arch))
3183 }
3184
3185 #[inline(always)]
3186 fn combine_with(
3187 &self,
3188 other: <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
3189 ) -> Self {
3190 Self(self.0 + other)
3191 }
3192
3193 #[inline(always)]
3194 fn sum(&self) -> f32 {
3195 CosineStateless.reduce(self.0)
3196 }
3197}
3198
3199#[derive(Clone, Copy, Debug, Default)]
3214pub struct L1Norm;
3215
3216#[cfg(target_arch = "x86_64")]
3217impl SIMDSchema<f32, f32, V4> for L1Norm {
3218 type SIMDWidth = Const<16>;
3219 type Accumulator = <V4 as Architecture>::f32x16;
3220 type Left = <V4 as Architecture>::f32x16;
3221 type Right = <V4 as Architecture>::f32x16;
3222 type Return = f32;
3223 type Main = Strategy4x1;
3224
3225 #[inline(always)]
3226 fn init(&self, arch: V4) -> Self::Accumulator {
3227 Self::Accumulator::default(arch)
3228 }
3229
3230 #[inline(always)]
3231 fn accumulate(
3232 &self,
3233 x: Self::Left,
3234 _y: Self::Right,
3235 acc: Self::Accumulator,
3236 ) -> Self::Accumulator {
3237 x.abs_simd() + acc
3238 }
3239
3240 #[inline(always)]
3242 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3243 x.sum_tree()
3244 }
3245}
3246
3247#[cfg(target_arch = "x86_64")]
3248impl SIMDSchema<f32, f32, V3> for L1Norm {
3249 type SIMDWidth = Const<8>;
3250 type Accumulator = <V3 as Architecture>::f32x8;
3251 type Left = <V3 as Architecture>::f32x8;
3252 type Right = <V3 as Architecture>::f32x8;
3253 type Return = f32;
3254 type Main = Strategy4x1;
3255
3256 #[inline(always)]
3257 fn init(&self, arch: V3) -> Self::Accumulator {
3258 Self::Accumulator::default(arch)
3259 }
3260
3261 #[inline(always)]
3262 fn accumulate(
3263 &self,
3264 x: Self::Left,
3265 _y: Self::Right,
3266 acc: Self::Accumulator,
3267 ) -> Self::Accumulator {
3268 x.abs_simd() + acc
3269 }
3270
3271 #[inline(always)]
3273 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3274 x.sum_tree()
3275 }
3276}
3277
3278#[cfg(target_arch = "aarch64")]
3279impl SIMDSchema<f32, f32, Neon> for L1Norm {
3280 type SIMDWidth = Const<4>;
3281 type Accumulator = <Neon as Architecture>::f32x4;
3282 type Left = <Neon as Architecture>::f32x4;
3283 type Right = <Neon as Architecture>::f32x4;
3284 type Return = f32;
3285 type Main = Strategy4x1;
3286
3287 #[inline(always)]
3288 fn init(&self, arch: Neon) -> Self::Accumulator {
3289 Self::Accumulator::default(arch)
3290 }
3291
3292 #[inline(always)]
3293 fn accumulate(
3294 &self,
3295 x: Self::Left,
3296 _y: Self::Right,
3297 acc: Self::Accumulator,
3298 ) -> Self::Accumulator {
3299 x.abs_simd() + acc
3300 }
3301
3302 #[inline(always)]
3303 unsafe fn epilogue(
3304 &self,
3305 arch: Neon,
3306 x: *const f32,
3307 _y: *const f32,
3308 len: usize,
3309 acc: Self::Accumulator,
3310 ) -> Self::Accumulator {
3311 let mut s: f32 = 0.0;
3312 for i in 0..len.min(Self::SIMDWidth::value() - 1) {
3313 let vx = unsafe { x.add(i).read_unaligned() };
3315 s += vx.abs();
3316 }
3317 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
3318 }
3319
3320 #[inline(always)]
3321 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3322 x.sum_tree()
3323 }
3324}
3325
3326impl SIMDSchema<f32, f32, Scalar> for L1Norm {
3327 type SIMDWidth = Const<4>;
3328 type Accumulator = Emulated<f32, 4>;
3329 type Left = Emulated<f32, 4>;
3330 type Right = Emulated<f32, 4>;
3331 type Return = f32;
3332 type Main = Strategy2x1;
3333
3334 #[inline(always)]
3335 fn init(&self, arch: Scalar) -> Self::Accumulator {
3336 Self::Accumulator::default(arch)
3337 }
3338
3339 #[inline(always)]
3340 fn accumulate(
3341 &self,
3342 x: Self::Left,
3343 _y: Self::Right,
3344 acc: Self::Accumulator,
3345 ) -> Self::Accumulator {
3346 x.abs_simd() + acc
3347 }
3348
3349 #[inline(always)]
3351 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3352 x.sum_tree()
3353 }
3354
3355 #[inline(always)]
3356 unsafe fn epilogue(
3357 &self,
3358 arch: Scalar,
3359 x: *const f32,
3360 _y: *const f32,
3361 len: usize,
3362 acc: Self::Accumulator,
3363 ) -> Self::Accumulator {
3364 let mut s: f32 = 0.0;
3365 for i in 0..len {
3366 let vx = unsafe { x.add(i).read_unaligned() };
3368 s += vx.abs();
3369 }
3370 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
3371 }
3372}
3373
3374#[cfg(target_arch = "x86_64")]
3375impl SIMDSchema<Half, Half, V4> for L1Norm {
3376 type SIMDWidth = Const<8>;
3377 type Accumulator = <V4 as Architecture>::f32x8;
3378 type Left = <V4 as Architecture>::f16x8;
3379 type Right = <V4 as Architecture>::f16x8;
3380 type Return = f32;
3381 type Main = Strategy2x4;
3382
3383 #[inline(always)]
3384 fn init(&self, arch: V4) -> Self::Accumulator {
3385 Self::Accumulator::default(arch)
3386 }
3387
3388 #[inline(always)]
3389 fn accumulate(
3390 &self,
3391 x: Self::Left,
3392 _y: Self::Right,
3393 acc: Self::Accumulator,
3394 ) -> Self::Accumulator {
3395 let x: <V4 as Architecture>::f32x8 = x.into();
3396 x.abs_simd() + acc
3397 }
3398
3399 #[inline(always)]
3401 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3402 x.sum_tree()
3403 }
3404}
3405
3406#[cfg(target_arch = "x86_64")]
3407impl SIMDSchema<Half, Half, V3> for L1Norm {
3408 type SIMDWidth = Const<8>;
3409 type Accumulator = <V3 as Architecture>::f32x8;
3410 type Left = <V3 as Architecture>::f16x8;
3411 type Right = <V3 as Architecture>::f16x8;
3412 type Return = f32;
3413 type Main = Strategy2x4;
3414
3415 #[inline(always)]
3416 fn init(&self, arch: V3) -> Self::Accumulator {
3417 Self::Accumulator::default(arch)
3418 }
3419
3420 #[inline(always)]
3421 fn accumulate(
3422 &self,
3423 x: Self::Left,
3424 _y: Self::Right,
3425 acc: Self::Accumulator,
3426 ) -> Self::Accumulator {
3427 let x: <V3 as Architecture>::f32x8 = x.into();
3428 x.abs_simd() + acc
3429 }
3430
3431 #[inline(always)]
3433 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3434 x.sum_tree()
3435 }
3436}
3437
3438#[cfg(target_arch = "aarch64")]
3439impl SIMDSchema<Half, Half, Neon> for L1Norm {
3440 type SIMDWidth = Const<4>;
3441 type Accumulator = <Neon as Architecture>::f32x4;
3442 type Left = diskann_wide::arch::aarch64::f16x4;
3443 type Right = diskann_wide::arch::aarch64::f16x4;
3444 type Return = f32;
3445 type Main = Strategy2x4;
3446
3447 #[inline(always)]
3448 fn init(&self, arch: Neon) -> Self::Accumulator {
3449 Self::Accumulator::default(arch)
3450 }
3451
3452 #[inline(always)]
3453 fn accumulate(
3454 &self,
3455 x: Self::Left,
3456 _y: Self::Right,
3457 acc: Self::Accumulator,
3458 ) -> Self::Accumulator {
3459 let x: <Neon as Architecture>::f32x4 = x.into();
3460 x.abs_simd() + acc
3461 }
3462
3463 #[inline(always)]
3464 unsafe fn epilogue(
3465 &self,
3466 arch: Neon,
3467 x: *const Half,
3468 _y: *const Half,
3469 len: usize,
3470 acc: Self::Accumulator,
3471 ) -> Self::Accumulator {
3472 let rest = scalar_epilogue(
3473 x,
3474 x, len.min(Self::SIMDWidth::value() - 1),
3476 Self::Accumulator::default(arch),
3477 |acc, x: Half, _: Half| -> Self::Accumulator {
3478 let zero = Half::default();
3479 let x: Self::Accumulator =
3480 Self::Left::from_array(arch, [x, zero, zero, zero]).into();
3481 x.abs_simd() + acc
3482 },
3483 );
3484 acc + rest
3485 }
3486
3487 #[inline(always)]
3489 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3490 x.sum_tree()
3491 }
3492}
3493
3494impl SIMDSchema<Half, Half, Scalar> for L1Norm {
3495 type SIMDWidth = Const<1>;
3496 type Accumulator = Emulated<f32, 1>;
3497 type Left = Emulated<Half, 1>;
3498 type Right = Emulated<Half, 1>;
3499 type Return = f32;
3500 type Main = Strategy1x1;
3501
3502 #[inline(always)]
3503 fn init(&self, arch: Scalar) -> Self::Accumulator {
3504 Self::Accumulator::default(arch)
3505 }
3506
3507 #[inline(always)]
3508 fn accumulate(
3509 &self,
3510 x: Self::Left,
3511 _y: Self::Right,
3512 acc: Self::Accumulator,
3513 ) -> Self::Accumulator {
3514 let x: Self::Accumulator = x.into();
3515 x.abs_simd() + acc
3516 }
3517
3518 #[inline(always)]
3520 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3521 x.to_array()[0]
3522 }
3523
3524 #[inline(always)]
3525 unsafe fn epilogue(
3526 &self,
3527 _arch: Scalar,
3528 _x: *const Half,
3529 _y: *const Half,
3530 _len: usize,
3531 _acc: Self::Accumulator,
3532 ) -> Self::Accumulator {
3533 unreachable!("The SIMD width is 1, so there should be no epilogue")
3534 }
3535}
3536
3537#[cfg(test)]
3542mod tests {
3543 use std::{collections::HashMap, sync::LazyLock};
3544
3545 use approx::assert_relative_eq;
3546 use diskann_wide::{arch::Target1, ARCH};
3547 use half::f16;
3548 use rand::{distr::StandardUniform, rngs::StdRng, Rng, SeedableRng};
3549 use rand_distr;
3550
3551 use super::*;
3552 use crate::{distance::reference, norm::LInfNorm, test_util};
3553
3554 fn cosine_norm_check_impl<A>(arch: A)
3559 where
3560 A: diskann_wide::Architecture,
3561 CosineStateless:
3562 SIMDSchema<f32, f32, A, Return = f32> + SIMDSchema<Half, Half, A, Return = f32>,
3563 {
3564 {
3566 let x: [f32; 2] = [0.0, 0.0];
3567 let y: [f32; 2] = [0.0, 1.0];
3568 assert_eq!(
3569 simd_op(&CosineStateless {}, arch, x, x),
3570 0.0,
3571 "when both vectors are zero, similarity should be zero",
3572 );
3573 assert_eq!(
3574 simd_op(&CosineStateless {}, arch, x, y),
3575 0.0,
3576 "when one vector is zero, similarity should be zero",
3577 );
3578 assert_eq!(
3579 simd_op(&CosineStateless {}, arch, y, x),
3580 0.0,
3581 "when one vector is zero, similarity should be zero",
3582 );
3583 }
3584
3585 {
3587 let x: [f32; 4] = [0.0, 0.0, 2.938736e-39f32, 0.0];
3588 let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
3589 assert_eq!(
3590 simd_op(&CosineStateless {}, arch, x, x),
3591 0.0,
3592 "when both vectors are almost zero, similarity should be zero",
3593 );
3594 assert_eq!(
3595 simd_op(&CosineStateless {}, arch, x, y),
3596 0.0,
3597 "when one vector is almost zero, similarity should be zero",
3598 );
3599 assert_eq!(
3600 simd_op(&CosineStateless {}, arch, y, x),
3601 0.0,
3602 "when one vector is almost zero, similarity should be zero",
3603 );
3604 }
3605
3606 {
3608 let x: [f32; 4] = [0.0, 0.0, 1.0842022e-19f32, 0.0];
3609 let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
3610 assert_eq!(
3611 simd_op(&CosineStateless {}, arch, x, x),
3612 1.0,
3613 "cosine-stateless should handle vectors this small",
3614 );
3615 assert_eq!(
3616 simd_op(&CosineStateless {}, arch, x, y),
3617 1.0,
3618 "cosine-stateless should handle vectors this small",
3619 );
3620 assert_eq!(
3621 simd_op(&CosineStateless {}, arch, y, x),
3622 1.0,
3623 "cosine-stateless should handle vectors this small",
3624 );
3625 }
3626
3627 let cvt = diskann_wide::cast_f32_to_f16;
3628
3629 {
3631 let x: [Half; 2] = [Half::default(), Half::default()];
3632 let y: [Half; 2] = [Half::default(), cvt(1.0)];
3633 assert_eq!(
3634 simd_op(&CosineStateless {}, arch, x, x),
3635 0.0,
3636 "when both vectors are zero, similarity should be zero",
3637 );
3638 assert_eq!(
3639 simd_op(&CosineStateless {}, arch, x, y),
3640 0.0,
3641 "when one vector is zero, similarity should be zero",
3642 );
3643 assert_eq!(
3644 simd_op(&CosineStateless {}, arch, y, x),
3645 0.0,
3646 "when one vector is zero, similarity should be zero",
3647 );
3648 }
3649
3650 {
3652 let x: [Half; 4] = [
3653 Half::default(),
3654 Half::default(),
3655 Half::MIN_POSITIVE_SUBNORMAL,
3656 Half::default(),
3657 ];
3658 let y: [Half; 4] = [Half::default(), Half::default(), cvt(1.0), Half::default()];
3659 assert_eq!(
3660 simd_op(&CosineStateless {}, arch, x, x),
3661 1.0,
3662 "when both vectors are almost zero, similarity should be zero",
3663 );
3664 assert_eq!(
3665 simd_op(&CosineStateless {}, arch, x, y),
3666 1.0,
3667 "when one vector is almost zero, similarity should be zero",
3668 );
3669 assert_eq!(
3670 simd_op(&CosineStateless {}, arch, y, x),
3671 1.0,
3672 "when one vector is almost zero, similarity should be zero",
3673 );
3674
3675 let threshold = f32::MIN_POSITIVE;
3681 let bound = 50;
3682 let values = {
3683 let mut down = threshold;
3684 let mut up = threshold;
3685 for _ in 0..bound {
3686 down = down.next_down();
3687 up = up.next_up();
3688 }
3689 assert!(down > 0.0);
3690 let min = down.sqrt();
3691 let max = up.sqrt();
3692 let mut v = min;
3693 let mut values = Vec::new();
3694 while v <= max {
3695 values.push(v);
3696 v = v.next_up();
3697 }
3698 values
3699 };
3700
3701 let mut lo = 0;
3702 let mut hi = 0;
3703 for i in values.iter() {
3704 for j in values.iter() {
3705 let s: f32 = simd_op(&CosineStateless {}, arch, [*i], [*j]);
3706 if i * i < threshold || j * j < threshold {
3707 lo += 1;
3708 assert_eq!(s, 0.0, "failed for i = {}, j = {}", i, j);
3709 } else {
3710 hi += 1;
3711 assert_eq!(s, 1.0, "failed for i = {}, j = {}", i, j);
3712 }
3713 }
3714 }
3715 assert_ne!(lo, 0);
3716 assert_ne!(hi, 0);
3717 }
3718 }
3719
3720 #[test]
3721 fn cosine_norm_check() {
3722 cosine_norm_check_impl::<diskann_wide::arch::Current>(diskann_wide::arch::current());
3723 cosine_norm_check_impl::<diskann_wide::arch::Scalar>(diskann_wide::arch::Scalar::new());
3724 }
3725
3726 #[test]
3727 #[cfg(target_arch = "x86_64")]
3728 fn cosine_norm_check_x86_64() {
3729 if let Some(arch) = V3::new_checked() {
3730 cosine_norm_check_impl::<V3>(arch);
3731 }
3732
3733 if let Some(arch) = V4::new_checked_miri() {
3734 cosine_norm_check_impl::<V4>(arch);
3735 }
3736 }
3737
3738 fn test_resumable<T, L, R, A>(arch: A, x: &[L], y: &[R], chunk_size: usize) -> f32
3744 where
3745 A: Architecture,
3746 T: ResumableSIMDSchema<L, R, A, FinalReturn = f32>,
3747 {
3748 let mut acc = Resumable(<T as ResumableSIMDSchema<L, R, A>>::init(arch));
3749 let iter = std::iter::zip(x.chunks(chunk_size), y.chunks(chunk_size));
3750 for (a, b) in iter {
3751 acc = simd_op(&acc, arch, a, b);
3752 }
3753 acc.0.sum()
3754 }
3755
3756 fn stress_test_with_resumable<
3757 A: Architecture,
3758 O: Default + SIMDSchema<f32, f32, A, Return = f32>,
3759 T: ResumableSIMDSchema<f32, f32, A, NonResumable = O, FinalReturn = f32>,
3760 Rand: Rng,
3761 >(
3762 arch: A,
3763 reference: fn(&[f32], &[f32]) -> f32,
3764 dim: usize,
3765 epsilon: f32,
3766 max_relative: f32,
3767 rng: &mut Rand,
3768 ) {
3769 let chunk_divisors: Vec<usize> = vec![1, 2, 3, 4, 16, 54, 64, 65, 70, 77];
3771 let checker = test_util::AdHocChecker::<f32, f32>::new(|a: &[f32], b: &[f32]| {
3772 let expected = reference(a, b);
3773 let got = simd_op(&O::default(), arch, a, b);
3774 println!("dim = {}", dim);
3775 assert_relative_eq!(
3776 expected,
3777 got,
3778 epsilon = epsilon,
3779 max_relative = max_relative,
3780 );
3781
3782 if dim == 0 {
3783 return;
3784 }
3785
3786 for d in &chunk_divisors {
3787 let chunk_size = dim / d + (!dim.is_multiple_of(*d) as usize);
3788 let chunked = test_resumable::<T, f32, f32, _>(arch, a, b, chunk_size);
3789 assert_relative_eq!(chunked, got, epsilon = epsilon, max_relative = max_relative);
3790 }
3791 });
3792
3793 test_util::test_distance_function(
3794 checker,
3795 rand_distr::Normal::new(0.0, 10.0).unwrap(),
3796 rand_distr::Normal::new(0.0, 10.0).unwrap(),
3797 dim,
3798 10,
3799 rng,
3800 )
3801 }
3802
3803 #[allow(clippy::too_many_arguments)]
3804 fn stress_test<L, R, DistLeft, DistRight, O, Rand, A>(
3805 arch: A,
3806 reference: fn(&[L], &[R]) -> f32,
3807 left_dist: DistLeft,
3808 right_dist: DistRight,
3809 dim: usize,
3810 epsilon: f32,
3811 max_relative: f32,
3812 rng: &mut Rand,
3813 ) where
3814 L: test_util::CornerCases,
3815 R: test_util::CornerCases,
3816 DistLeft: test_util::GenerateRandomArguments<L>,
3817 DistRight: test_util::GenerateRandomArguments<R>,
3818 O: Default + SIMDSchema<L, R, A, Return = f32>,
3819 Rand: Rng,
3820 A: Architecture,
3821 {
3822 let checker = test_util::Checker::<L, R, f32>::new(
3823 |x: &[L], y: &[R]| simd_op(&O::default(), arch, x, y),
3824 reference,
3825 |got, expected| {
3826 assert_relative_eq!(
3827 expected,
3828 got,
3829 epsilon = epsilon,
3830 max_relative = max_relative
3831 );
3832 },
3833 );
3834
3835 let trials = if cfg!(miri) { 0 } else { 10 };
3836
3837 test_util::test_distance_function(checker, left_dist, right_dist, dim, trials, rng);
3838 }
3839
3840 fn stress_test_linf<L, Dist, Rand, A>(
3841 arch: A,
3842 reference: fn(&[L]) -> f32,
3843 dist: Dist,
3844 dim: usize,
3845 epsilon: f32,
3846 max_relative: f32,
3847 rng: &mut Rand,
3848 ) where
3849 L: test_util::CornerCases + Copy,
3850 Dist: Clone + test_util::GenerateRandomArguments<L>,
3851 Rand: Rng,
3852 A: Architecture,
3853 LInfNorm: for<'a> Target1<A, f32, &'a [L]>,
3854 {
3855 let checker = test_util::Checker::<L, L, f32>::new(
3856 |x: &[L], _y: &[L]| (LInfNorm).run(arch, x),
3857 |x: &[L], _y: &[L]| reference(x),
3858 |got, expected| {
3859 assert_relative_eq!(
3860 expected,
3861 got,
3862 epsilon = epsilon,
3863 max_relative = max_relative
3864 );
3865 },
3866 );
3867
3868 println!("checking {dim}");
3869 test_util::test_distance_function(checker, dist.clone(), dist, dim, 10, rng);
3870 }
3871
3872 macro_rules! float_test {
3877 ($name:ident,
3878 $impl:ty,
3879 $resumable:ident,
3880 $reference:path,
3881 $eps:literal,
3882 $relative:literal,
3883 $seed:literal,
3884 $upper:literal,
3885 $($arch:tt)*
3886 ) => {
3887 #[test]
3888 fn $name() {
3889 if let Some(arch) = $($arch)* {
3890 let mut rng = StdRng::seed_from_u64($seed);
3891 for dim in 0..$upper {
3892 stress_test_with_resumable::<_, $impl, $resumable<_>, StdRng>(
3893 arch,
3894 |l, r| $reference(l, r).into_inner(),
3895 dim,
3896 $eps,
3897 $relative,
3898 &mut rng,
3899 );
3900 }
3901 }
3902 }
3903 }
3904 }
3905
3906 float_test!(
3911 test_l2_f32_current,
3912 L2,
3913 ResumableL2,
3914 reference::reference_squared_l2_f32_mathematical,
3915 1e-5,
3916 1e-5,
3917 0xf149c2bcde660128,
3918 64,
3919 Some(diskann_wide::ARCH)
3920 );
3921
3922 float_test!(
3923 test_l2_f32_scalar,
3924 L2,
3925 ResumableL2,
3926 reference::reference_squared_l2_f32_mathematical,
3927 1e-5,
3928 1e-5,
3929 0xf149c2bcde660128,
3930 64,
3931 Some(diskann_wide::arch::Scalar)
3932 );
3933
3934 #[cfg(target_arch = "x86_64")]
3935 float_test!(
3936 test_l2_f32_x86_64_v3,
3937 L2,
3938 ResumableL2,
3939 reference::reference_squared_l2_f32_mathematical,
3940 1e-5,
3941 1e-5,
3942 0xf149c2bcde660128,
3943 256,
3944 V3::new_checked()
3945 );
3946
3947 #[cfg(target_arch = "x86_64")]
3948 float_test!(
3949 test_l2_f32_x86_64_v4,
3950 L2,
3951 ResumableL2,
3952 reference::reference_squared_l2_f32_mathematical,
3953 1e-5,
3954 1e-5,
3955 0xf149c2bcde660128,
3956 256,
3957 V4::new_checked_miri()
3958 );
3959
3960 #[cfg(target_arch = "aarch64")]
3961 float_test!(
3962 test_l2_f32_aarch64_neon,
3963 L2,
3964 ResumableL2,
3965 reference::reference_squared_l2_f32_mathematical,
3966 1e-5,
3967 1e-5,
3968 0xf149c2bcde660128,
3969 256,
3970 Neon::new_checked()
3971 );
3972
3973 float_test!(
3978 test_ip_f32_current,
3979 IP,
3980 ResumableIP,
3981 reference::reference_innerproduct_f32_mathematical,
3982 2e-4,
3983 1e-3,
3984 0xb4687c17a9ea9866,
3985 64,
3986 Some(diskann_wide::ARCH)
3987 );
3988
3989 float_test!(
3990 test_ip_f32_scalar,
3991 IP,
3992 ResumableIP,
3993 reference::reference_innerproduct_f32_mathematical,
3994 2e-4,
3995 1e-3,
3996 0xb4687c17a9ea9866,
3997 64,
3998 Some(diskann_wide::arch::Scalar)
3999 );
4000
4001 #[cfg(target_arch = "x86_64")]
4002 float_test!(
4003 test_ip_f32_x86_64_v3,
4004 IP,
4005 ResumableIP,
4006 reference::reference_innerproduct_f32_mathematical,
4007 2e-4,
4008 1e-3,
4009 0xb4687c17a9ea9866,
4010 256,
4011 V3::new_checked()
4012 );
4013
4014 #[cfg(target_arch = "x86_64")]
4015 float_test!(
4016 test_ip_f32_x86_64_v4,
4017 IP,
4018 ResumableIP,
4019 reference::reference_innerproduct_f32_mathematical,
4020 2e-4,
4021 1e-3,
4022 0xb4687c17a9ea9866,
4023 256,
4024 V4::new_checked_miri()
4025 );
4026
4027 #[cfg(target_arch = "aarch64")]
4028 float_test!(
4029 test_ip_f32_aarch64_neon,
4030 IP,
4031 ResumableIP,
4032 reference::reference_innerproduct_f32_mathematical,
4033 2e-4,
4034 1e-3,
4035 0xb4687c17a9ea9866,
4036 256,
4037 Neon::new_checked()
4038 );
4039
4040 float_test!(
4045 test_cosine_f32_current,
4046 CosineStateless,
4047 ResumableCosine,
4048 reference::reference_cosine_f32_mathematical,
4049 1e-5,
4050 1e-5,
4051 0xe860e9dc65f38bb8,
4052 64,
4053 Some(diskann_wide::ARCH)
4054 );
4055
4056 float_test!(
4057 test_cosine_f32_scalar,
4058 CosineStateless,
4059 ResumableCosine,
4060 reference::reference_cosine_f32_mathematical,
4061 1e-5,
4062 1e-5,
4063 0xe860e9dc65f38bb8,
4064 64,
4065 Some(diskann_wide::arch::Scalar)
4066 );
4067
4068 #[cfg(target_arch = "x86_64")]
4069 float_test!(
4070 test_cosine_f32_x86_64_v3,
4071 CosineStateless,
4072 ResumableCosine,
4073 reference::reference_cosine_f32_mathematical,
4074 1e-5,
4075 1e-5,
4076 0xe860e9dc65f38bb8,
4077 256,
4078 V3::new_checked()
4079 );
4080
4081 #[cfg(target_arch = "x86_64")]
4082 float_test!(
4083 test_cosine_f32_x86_64_v4,
4084 CosineStateless,
4085 ResumableCosine,
4086 reference::reference_cosine_f32_mathematical,
4087 1e-5,
4088 1e-5,
4089 0xe860e9dc65f38bb8,
4090 256,
4091 V4::new_checked_miri()
4092 );
4093
4094 #[cfg(target_arch = "aarch64")]
4095 float_test!(
4096 test_cosine_f32_aarch64_neon,
4097 CosineStateless,
4098 ResumableCosine,
4099 reference::reference_cosine_f32_mathematical,
4100 1e-5,
4101 1e-5,
4102 0xe860e9dc65f38bb8,
4103 256,
4104 Neon::new_checked()
4105 );
4106
4107 macro_rules! half_test {
4112 ($name:ident,
4113 $impl:ty,
4114 $reference:path,
4115 $eps:literal,
4116 $relative:literal,
4117 $seed:literal,
4118 $upper:literal,
4119 $($arch:tt)*
4120 ) => {
4121 #[test]
4122 fn $name() {
4123 if let Some(arch) = $($arch)* {
4124 let mut rng = StdRng::seed_from_u64($seed);
4125 for dim in 0..$upper {
4126 stress_test::<
4127 Half,
4128 Half,
4129 rand_distr::Normal<f32>,
4130 rand_distr::Normal<f32>,
4131 $impl,
4132 StdRng,
4133 _
4134 >(
4135 arch,
4136 |l, r| $reference(l, r).into_inner(),
4137 rand_distr::Normal::new(0.0, 10.0).unwrap(),
4138 rand_distr::Normal::new(0.0, 10.0).unwrap(),
4139 dim,
4140 $eps,
4141 $relative,
4142 &mut rng
4143 );
4144 }
4145 }
4146 }
4147 }
4148 }
4149
4150 half_test!(
4155 test_l2_f16_current,
4156 L2,
4157 reference::reference_squared_l2_f16_mathematical,
4158 1e-5,
4159 1e-5,
4160 0x87ca6f1051667500,
4161 64,
4162 Some(diskann_wide::ARCH)
4163 );
4164
4165 half_test!(
4166 test_l2_f16_scalar,
4167 L2,
4168 reference::reference_squared_l2_f16_mathematical,
4169 1e-5,
4170 1e-5,
4171 0x87ca6f1051667500,
4172 64,
4173 Some(diskann_wide::arch::Scalar)
4174 );
4175
4176 #[cfg(target_arch = "x86_64")]
4177 half_test!(
4178 test_l2_f16_x86_64_v3,
4179 L2,
4180 reference::reference_squared_l2_f16_mathematical,
4181 1e-5,
4182 1e-5,
4183 0x87ca6f1051667500,
4184 256,
4185 V3::new_checked()
4186 );
4187
4188 #[cfg(target_arch = "x86_64")]
4189 half_test!(
4190 test_l2_f16_x86_64_v4,
4191 L2,
4192 reference::reference_squared_l2_f16_mathematical,
4193 1e-5,
4194 1e-5,
4195 0x87ca6f1051667500,
4196 256,
4197 V4::new_checked_miri()
4198 );
4199
4200 #[cfg(target_arch = "aarch64")]
4201 half_test!(
4202 test_l2_f16_aarch64_neon,
4203 L2,
4204 reference::reference_squared_l2_f16_mathematical,
4205 1e-5,
4206 1e-5,
4207 0x87ca6f1051667500,
4208 256,
4209 Neon::new_checked()
4210 );
4211
4212 half_test!(
4217 test_ip_f16_current,
4218 IP,
4219 reference::reference_innerproduct_f16_mathematical,
4220 2e-4,
4221 2e-4,
4222 0x5909f5f20307ccbe,
4223 64,
4224 Some(diskann_wide::ARCH)
4225 );
4226
4227 half_test!(
4228 test_ip_f16_scalar,
4229 IP,
4230 reference::reference_innerproduct_f16_mathematical,
4231 2e-4,
4232 2e-4,
4233 0x5909f5f20307ccbe,
4234 64,
4235 Some(diskann_wide::arch::Scalar)
4236 );
4237
4238 #[cfg(target_arch = "x86_64")]
4239 half_test!(
4240 test_ip_f16_x86_64_v3,
4241 IP,
4242 reference::reference_innerproduct_f16_mathematical,
4243 2e-4,
4244 2e-4,
4245 0x5909f5f20307ccbe,
4246 256,
4247 V3::new_checked()
4248 );
4249
4250 #[cfg(target_arch = "x86_64")]
4251 half_test!(
4252 test_ip_f16_x86_64_v4,
4253 IP,
4254 reference::reference_innerproduct_f16_mathematical,
4255 2e-4,
4256 2e-4,
4257 0x5909f5f20307ccbe,
4258 256,
4259 V4::new_checked_miri()
4260 );
4261
4262 #[cfg(target_arch = "aarch64")]
4263 half_test!(
4264 test_ip_f16_aarch64_neon,
4265 IP,
4266 reference::reference_innerproduct_f16_mathematical,
4267 2e-4,
4268 2e-4,
4269 0x5909f5f20307ccbe,
4270 256,
4271 Neon::new_checked()
4272 );
4273
4274 half_test!(
4279 test_cosine_f16_current,
4280 CosineStateless,
4281 reference::reference_cosine_f16_mathematical,
4282 1e-5,
4283 1e-5,
4284 0x41dda34655f05ef6,
4285 64,
4286 Some(diskann_wide::ARCH)
4287 );
4288
4289 half_test!(
4290 test_cosine_f16_scalar,
4291 CosineStateless,
4292 reference::reference_cosine_f16_mathematical,
4293 1e-5,
4294 1e-5,
4295 0x41dda34655f05ef6,
4296 64,
4297 Some(diskann_wide::arch::Scalar)
4298 );
4299
4300 #[cfg(target_arch = "x86_64")]
4301 half_test!(
4302 test_cosine_f16_x86_64_v3,
4303 CosineStateless,
4304 reference::reference_cosine_f16_mathematical,
4305 1e-5,
4306 1e-5,
4307 0x41dda34655f05ef6,
4308 256,
4309 V3::new_checked()
4310 );
4311
4312 #[cfg(target_arch = "x86_64")]
4313 half_test!(
4314 test_cosine_f16_x86_64_v4,
4315 CosineStateless,
4316 reference::reference_cosine_f16_mathematical,
4317 1e-5,
4318 1e-5,
4319 0x41dda34655f05ef6,
4320 256,
4321 V4::new_checked_miri()
4322 );
4323
4324 #[cfg(target_arch = "aarch64")]
4325 half_test!(
4326 test_cosine_f16_aarch64_neon,
4327 CosineStateless,
4328 reference::reference_cosine_f16_mathematical,
4329 1e-5,
4330 1e-5,
4331 0x41dda34655f05ef6,
4332 256,
4333 Neon::new_checked()
4334 );
4335
4336 macro_rules! int_test {
4341 (
4342 $name:ident,
4343 $T:ty,
4344 $impl:ty,
4345 $reference:path,
4346 $seed:literal,
4347 $upper:literal,
4348 { $($arch:tt)* }
4349 ) => {
4350 #[test]
4351 fn $name() {
4352 if let Some(arch) = $($arch)* {
4353 let mut rng = StdRng::seed_from_u64($seed);
4354 for dim in 0..$upper {
4355 stress_test::<$T, $T, _, _, $impl, _, _>(
4356 arch,
4357 |l, r| $reference(l, r).into_inner(),
4358 StandardUniform,
4359 StandardUniform,
4360 dim,
4361 0.0,
4362 0.0,
4363 &mut rng,
4364 )
4365 }
4366 }
4367 }
4368 }
4369 }
4370
4371 int_test!(
4376 test_l2_u8_current,
4377 u8,
4378 L2,
4379 reference::reference_squared_l2_u8_mathematical,
4380 0x945bdc37d8279d4b,
4381 128,
4382 { Some(ARCH) }
4383 );
4384
4385 int_test!(
4386 test_l2_u8_scalar,
4387 u8,
4388 L2,
4389 reference::reference_squared_l2_u8_mathematical,
4390 0x74c86334ab7a51f9,
4391 128,
4392 { Some(diskann_wide::arch::Scalar) }
4393 );
4394
4395 #[cfg(target_arch = "x86_64")]
4396 int_test!(
4397 test_l2_u8_x86_64_v3,
4398 u8,
4399 L2,
4400 reference::reference_squared_l2_u8_mathematical,
4401 0x74c86334ab7a51f9,
4402 256,
4403 { V3::new_checked() }
4404 );
4405
4406 #[cfg(target_arch = "x86_64")]
4407 int_test!(
4408 test_l2_u8_x86_64_v4,
4409 u8,
4410 L2,
4411 reference::reference_squared_l2_u8_mathematical,
4412 0x74c86334ab7a51f9,
4413 320,
4414 { V4::new_checked_miri() }
4415 );
4416
4417 #[cfg(target_arch = "aarch64")]
4418 int_test!(
4419 test_l2_u8_aarch64_neon,
4420 u8,
4421 L2,
4422 reference::reference_squared_l2_u8_mathematical,
4423 0x74c86334ab7a51f9,
4424 320,
4425 { Neon::new_checked() }
4426 );
4427
4428 int_test!(
4429 test_ip_u8_current,
4430 u8,
4431 IP,
4432 reference::reference_innerproduct_u8_mathematical,
4433 0xcbe0342c75085fd5,
4434 64,
4435 { Some(ARCH) }
4436 );
4437
4438 int_test!(
4439 test_ip_u8_scalar,
4440 u8,
4441 IP,
4442 reference::reference_innerproduct_u8_mathematical,
4443 0x888e07fc489e773f,
4444 64,
4445 { Some(diskann_wide::arch::Scalar) }
4446 );
4447
4448 #[cfg(target_arch = "x86_64")]
4449 int_test!(
4450 test_ip_u8_x86_64_v3,
4451 u8,
4452 IP,
4453 reference::reference_innerproduct_u8_mathematical,
4454 0x888e07fc489e773f,
4455 256,
4456 { V3::new_checked() }
4457 );
4458
4459 #[cfg(target_arch = "x86_64")]
4460 int_test!(
4461 test_ip_u8_x86_64_v4,
4462 u8,
4463 IP,
4464 reference::reference_innerproduct_u8_mathematical,
4465 0x888e07fc489e773f,
4466 320,
4467 { V4::new_checked_miri() }
4468 );
4469
4470 #[cfg(target_arch = "aarch64")]
4471 int_test!(
4472 test_ip_u8_aarch64_neon,
4473 u8,
4474 IP,
4475 reference::reference_innerproduct_u8_mathematical,
4476 0x888e07fc489e773f,
4477 320,
4478 { Neon::new_checked() }
4479 );
4480
4481 int_test!(
4482 test_cosine_u8_current,
4483 u8,
4484 CosineStateless,
4485 reference::reference_cosine_u8_mathematical,
4486 0x96867b6aff616b28,
4487 64,
4488 { Some(ARCH) }
4489 );
4490
4491 int_test!(
4492 test_cosine_u8_scalar,
4493 u8,
4494 CosineStateless,
4495 reference::reference_cosine_u8_mathematical,
4496 0xcc258c9391733211,
4497 64,
4498 { Some(diskann_wide::arch::Scalar) }
4499 );
4500
4501 #[cfg(target_arch = "x86_64")]
4502 int_test!(
4503 test_cosine_u8_x86_64_v3,
4504 u8,
4505 CosineStateless,
4506 reference::reference_cosine_u8_mathematical,
4507 0xcc258c9391733211,
4508 256,
4509 { V3::new_checked() }
4510 );
4511
4512 #[cfg(target_arch = "x86_64")]
4513 int_test!(
4514 test_cosine_u8_x86_64_v4,
4515 u8,
4516 CosineStateless,
4517 reference::reference_cosine_u8_mathematical,
4518 0xcc258c9391733211,
4519 320,
4520 { V4::new_checked_miri() }
4521 );
4522
4523 #[cfg(target_arch = "aarch64")]
4524 int_test!(
4525 test_cosine_u8_aarch64_neon,
4526 u8,
4527 CosineStateless,
4528 reference::reference_cosine_u8_mathematical,
4529 0xcc258c9391733211,
4530 320,
4531 { Neon::new_checked() }
4532 );
4533
4534 int_test!(
4539 test_l2_i8_current,
4540 i8,
4541 L2,
4542 reference::reference_squared_l2_i8_mathematical,
4543 0xa60136248cd3c2f0,
4544 64,
4545 { Some(ARCH) }
4546 );
4547
4548 int_test!(
4549 test_l2_i8_scalar,
4550 i8,
4551 L2,
4552 reference::reference_squared_l2_i8_mathematical,
4553 0x3e8bada709e176be,
4554 64,
4555 { Some(diskann_wide::arch::Scalar) }
4556 );
4557
4558 #[cfg(target_arch = "x86_64")]
4559 int_test!(
4560 test_l2_i8_x86_64_v3,
4561 i8,
4562 L2,
4563 reference::reference_squared_l2_i8_mathematical,
4564 0x3e8bada709e176be,
4565 256,
4566 { V3::new_checked() }
4567 );
4568
4569 #[cfg(target_arch = "x86_64")]
4570 int_test!(
4571 test_l2_i8_x86_64_v4,
4572 i8,
4573 L2,
4574 reference::reference_squared_l2_i8_mathematical,
4575 0x3e8bada709e176be,
4576 320,
4577 { V4::new_checked_miri() }
4578 );
4579
4580 #[cfg(target_arch = "aarch64")]
4581 int_test!(
4582 test_l2_i8_aarch64_neon,
4583 i8,
4584 L2,
4585 reference::reference_squared_l2_i8_mathematical,
4586 0x3e8bada709e176be,
4587 320,
4588 { Neon::new_checked() }
4589 );
4590
4591 int_test!(
4592 test_ip_i8_current,
4593 i8,
4594 IP,
4595 reference::reference_innerproduct_i8_mathematical,
4596 0xe8306104740509e1,
4597 64,
4598 { Some(ARCH) }
4599 );
4600
4601 int_test!(
4602 test_ip_i8_scalar,
4603 i8,
4604 IP,
4605 reference::reference_innerproduct_i8_mathematical,
4606 0x8a263408c7b31d85,
4607 64,
4608 { Some(diskann_wide::arch::Scalar) }
4609 );
4610
4611 #[cfg(target_arch = "x86_64")]
4612 int_test!(
4613 test_ip_i8_x86_64_v3,
4614 i8,
4615 IP,
4616 reference::reference_innerproduct_i8_mathematical,
4617 0x8a263408c7b31d85,
4618 256,
4619 { V3::new_checked() }
4620 );
4621
4622 #[cfg(target_arch = "x86_64")]
4623 int_test!(
4624 test_ip_i8_x86_64_v4,
4625 i8,
4626 IP,
4627 reference::reference_innerproduct_i8_mathematical,
4628 0x8a263408c7b31d85,
4629 320,
4630 { V4::new_checked_miri() }
4631 );
4632
4633 #[cfg(target_arch = "aarch64")]
4634 int_test!(
4635 test_ip_i8_aarch64_neon,
4636 i8,
4637 IP,
4638 reference::reference_innerproduct_i8_mathematical,
4639 0x8a263408c7b31d85,
4640 320,
4641 { Neon::new_checked() }
4642 );
4643
4644 int_test!(
4645 test_cosine_i8_current,
4646 i8,
4647 CosineStateless,
4648 reference::reference_cosine_i8_mathematical,
4649 0x818c210190701e4b,
4650 64,
4651 { Some(ARCH) }
4652 );
4653
4654 int_test!(
4655 test_cosine_i8_scalar,
4656 i8,
4657 CosineStateless,
4658 reference::reference_cosine_i8_mathematical,
4659 0x2d077bed2629b18e,
4660 64,
4661 { Some(diskann_wide::arch::Scalar) }
4662 );
4663
4664 #[cfg(target_arch = "x86_64")]
4665 int_test!(
4666 test_cosine_i8_x86_64_v3,
4667 i8,
4668 CosineStateless,
4669 reference::reference_cosine_i8_mathematical,
4670 0x2d077bed2629b18e,
4671 256,
4672 { V3::new_checked() }
4673 );
4674
4675 #[cfg(target_arch = "x86_64")]
4676 int_test!(
4677 test_cosine_i8_x86_64_v4,
4678 i8,
4679 CosineStateless,
4680 reference::reference_cosine_i8_mathematical,
4681 0x2d077bed2629b18e,
4682 320,
4683 { V4::new_checked_miri() }
4684 );
4685
4686 #[cfg(target_arch = "aarch64")]
4687 int_test!(
4688 test_cosine_i8_aarch64_neon,
4689 i8,
4690 CosineStateless,
4691 reference::reference_cosine_i8_mathematical,
4692 0x2d077bed2629b18e,
4693 320,
4694 { Neon::new_checked() }
4695 );
4696
4697 macro_rules! linf_test {
4702 ($name:ident,
4703 $T:ty,
4704 $reference:path,
4705 $eps:literal,
4706 $relative:literal,
4707 $seed:literal,
4708 $upper:literal,
4709 $($arch:tt)*
4710 ) => {
4711 #[test]
4712 fn $name() {
4713 if let Some(arch) = $($arch)* {
4714 let mut rng = StdRng::seed_from_u64($seed);
4715 for dim in 0..$upper {
4716 stress_test_linf::<$T, _, StdRng, _>(
4717 arch,
4718 |l| $reference(l).into_inner(),
4719 rand_distr::Normal::new(-10.0, 10.0).unwrap(),
4720 dim,
4721 $eps,
4722 $relative,
4723 &mut rng,
4724 );
4725 }
4726 }
4727 }
4728 }
4729 }
4730
4731 linf_test!(
4732 test_linf_f32_scalar,
4733 f32,
4734 reference::reference_linf_f32_mathematical,
4735 1e-6,
4736 1e-6,
4737 0xf149c2bcde660128,
4738 256,
4739 Some(Scalar::new())
4740 );
4741
4742 #[cfg(target_arch = "x86_64")]
4743 linf_test!(
4744 test_linf_f32_v3,
4745 f32,
4746 reference::reference_linf_f32_mathematical,
4747 1e-6,
4748 1e-6,
4749 0xf149c2bcde660128,
4750 256,
4751 V3::new_checked()
4752 );
4753
4754 #[cfg(target_arch = "x86_64")]
4755 linf_test!(
4756 test_linf_f32_v4,
4757 f32,
4758 reference::reference_linf_f32_mathematical,
4759 1e-6,
4760 1e-6,
4761 0xf149c2bcde660128,
4762 256,
4763 V4::new_checked_miri()
4764 );
4765
4766 #[cfg(target_arch = "aarch64")]
4767 linf_test!(
4768 test_linf_f32_neon,
4769 f32,
4770 reference::reference_linf_f32_mathematical,
4771 1e-6,
4772 1e-6,
4773 0xf149c2bcde660128,
4774 256,
4775 Neon::new_checked()
4776 );
4777
4778 linf_test!(
4779 test_linf_f16_scalar,
4780 f16,
4781 reference::reference_linf_f16_mathematical,
4782 1e-6,
4783 1e-6,
4784 0xf149c2bcde660128,
4785 256,
4786 Some(Scalar::new())
4787 );
4788
4789 #[cfg(target_arch = "x86_64")]
4790 linf_test!(
4791 test_linf_f16_v3,
4792 f16,
4793 reference::reference_linf_f16_mathematical,
4794 1e-6,
4795 1e-6,
4796 0xf149c2bcde660128,
4797 256,
4798 V3::new_checked()
4799 );
4800
4801 #[cfg(target_arch = "x86_64")]
4802 linf_test!(
4803 test_linf_f16_v4,
4804 f16,
4805 reference::reference_linf_f16_mathematical,
4806 1e-6,
4807 1e-6,
4808 0xf149c2bcde660128,
4809 256,
4810 V4::new_checked_miri()
4811 );
4812
4813 #[cfg(target_arch = "aarch64")]
4814 linf_test!(
4815 test_linf_f16_neon,
4816 f16,
4817 reference::reference_linf_f16_mathematical,
4818 1e-6,
4819 1e-6,
4820 0xf149c2bcde660128,
4821 256,
4822 Neon::new_checked()
4823 );
4824
4825 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4830 enum DataType {
4831 Float32,
4832 Float16,
4833 UInt8,
4834 Int8,
4835 }
4836
4837 trait AsDataType {
4838 fn as_data_type() -> DataType;
4839 }
4840
4841 impl AsDataType for f32 {
4842 fn as_data_type() -> DataType {
4843 DataType::Float32
4844 }
4845 }
4846
4847 impl AsDataType for f16 {
4848 fn as_data_type() -> DataType {
4849 DataType::Float16
4850 }
4851 }
4852
4853 impl AsDataType for u8 {
4854 fn as_data_type() -> DataType {
4855 DataType::UInt8
4856 }
4857 }
4858
4859 impl AsDataType for i8 {
4860 fn as_data_type() -> DataType {
4861 DataType::Int8
4862 }
4863 }
4864
4865 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4866 enum Arch {
4867 Scalar,
4868 #[expect(non_camel_case_types)]
4869 X86_64_V3,
4870 #[expect(non_camel_case_types)]
4871 X86_64_V4,
4872 Aarch64Neon,
4873 }
4874
4875 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4876 struct Key {
4877 arch: Arch,
4878 left: DataType,
4879 right: DataType,
4880 }
4881
4882 impl Key {
4883 fn new(arch: Arch, left: DataType, right: DataType) -> Self {
4884 Self { arch, left, right }
4885 }
4886 }
4887
4888 static MIRI_BOUNDS: LazyLock<HashMap<Key, usize>> = LazyLock::new(|| {
4889 use Arch::{Aarch64Neon, Scalar, X86_64_V3, X86_64_V4};
4890 use DataType::{Float16, Float32, Int8, UInt8};
4891
4892 [
4893 (Key::new(Scalar, Float32, Float32), 64),
4894 (Key::new(X86_64_V3, Float32, Float32), 256),
4895 (Key::new(X86_64_V4, Float32, Float32), 256),
4896 (Key::new(Aarch64Neon, Float32, Float32), 128),
4897 (Key::new(Scalar, Float16, Float16), 64),
4898 (Key::new(X86_64_V3, Float16, Float16), 256),
4899 (Key::new(X86_64_V4, Float16, Float16), 256),
4900 (Key::new(Aarch64Neon, Float16, Float16), 128),
4901 (Key::new(Scalar, Float32, Float16), 64),
4902 (Key::new(X86_64_V3, Float32, Float16), 256),
4903 (Key::new(X86_64_V4, Float32, Float16), 256),
4904 (Key::new(Aarch64Neon, Float32, Float16), 128),
4905 (Key::new(Scalar, UInt8, UInt8), 64),
4906 (Key::new(X86_64_V3, UInt8, UInt8), 256),
4907 (Key::new(X86_64_V4, UInt8, UInt8), 320),
4908 (Key::new(Aarch64Neon, UInt8, UInt8), 128),
4909 (Key::new(Scalar, Int8, Int8), 64),
4910 (Key::new(X86_64_V3, Int8, Int8), 256),
4911 (Key::new(X86_64_V4, Int8, Int8), 320),
4912 (Key::new(Aarch64Neon, Int8, Int8), 128),
4913 ]
4914 .into_iter()
4915 .collect()
4916 });
4917
4918 macro_rules! test_bounds {
4919 (
4920 $function:ident,
4921 $left:ty,
4922 $left_ex:expr,
4923 $right:ty,
4924 $right_ex:expr
4925 ) => {
4926 #[test]
4927 fn $function() {
4928 let left: $left = $left_ex;
4929 let right: $right = $right_ex;
4930
4931 let left_type = <$left>::as_data_type();
4932 let right_type = <$right>::as_data_type();
4933
4934 {
4936 let max = MIRI_BOUNDS[&Key::new(Arch::Scalar, left_type, right_type)];
4937 for dim in 0..max {
4938 let left: Vec<$left> = vec![left; dim];
4939 let right: Vec<$right> = vec![right; dim];
4940
4941 let arch = diskann_wide::arch::Scalar;
4942 simd_op(&L2, arch, left.as_slice(), right.as_slice());
4943 simd_op(&IP, arch, left.as_slice(), right.as_slice());
4944 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4945 }
4946 }
4947
4948 #[cfg(target_arch = "x86_64")]
4949 if let Some(arch) = V3::new_checked() {
4950 let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V3, left_type, right_type)];
4951 for dim in 0..max {
4952 let left: Vec<$left> = vec![left; dim];
4953 let right: Vec<$right> = vec![right; dim];
4954
4955 simd_op(&L2, arch, left.as_slice(), right.as_slice());
4956 simd_op(&IP, arch, left.as_slice(), right.as_slice());
4957 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4958 }
4959 }
4960
4961 #[cfg(target_arch = "x86_64")]
4962 if let Some(arch) = V4::new_checked_miri() {
4963 let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V4, left_type, right_type)];
4964 for dim in 0..max {
4965 let left: Vec<$left> = vec![left; dim];
4966 let right: Vec<$right> = vec![right; dim];
4967
4968 simd_op(&L2, arch, left.as_slice(), right.as_slice());
4969 simd_op(&IP, arch, left.as_slice(), right.as_slice());
4970 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4971 }
4972 }
4973
4974 #[cfg(target_arch = "aarch64")]
4975 if let Some(arch) = Neon::new_checked() {
4976 let max = MIRI_BOUNDS[&Key::new(Arch::Aarch64Neon, left_type, right_type)];
4977 for dim in 0..max {
4978 let left: Vec<$left> = vec![left; dim];
4979 let right: Vec<$right> = vec![right; dim];
4980
4981 simd_op(&L2, arch, left.as_slice(), right.as_slice());
4982 simd_op(&IP, arch, left.as_slice(), right.as_slice());
4983 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4984 }
4985 }
4986 }
4987 };
4988 }
4989
4990 test_bounds!(miri_test_bounds_f32xf32, f32, 1.0f32, f32, 2.0f32);
4991 test_bounds!(
4992 miri_test_bounds_f16xf16,
4993 f16,
4994 diskann_wide::cast_f32_to_f16(1.0f32),
4995 f16,
4996 diskann_wide::cast_f32_to_f16(2.0f32)
4997 );
4998 test_bounds!(
4999 miri_test_bounds_f32xf16,
5000 f32,
5001 1.0f32,
5002 f16,
5003 diskann_wide::cast_f32_to_f16(2.0f32)
5004 );
5005 test_bounds!(miri_test_bounds_u8xu8, u8, 1u8, u8, 1u8);
5006 test_bounds!(miri_test_bounds_i8xi8, i8, 1i8, i8, 1i8);
5007}