1use std::convert::AsRef;
7
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11#[cfg(not(target_arch = "aarch64"))]
12use diskann_wide::SIMDDotProduct;
13use diskann_wide::{
14 arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDMulAdd, SIMDSumTree,
15 SIMDVector,
16};
17
18use crate::Half;
19
20pub trait LossyF32Conversion: Copy {
22 fn as_f32_lossy(self) -> f32;
23}
24
25impl LossyF32Conversion for f32 {
26 fn as_f32_lossy(self) -> f32 {
27 self
28 }
29}
30
31impl LossyF32Conversion for i32 {
32 fn as_f32_lossy(self) -> f32 {
33 self as f32
34 }
35}
36
37cfg_if::cfg_if! {
38 if #[cfg(miri)] {
39 fn force_eval(_x: f32) {}
40 } else if #[cfg(target_arch = "x86_64")] {
41 use std::arch::asm;
42
43 #[inline(always)]
49 fn force_eval(x: f32) {
50 unsafe {
55 asm!(
56 "/* {0} */",
58 in(xmm_reg) x,
61 options(nostack, nomem, preserves_flags)
70 )
71 }
72 }
73 } else {
74 fn force_eval(_x: f32) {}
76 }
77}
78
79#[derive(Debug, Clone, Copy)]
90pub struct Loader<Schema, Left, Right, A>
91where
92 Schema: SIMDSchema<Left, Right, A>,
93 A: Architecture,
94{
95 arch: A,
96 schema: Schema,
97 left: *const Left,
98 right: *const Right,
99 len: usize,
100}
101
102impl<Schema, Left, Right, A> Loader<Schema, Left, Right, A>
103where
104 Schema: SIMDSchema<Left, Right, A>,
105 A: Architecture,
106{
107 #[inline(always)]
112 fn new(arch: A, schema: Schema, left: *const Left, right: *const Right, len: usize) -> Self {
113 Self {
114 arch,
115 schema,
116 left,
117 right,
118 len,
119 }
120 }
121
122 #[inline(always)]
124 fn arch(&self) -> A {
125 self.arch
126 }
127
128 #[inline(always)]
130 fn schema(&self) -> Schema {
131 self.schema
132 }
133
134 #[inline(always)]
165 unsafe fn load(&self, block: usize, offset: usize) -> (Schema::Left, Schema::Right) {
166 let stride = Schema::SIMDWidth::value();
167 let block_stride = stride * Schema::Main::BLOCK_SIZE;
168 let offset = block_stride * block + stride * offset;
169
170 debug_assert!(
171 offset + stride <= self.len,
172 "length = {}, offset = {}",
173 self.len,
174 offset
175 );
176
177 (
178 Schema::Left::load_simd(self.arch, self.left.add(offset)),
179 Schema::Right::load_simd(self.arch, self.right.add(offset)),
180 )
181 }
182}
183
184pub trait MainLoop {
186 const BLOCK_SIZE: usize;
193
194 unsafe fn main<S, L, R, A>(
230 loader: &Loader<S, L, R, A>,
231 trip_count: usize,
232 epilogues: usize,
233 ) -> S::Accumulator
234 where
235 A: Architecture,
236 S: SIMDSchema<L, R, A>;
237}
238pub struct Strategy1x1;
241
242pub struct Strategy2x1;
245
246pub struct Strategy4x1;
249
250pub struct Strategy4x2;
253
254pub struct Strategy2x4;
257
258impl MainLoop for Strategy1x1 {
259 const BLOCK_SIZE: usize = 1;
260
261 #[inline(always)]
262 unsafe fn main<S, L, R, A>(
263 loader: &Loader<S, L, R, A>,
264 trip_count: usize,
265 _epilogues: usize,
266 ) -> S::Accumulator
267 where
268 A: Architecture,
269 S: SIMDSchema<L, R, A>,
270 {
271 let arch = loader.arch();
272 let schema = loader.schema();
273
274 let mut s0 = schema.init(arch);
275 for i in 0..trip_count {
276 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
277 }
278
279 s0
280 }
281}
282
283impl MainLoop for Strategy2x1 {
284 const BLOCK_SIZE: usize = 2;
285
286 #[inline(always)]
287 unsafe fn main<S, L, R, A>(
288 loader: &Loader<S, L, R, A>,
289 trip_count: usize,
290 epilogues: usize,
291 ) -> S::Accumulator
292 where
293 A: Architecture,
294 S: SIMDSchema<L, R, A>,
295 {
296 let arch = loader.arch();
297 let schema = loader.schema();
298
299 let mut s0 = schema.init(arch);
300 let mut s1 = schema.init(arch);
301
302 for i in 0..trip_count {
303 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
304 s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
305 }
306
307 let mut s = schema.combine(s0, s1);
308 if epilogues != 0 {
309 s = schema.accumulate_tuple(s, loader.load(trip_count, 0));
310 }
311
312 s
313 }
314}
315
316impl MainLoop for Strategy4x1 {
317 const BLOCK_SIZE: usize = 4;
318
319 #[inline(always)]
320 unsafe fn main<S, L, R, A>(
321 loader: &Loader<S, L, R, A>,
322 trip_count: usize,
323 epilogues: usize,
324 ) -> S::Accumulator
325 where
326 A: Architecture,
327 S: SIMDSchema<L, R, A>,
328 {
329 let arch = loader.arch();
330 let schema = loader.schema();
331
332 let mut s0 = schema.init(arch);
333 let mut s1 = schema.init(arch);
334 let mut s2 = schema.init(arch);
335 let mut s3 = schema.init(arch);
336
337 for i in 0..trip_count {
338 s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
339 s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
340 s2 = schema.accumulate_tuple(s2, loader.load(i, 2));
341 s3 = schema.accumulate_tuple(s3, loader.load(i, 3));
342 }
343
344 if epilogues >= 1 {
345 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
346 }
347
348 if epilogues >= 2 {
349 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
350 }
351
352 if epilogues >= 3 {
353 s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
354 }
355
356 schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
357 }
358}
359
360impl MainLoop for Strategy4x2 {
361 const BLOCK_SIZE: usize = 4;
362
363 #[inline(always)]
364 unsafe fn main<S, L, R, A>(
365 loader: &Loader<S, L, R, A>,
366 trip_count: usize,
367 epilogues: usize,
368 ) -> S::Accumulator
369 where
370 A: Architecture,
371 S: SIMDSchema<L, R, A>,
372 {
373 let arch = loader.arch();
374 let schema = loader.schema();
375
376 let mut s0 = schema.init(arch);
377 let mut s1 = schema.init(arch);
378 let mut s2 = schema.init(arch);
379 let mut s3 = schema.init(arch);
380
381 for i in 0..(trip_count / 2) {
382 let j = 2 * i;
383 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
384 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
385 s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
386 s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
387
388 s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
389 s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
390 s2 = schema.accumulate_tuple(s2, loader.load(j, 6));
391 s3 = schema.accumulate_tuple(s3, loader.load(j, 7));
392 }
393
394 if !trip_count.is_multiple_of(2) {
395 let j = trip_count - 1;
397 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
398 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
399 s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
400 s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
401 }
402
403 if epilogues >= 1 {
404 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
405 }
406
407 if epilogues >= 2 {
408 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
409 }
410
411 if epilogues >= 3 {
412 s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
413 }
414
415 schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
416 }
417}
418
419impl MainLoop for Strategy2x4 {
420 const BLOCK_SIZE: usize = 4;
421
422 #[inline(always)]
428 unsafe fn main<S, L, R, A>(
429 loader: &Loader<S, L, R, A>,
430 trip_count: usize,
431 epilogues: usize,
432 ) -> S::Accumulator
433 where
434 A: Architecture,
435 S: SIMDSchema<L, R, A>,
436 {
437 let arch = loader.arch();
438 let schema = loader.schema();
439
440 let mut s0 = schema.init(arch);
441 let mut s1 = schema.init(arch);
442
443 for i in 0..(trip_count / 2) {
444 let j = 2 * i;
445 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
446 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
447 s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
448 s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
449
450 s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
451 s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
452 s0 = schema.accumulate_tuple(s0, loader.load(j, 6));
453 s1 = schema.accumulate_tuple(s1, loader.load(j, 7));
454 }
455
456 if !trip_count.is_multiple_of(2) {
457 let j = trip_count - 1;
458 s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
459 s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
460 s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
461 s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
462 }
463
464 if epilogues >= 1 {
465 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
466 }
467
468 if epilogues >= 2 {
469 s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
470 }
471
472 if epilogues >= 3 {
473 s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 2));
474 }
475
476 schema.combine(s0, s1)
477 }
478}
479
480pub trait SIMDSchema<T, U, A: Architecture = diskann_wide::arch::Current>: Copy {
488 type SIMDWidth: Constant<Type = usize>;
491
492 type Accumulator: std::ops::Add<Output = Self::Accumulator> + std::fmt::Debug + Copy;
494
495 type Left: SIMDVector<Arch = A, Scalar = T, ConstLanes = Self::SIMDWidth>;
497
498 type Right: SIMDVector<Arch = A, Scalar = U, ConstLanes = Self::SIMDWidth>;
500
501 type Return;
504
505 type Main: MainLoop;
507
508 fn init(&self, arch: A) -> Self::Accumulator;
510
511 fn accumulate(
513 &self,
514 x: Self::Left,
515 y: Self::Right,
516 acc: Self::Accumulator,
517 ) -> Self::Accumulator;
518
519 #[inline(always)]
521 fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
522 x + y
523 }
524
525 #[inline(always)]
543 unsafe fn epilogue(
544 &self,
545 arch: A,
546 x: *const T,
547 y: *const U,
548 len: usize,
549 acc: Self::Accumulator,
550 ) -> Self::Accumulator {
551 let a = Self::Left::load_simd_first(arch, x, len);
554
555 let b = Self::Right::load_simd_first(arch, y, len);
558 self.accumulate(a, b, acc)
559 }
560
561 fn reduce(&self, x: Self::Accumulator) -> Self::Return;
565
566 #[inline(always)]
571 fn get_simd_width() -> usize {
572 Self::SIMDWidth::value()
573 }
574
575 #[inline(always)]
581 fn get_main_bocksize() -> usize {
582 Self::Main::BLOCK_SIZE
583 }
584
585 #[doc(hidden)]
588 #[inline(always)]
589 fn accumulate_tuple(
590 &self,
591 acc: Self::Accumulator,
592 (x, y): (Self::Left, Self::Right),
593 ) -> Self::Accumulator {
594 self.accumulate(x, y, acc)
595 }
596}
597
598pub trait ResumableSIMDSchema<T, U, A = diskann_wide::arch::Current>: Copy
606where
607 A: Architecture,
608{
609 type NonResumable: SIMDSchema<T, U, A> + Default;
611 type FinalReturn;
612
613 fn init(arch: A) -> Self;
614 fn combine_with(&self, other: <Self::NonResumable as SIMDSchema<T, U, A>>::Accumulator)
615 -> Self;
616 fn sum(&self) -> Self::FinalReturn;
617}
618
619#[derive(Debug, Clone, Copy)]
620pub struct Resumable<T>(T);
621
622impl<T> Resumable<T> {
623 pub fn new(val: T) -> Self {
624 Self(val)
625 }
626
627 pub fn consume(self) -> T {
628 self.0
629 }
630}
631
632impl<T, U, R, A> SIMDSchema<T, U, A> for Resumable<R>
633where
634 A: Architecture,
635 R: ResumableSIMDSchema<T, U, A>,
636{
637 type SIMDWidth = <R::NonResumable as SIMDSchema<T, U, A>>::SIMDWidth;
638 type Accumulator = <R::NonResumable as SIMDSchema<T, U, A>>::Accumulator;
639 type Left = <R::NonResumable as SIMDSchema<T, U, A>>::Left;
640 type Right = <R::NonResumable as SIMDSchema<T, U, A>>::Right;
641 type Return = Self;
642 type Main = <R::NonResumable as SIMDSchema<T, U, A>>::Main;
643
644 fn init(&self, arch: A) -> Self::Accumulator {
645 R::NonResumable::default().init(arch)
646 }
647
648 fn accumulate(
649 &self,
650 x: Self::Left,
651 y: Self::Right,
652 acc: Self::Accumulator,
653 ) -> Self::Accumulator {
654 R::NonResumable::default().accumulate(x, y, acc)
655 }
656
657 fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
658 R::NonResumable::default().combine(x, y)
659 }
660
661 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
662 Self(self.0.combine_with(x))
663 }
664}
665
666#[inline(never)]
667#[allow(clippy::panic)]
668fn emit_length_error(xlen: usize, ylen: usize) -> ! {
669 panic!(
670 "lengths must be equal, instead got: xlen = {}, ylen = {}",
671 xlen, ylen
672 )
673}
674
675#[inline(always)]
681pub fn simd_op<L, R, S, T, U, A>(schema: &S, arch: A, x: T, y: U) -> S::Return
682where
683 A: Architecture,
684 T: AsRef<[L]>,
685 U: AsRef<[R]>,
686 S: SIMDSchema<L, R, A>,
687{
688 let x: &[L] = x.as_ref();
689 let y: &[R] = y.as_ref();
690
691 let len = x.len();
692
693 if len != y.len() {
701 emit_length_error(len, y.len());
702 }
703 let px = x.as_ptr();
704 let py = y.as_ptr();
705
706 let simd_width: usize = S::get_simd_width();
716 let unroll: usize = S::get_main_bocksize();
717
718 let trip_count = len / (simd_width * unroll);
719 let epilogues = (len - simd_width * unroll * trip_count) / simd_width;
720
721 let loader: Loader<S, L, R, A> = Loader::new(arch, *schema, px, py, len);
724
725 let mut s0 = unsafe { <S as SIMDSchema<L, R, A>>::Main::main(&loader, trip_count, epilogues) };
729
730 let remainder = len % simd_width;
731 if remainder != 0 {
732 let i = len - remainder;
733
734 s0 = unsafe { schema.epilogue(arch, px.add(i), py.add(i), remainder, s0) };
739 }
740
741 schema.reduce(s0)
742}
743
744#[derive(Debug, Default, Clone, Copy)]
750pub struct L2;
751
752#[cfg(target_arch = "x86_64")]
753impl SIMDSchema<f32, f32, V4> for L2 {
754 type SIMDWidth = Const<8>;
755 type Accumulator = <V4 as Architecture>::f32x8;
756 type Left = <V4 as Architecture>::f32x8;
757 type Right = <V4 as Architecture>::f32x8;
758 type Return = f32;
759 type Main = Strategy4x1;
760
761 #[inline(always)]
762 fn init(&self, arch: V4) -> Self::Accumulator {
763 Self::Accumulator::default(arch)
764 }
765
766 #[inline(always)]
767 fn accumulate(
768 &self,
769 x: Self::Left,
770 y: Self::Right,
771 acc: Self::Accumulator,
772 ) -> Self::Accumulator {
773 let c = x - y;
774 c.mul_add_simd(c, acc)
775 }
776
777 #[inline(always)]
778 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
779 x.sum_tree()
780 }
781}
782
783#[cfg(target_arch = "x86_64")]
784impl SIMDSchema<f32, f32, V3> for L2 {
785 type SIMDWidth = Const<8>;
786 type Accumulator = <V3 as Architecture>::f32x8;
787 type Left = <V3 as Architecture>::f32x8;
788 type Right = <V3 as Architecture>::f32x8;
789 type Return = f32;
790 type Main = Strategy4x1;
791
792 #[inline(always)]
793 fn init(&self, arch: V3) -> Self::Accumulator {
794 Self::Accumulator::default(arch)
795 }
796
797 #[inline(always)]
798 fn accumulate(
799 &self,
800 x: Self::Left,
801 y: Self::Right,
802 acc: Self::Accumulator,
803 ) -> Self::Accumulator {
804 let c = x - y;
805 c.mul_add_simd(c, acc)
806 }
807
808 #[inline(always)]
809 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
810 x.sum_tree()
811 }
812}
813
814impl SIMDSchema<f32, f32, Scalar> for L2 {
815 type SIMDWidth = Const<4>;
816 type Accumulator = Emulated<f32, 4>;
817 type Left = Emulated<f32, 4>;
818 type Right = Emulated<f32, 4>;
819 type Return = f32;
820 type Main = Strategy2x1;
821
822 #[inline(always)]
823 fn init(&self, arch: Scalar) -> Self::Accumulator {
824 Self::Accumulator::default(arch)
825 }
826
827 #[inline(always)]
828 fn accumulate(
829 &self,
830 x: Self::Left,
831 y: Self::Right,
832 acc: Self::Accumulator,
833 ) -> Self::Accumulator {
834 let c = x - y;
836 (c * c) + acc
837 }
838
839 #[inline(always)]
840 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
841 x.sum_tree()
842 }
843
844 #[inline(always)]
845 unsafe fn epilogue(
846 &self,
847 arch: Scalar,
848 x: *const f32,
849 y: *const f32,
850 len: usize,
851 acc: Self::Accumulator,
852 ) -> Self::Accumulator {
853 let mut s: f32 = 0.0;
854 for i in 0..len {
855 let vx = unsafe { x.add(i).read() };
857 let vy = unsafe { y.add(i).read() };
859 let d = vx - vy;
860 s += d * d;
861 }
862 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
863 }
864}
865
866#[cfg(target_arch = "x86_64")]
867impl SIMDSchema<Half, Half, V4> for L2 {
868 type SIMDWidth = Const<8>;
869 type Accumulator = <V4 as Architecture>::f32x8;
870 type Left = <V4 as Architecture>::f16x8;
871 type Right = <V4 as Architecture>::f16x8;
872 type Return = f32;
873 type Main = Strategy2x4;
874
875 #[inline(always)]
876 fn init(&self, arch: V4) -> Self::Accumulator {
877 Self::Accumulator::default(arch)
878 }
879
880 #[inline(always)]
881 fn accumulate(
882 &self,
883 x: Self::Left,
884 y: Self::Right,
885 acc: Self::Accumulator,
886 ) -> Self::Accumulator {
887 diskann_wide::alias!(f32s = <V4>::f32x8);
888
889 let x: f32s = x.into();
890 let y: f32s = y.into();
891
892 let c = x - y;
893 c.mul_add_simd(c, acc)
894 }
895
896 #[inline(always)]
897 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
898 x.sum_tree()
899 }
900}
901
902#[cfg(target_arch = "x86_64")]
903impl SIMDSchema<Half, Half, V3> for L2 {
904 type SIMDWidth = Const<8>;
905 type Accumulator = <V3 as Architecture>::f32x8;
906 type Left = <V3 as Architecture>::f16x8;
907 type Right = <V3 as Architecture>::f16x8;
908 type Return = f32;
909 type Main = Strategy2x4;
910
911 #[inline(always)]
912 fn init(&self, arch: V3) -> Self::Accumulator {
913 Self::Accumulator::default(arch)
914 }
915
916 #[inline(always)]
917 fn accumulate(
918 &self,
919 x: Self::Left,
920 y: Self::Right,
921 acc: Self::Accumulator,
922 ) -> Self::Accumulator {
923 diskann_wide::alias!(f32s = <V3>::f32x8);
924
925 let x: f32s = x.into();
926 let y: f32s = y.into();
927
928 let c = x - y;
929 c.mul_add_simd(c, acc)
930 }
931
932 #[inline(always)]
934 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
935 x.sum_tree()
936 }
937}
938
939impl SIMDSchema<Half, Half, Scalar> for L2 {
940 type SIMDWidth = Const<1>;
941 type Accumulator = Emulated<f32, 1>;
942 type Left = Emulated<Half, 1>;
943 type Right = Emulated<Half, 1>;
944 type Return = f32;
945 type Main = Strategy1x1;
946
947 #[inline(always)]
948 fn init(&self, arch: Scalar) -> Self::Accumulator {
949 Self::Accumulator::default(arch)
950 }
951
952 #[inline(always)]
953 fn accumulate(
954 &self,
955 x: Self::Left,
956 y: Self::Right,
957 acc: Self::Accumulator,
958 ) -> Self::Accumulator {
959 let x: Self::Accumulator = x.into();
960 let y: Self::Accumulator = y.into();
961
962 let c = x - y;
963 acc + (c * c)
964 }
965
966 #[inline(always)]
967 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
968 x.to_array()[0]
969 }
970}
971
972impl<A> SIMDSchema<f32, Half, A> for L2
973where
974 A: Architecture,
975{
976 type SIMDWidth = Const<8>;
977 type Accumulator = A::f32x8;
978 type Left = A::f32x8;
979 type Right = A::f16x8;
980 type Return = f32;
981 type Main = Strategy4x2;
982
983 #[inline(always)]
984 fn init(&self, arch: A) -> Self::Accumulator {
985 Self::Accumulator::default(arch)
986 }
987
988 #[inline(always)]
989 fn accumulate(
990 &self,
991 x: Self::Left,
992 y: Self::Right,
993 acc: Self::Accumulator,
994 ) -> Self::Accumulator {
995 let y: A::f32x8 = y.into();
996 let c = x - y;
997 c.mul_add_simd(c, acc)
998 }
999
1000 #[inline(always)]
1002 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1003 x.sum_tree()
1004 }
1005}
1006
1007#[cfg(target_arch = "x86_64")]
1008impl SIMDSchema<i8, i8, V4> for L2 {
1009 type SIMDWidth = Const<32>;
1010 type Accumulator = <V4 as Architecture>::i32x16;
1011 type Left = <V4 as Architecture>::i8x32;
1012 type Right = <V4 as Architecture>::i8x32;
1013 type Return = f32;
1014 type Main = Strategy4x1;
1015
1016 #[inline(always)]
1017 fn init(&self, arch: V4) -> Self::Accumulator {
1018 Self::Accumulator::default(arch)
1019 }
1020
1021 #[inline(always)]
1022 fn accumulate(
1023 &self,
1024 x: Self::Left,
1025 y: Self::Right,
1026 acc: Self::Accumulator,
1027 ) -> Self::Accumulator {
1028 diskann_wide::alias!(i16s = <V4>::i16x32);
1029
1030 let x: i16s = x.into();
1031 let y: i16s = y.into();
1032 let c = x - y;
1033 acc.dot_simd(c, c)
1034 }
1035
1036 #[inline(always)]
1037 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1038 x.sum_tree().as_f32_lossy()
1039 }
1040}
1041
1042#[cfg(target_arch = "x86_64")]
1043impl SIMDSchema<i8, i8, V3> for L2 {
1044 type SIMDWidth = Const<16>;
1045 type Accumulator = <V3 as Architecture>::i32x8;
1046 type Left = <V3 as Architecture>::i8x16;
1047 type Right = <V3 as Architecture>::i8x16;
1048 type Return = f32;
1049 type Main = Strategy4x1;
1050
1051 #[inline(always)]
1052 fn init(&self, arch: V3) -> Self::Accumulator {
1053 Self::Accumulator::default(arch)
1054 }
1055
1056 #[inline(always)]
1057 fn accumulate(
1058 &self,
1059 x: Self::Left,
1060 y: Self::Right,
1061 acc: Self::Accumulator,
1062 ) -> Self::Accumulator {
1063 diskann_wide::alias!(i16s = <V3>::i16x16);
1064
1065 let x: i16s = x.into();
1066 let y: i16s = y.into();
1067 let c = x - y;
1068 acc.dot_simd(c, c)
1069 }
1070
1071 #[inline(always)]
1073 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1074 x.sum_tree().as_f32_lossy()
1075 }
1076}
1077
1078impl SIMDSchema<i8, i8, Scalar> for L2 {
1079 type SIMDWidth = Const<4>;
1080 type Accumulator = Emulated<i32, 4>;
1081 type Left = Emulated<i8, 4>;
1082 type Right = Emulated<i8, 4>;
1083 type Return = f32;
1084 type Main = Strategy1x1;
1085
1086 #[inline(always)]
1087 fn init(&self, arch: Scalar) -> Self::Accumulator {
1088 Self::Accumulator::default(arch)
1089 }
1090
1091 #[inline(always)]
1092 fn accumulate(
1093 &self,
1094 x: Self::Left,
1095 y: Self::Right,
1096 acc: Self::Accumulator,
1097 ) -> Self::Accumulator {
1098 let x: Self::Accumulator = x.into();
1099 let y: Self::Accumulator = y.into();
1100 let c = x - y;
1101 acc + c * c
1102 }
1103
1104 #[inline(always)]
1106 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1107 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1108 }
1109
1110 #[inline(always)]
1111 unsafe fn epilogue(
1112 &self,
1113 arch: Scalar,
1114 x: *const i8,
1115 y: *const i8,
1116 len: usize,
1117 acc: Self::Accumulator,
1118 ) -> Self::Accumulator {
1119 let mut s: i32 = 0;
1120 for i in 0..len {
1121 let vx: i32 = unsafe { x.add(i).read() }.into();
1123 let vy: i32 = unsafe { y.add(i).read() }.into();
1125 let d = vx - vy;
1126 s += d * d;
1127 }
1128 acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1129 }
1130}
1131
1132#[cfg(target_arch = "x86_64")]
1133impl SIMDSchema<u8, u8, V4> for L2 {
1134 type SIMDWidth = Const<32>;
1135 type Accumulator = <V4 as Architecture>::i32x16;
1136 type Left = <V4 as Architecture>::u8x32;
1137 type Right = <V4 as Architecture>::u8x32;
1138 type Return = f32;
1139 type Main = Strategy4x1;
1140
1141 #[inline(always)]
1142 fn init(&self, arch: V4) -> Self::Accumulator {
1143 Self::Accumulator::default(arch)
1144 }
1145
1146 #[inline(always)]
1147 fn accumulate(
1148 &self,
1149 x: Self::Left,
1150 y: Self::Right,
1151 acc: Self::Accumulator,
1152 ) -> Self::Accumulator {
1153 diskann_wide::alias!(i16s = <V4>::i16x32);
1154
1155 let x: i16s = x.into();
1156 let y: i16s = y.into();
1157 let c = x - y;
1158 acc.dot_simd(c, c)
1159 }
1160
1161 #[inline(always)]
1162 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1163 x.sum_tree().as_f32_lossy()
1164 }
1165}
1166
1167#[cfg(target_arch = "x86_64")]
1168impl SIMDSchema<u8, u8, V3> for L2 {
1169 type SIMDWidth = Const<16>;
1170 type Accumulator = <V3 as Architecture>::i32x8;
1171 type Left = <V3 as Architecture>::u8x16;
1172 type Right = <V3 as Architecture>::u8x16;
1173 type Return = f32;
1174 type Main = Strategy4x1;
1175
1176 #[inline(always)]
1177 fn init(&self, arch: V3) -> Self::Accumulator {
1178 Self::Accumulator::default(arch)
1179 }
1180
1181 #[inline(always)]
1182 fn accumulate(
1183 &self,
1184 x: Self::Left,
1185 y: Self::Right,
1186 acc: Self::Accumulator,
1187 ) -> Self::Accumulator {
1188 diskann_wide::alias!(i16s = <V3>::i16x16);
1189
1190 let x: i16s = x.into();
1191 let y: i16s = y.into();
1192 let c = x - y;
1193 acc.dot_simd(c, c)
1194 }
1195
1196 #[inline(always)]
1198 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1199 x.sum_tree().as_f32_lossy()
1200 }
1201}
1202
1203impl SIMDSchema<u8, u8, Scalar> for L2 {
1204 type SIMDWidth = Const<4>;
1205 type Accumulator = Emulated<i32, 4>;
1206 type Left = Emulated<u8, 4>;
1207 type Right = Emulated<u8, 4>;
1208 type Return = f32;
1209 type Main = Strategy1x1;
1210
1211 #[inline(always)]
1212 fn init(&self, arch: Scalar) -> Self::Accumulator {
1213 Self::Accumulator::default(arch)
1214 }
1215
1216 #[inline(always)]
1217 fn accumulate(
1218 &self,
1219 x: Self::Left,
1220 y: Self::Right,
1221 acc: Self::Accumulator,
1222 ) -> Self::Accumulator {
1223 let x: Self::Accumulator = x.into();
1224 let y: Self::Accumulator = y.into();
1225 let c = x - y;
1226 acc + c * c
1227 }
1228
1229 #[inline(always)]
1231 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1232 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1233 }
1234
1235 #[inline(always)]
1236 unsafe fn epilogue(
1237 &self,
1238 arch: Scalar,
1239 x: *const u8,
1240 y: *const u8,
1241 len: usize,
1242 acc: Self::Accumulator,
1243 ) -> Self::Accumulator {
1244 let mut s: i32 = 0;
1245 for i in 0..len {
1246 let vx: i32 = unsafe { x.add(i).read() }.into();
1248 let vy: i32 = unsafe { y.add(i).read() }.into();
1250 let d = vx - vy;
1251 s += d * d;
1252 }
1253 acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1254 }
1255}
1256
1257#[derive(Clone, Copy, Debug)]
1260pub struct ResumableL2<A = diskann_wide::arch::Current>
1261where
1262 A: Architecture,
1263 L2: SIMDSchema<f32, f32, A>,
1264{
1265 acc: <L2 as SIMDSchema<f32, f32, A>>::Accumulator,
1266}
1267
1268impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableL2<A>
1269where
1270 A: Architecture,
1271 L2: SIMDSchema<f32, f32, A, Return = f32>,
1272{
1273 type NonResumable = L2;
1274 type FinalReturn = f32;
1275
1276 #[inline(always)]
1277 fn init(arch: A) -> Self {
1278 Self { acc: L2.init(arch) }
1279 }
1280
1281 #[inline(always)]
1282 fn combine_with(&self, other: <L2 as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1283 Self {
1284 acc: self.acc + other,
1285 }
1286 }
1287
1288 #[inline(always)]
1289 fn sum(&self) -> f32 {
1290 L2.reduce(self.acc)
1291 }
1292}
1293
1294#[derive(Clone, Copy, Debug, Default)]
1300pub struct IP;
1301
1302#[cfg(target_arch = "x86_64")]
1303impl SIMDSchema<f32, f32, V4> for IP {
1304 type SIMDWidth = Const<8>;
1305 type Accumulator = <V4 as Architecture>::f32x8;
1306 type Left = <V4 as Architecture>::f32x8;
1307 type Right = <V4 as Architecture>::f32x8;
1308 type Return = f32;
1309 type Main = Strategy4x1;
1310
1311 #[inline(always)]
1312 fn init(&self, arch: V4) -> Self::Accumulator {
1313 Self::Accumulator::default(arch)
1314 }
1315
1316 #[inline(always)]
1317 fn accumulate(
1318 &self,
1319 x: Self::Left,
1320 y: Self::Right,
1321 acc: Self::Accumulator,
1322 ) -> Self::Accumulator {
1323 x.mul_add_simd(y, acc)
1324 }
1325
1326 #[inline(always)]
1327 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1328 x.sum_tree()
1329 }
1330}
1331
1332#[cfg(target_arch = "x86_64")]
1333impl SIMDSchema<f32, f32, V3> for IP {
1334 type SIMDWidth = Const<8>;
1335 type Accumulator = <V3 as Architecture>::f32x8;
1336 type Left = <V3 as Architecture>::f32x8;
1337 type Right = <V3 as Architecture>::f32x8;
1338 type Return = f32;
1339 type Main = Strategy4x1;
1340
1341 #[inline(always)]
1342 fn init(&self, arch: V3) -> Self::Accumulator {
1343 Self::Accumulator::default(arch)
1344 }
1345
1346 #[inline(always)]
1347 fn accumulate(
1348 &self,
1349 x: Self::Left,
1350 y: Self::Right,
1351 acc: Self::Accumulator,
1352 ) -> Self::Accumulator {
1353 x.mul_add_simd(y, acc)
1354 }
1355
1356 #[inline(always)]
1358 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1359 x.sum_tree()
1360 }
1361}
1362
1363impl SIMDSchema<f32, f32, Scalar> for IP {
1364 type SIMDWidth = Const<4>;
1365 type Accumulator = Emulated<f32, 4>;
1366 type Left = Emulated<f32, 4>;
1367 type Right = Emulated<f32, 4>;
1368 type Return = f32;
1369 type Main = Strategy2x1;
1370
1371 #[inline(always)]
1372 fn init(&self, arch: Scalar) -> Self::Accumulator {
1373 Self::Accumulator::default(arch)
1374 }
1375
1376 #[inline(always)]
1377 fn accumulate(
1378 &self,
1379 x: Self::Left,
1380 y: Self::Right,
1381 acc: Self::Accumulator,
1382 ) -> Self::Accumulator {
1383 x * y + acc
1384 }
1385
1386 #[inline(always)]
1388 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1389 x.sum_tree()
1390 }
1391
1392 #[inline(always)]
1393 unsafe fn epilogue(
1394 &self,
1395 arch: Scalar,
1396 x: *const f32,
1397 y: *const f32,
1398 len: usize,
1399 acc: Self::Accumulator,
1400 ) -> Self::Accumulator {
1401 let mut s: f32 = 0.0;
1402 for i in 0..len {
1403 let vx = unsafe { x.add(i).read() };
1405 let vy = unsafe { y.add(i).read() };
1407 s += vx * vy;
1408 }
1409 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
1410 }
1411}
1412
1413#[cfg(target_arch = "x86_64")]
1414impl SIMDSchema<Half, Half, V4> for IP {
1415 type SIMDWidth = Const<8>;
1416 type Accumulator = <V4 as Architecture>::f32x8;
1417 type Left = <V4 as Architecture>::f16x8;
1418 type Right = <V4 as Architecture>::f16x8;
1419 type Return = f32;
1420 type Main = Strategy4x1;
1421
1422 #[inline(always)]
1423 fn init(&self, arch: V4) -> Self::Accumulator {
1424 Self::Accumulator::default(arch)
1425 }
1426
1427 #[inline(always)]
1428 fn accumulate(
1429 &self,
1430 x: Self::Left,
1431 y: Self::Right,
1432 acc: Self::Accumulator,
1433 ) -> Self::Accumulator {
1434 diskann_wide::alias!(f32s = <V4>::f32x8);
1435
1436 let x: f32s = x.into();
1437 let y: f32s = y.into();
1438 x.mul_add_simd(y, acc)
1439 }
1440
1441 #[inline(always)]
1442 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1443 x.sum_tree()
1444 }
1445}
1446
1447#[cfg(target_arch = "x86_64")]
1448impl SIMDSchema<Half, Half, V3> for IP {
1449 type SIMDWidth = Const<8>;
1450 type Accumulator = <V3 as Architecture>::f32x8;
1451 type Left = <V3 as Architecture>::f16x8;
1452 type Right = <V3 as Architecture>::f16x8;
1453 type Return = f32;
1454 type Main = Strategy2x4;
1455
1456 #[inline(always)]
1457 fn init(&self, arch: V3) -> Self::Accumulator {
1458 Self::Accumulator::default(arch)
1459 }
1460
1461 #[inline(always)]
1462 fn accumulate(
1463 &self,
1464 x: Self::Left,
1465 y: Self::Right,
1466 acc: Self::Accumulator,
1467 ) -> Self::Accumulator {
1468 diskann_wide::alias!(f32s = <V3>::f32x8);
1469
1470 let x: f32s = x.into();
1471 let y: f32s = y.into();
1472 x.mul_add_simd(y, acc)
1473 }
1474
1475 #[inline(always)]
1477 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1478 x.sum_tree()
1479 }
1480}
1481
1482impl SIMDSchema<Half, Half, Scalar> for IP {
1483 type SIMDWidth = Const<1>;
1484 type Accumulator = Emulated<f32, 1>;
1485 type Left = Emulated<Half, 1>;
1486 type Right = Emulated<Half, 1>;
1487 type Return = f32;
1488 type Main = Strategy1x1;
1489
1490 #[inline(always)]
1491 fn init(&self, arch: Scalar) -> Self::Accumulator {
1492 Self::Accumulator::default(arch)
1493 }
1494
1495 #[inline(always)]
1496 fn accumulate(
1497 &self,
1498 x: Self::Left,
1499 y: Self::Right,
1500 acc: Self::Accumulator,
1501 ) -> Self::Accumulator {
1502 let x: Self::Accumulator = x.into();
1503 let y: Self::Accumulator = y.into();
1504 x * y + acc
1505 }
1506
1507 #[inline(always)]
1508 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1509 x.to_array()[0]
1510 }
1511}
1512
1513impl<A> SIMDSchema<f32, Half, A> for IP
1514where
1515 A: Architecture,
1516{
1517 type SIMDWidth = Const<8>;
1518 type Accumulator = A::f32x8;
1519 type Left = A::f32x8;
1520 type Right = A::f16x8;
1521 type Return = f32;
1522 type Main = Strategy4x2;
1523
1524 #[inline(always)]
1525 fn init(&self, arch: A) -> Self::Accumulator {
1526 Self::Accumulator::default(arch)
1527 }
1528
1529 #[inline(always)]
1530 fn accumulate(
1531 &self,
1532 x: Self::Left,
1533 y: Self::Right,
1534 acc: Self::Accumulator,
1535 ) -> Self::Accumulator {
1536 let y: A::f32x8 = y.into();
1537 x.mul_add_simd(y, acc)
1538 }
1539
1540 #[inline(always)]
1542 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1543 x.sum_tree()
1544 }
1545}
1546
1547#[cfg(target_arch = "x86_64")]
1548impl SIMDSchema<i8, i8, V4> for IP {
1549 type SIMDWidth = Const<32>;
1550 type Accumulator = <V4 as Architecture>::i32x16;
1551 type Left = <V4 as Architecture>::i8x32;
1552 type Right = <V4 as Architecture>::i8x32;
1553 type Return = f32;
1554 type Main = Strategy4x1;
1555
1556 #[inline(always)]
1557 fn init(&self, arch: V4) -> Self::Accumulator {
1558 Self::Accumulator::default(arch)
1559 }
1560
1561 #[inline(always)]
1562 fn accumulate(
1563 &self,
1564 x: Self::Left,
1565 y: Self::Right,
1566 acc: Self::Accumulator,
1567 ) -> Self::Accumulator {
1568 diskann_wide::alias!(i16s = <V4>::i16x32);
1569
1570 let x: i16s = x.into();
1571 let y: i16s = y.into();
1572 acc.dot_simd(x, y)
1573 }
1574
1575 #[inline(always)]
1576 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1577 x.sum_tree().as_f32_lossy()
1578 }
1579}
1580
1581#[cfg(target_arch = "x86_64")]
1582impl SIMDSchema<i8, i8, V3> for IP {
1583 type SIMDWidth = Const<16>;
1584 type Accumulator = <V3 as Architecture>::i32x8;
1585 type Left = <V3 as Architecture>::i8x16;
1586 type Right = <V3 as Architecture>::i8x16;
1587 type Return = f32;
1588 type Main = Strategy4x1;
1589
1590 #[inline(always)]
1591 fn init(&self, arch: V3) -> Self::Accumulator {
1592 Self::Accumulator::default(arch)
1593 }
1594
1595 #[inline(always)]
1596 fn accumulate(
1597 &self,
1598 x: Self::Left,
1599 y: Self::Right,
1600 acc: Self::Accumulator,
1601 ) -> Self::Accumulator {
1602 diskann_wide::alias!(i16s = <V3>::i16x16);
1603
1604 let x: i16s = x.into();
1605 let y: i16s = y.into();
1606 acc.dot_simd(x, y)
1607 }
1608
1609 #[inline(always)]
1611 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1612 x.sum_tree().as_f32_lossy()
1613 }
1614}
1615
1616impl SIMDSchema<i8, i8, Scalar> for IP {
1617 type SIMDWidth = Const<1>;
1618 type Accumulator = Emulated<i32, 1>;
1619 type Left = Emulated<i8, 1>;
1620 type Right = Emulated<i8, 1>;
1621 type Return = f32;
1622 type Main = Strategy1x1;
1623
1624 #[inline(always)]
1625 fn init(&self, arch: Scalar) -> Self::Accumulator {
1626 Self::Accumulator::default(arch)
1627 }
1628
1629 #[inline(always)]
1630 fn accumulate(
1631 &self,
1632 x: Self::Left,
1633 y: Self::Right,
1634 acc: Self::Accumulator,
1635 ) -> Self::Accumulator {
1636 let x: Self::Accumulator = x.into();
1637 let y: Self::Accumulator = y.into();
1638 x * y + acc
1639 }
1640
1641 #[inline(always)]
1643 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1644 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1645 }
1646
1647 #[inline(always)]
1648 unsafe fn epilogue(
1649 &self,
1650 _arch: Scalar,
1651 _x: *const i8,
1652 _y: *const i8,
1653 _len: usize,
1654 _acc: Self::Accumulator,
1655 ) -> Self::Accumulator {
1656 unreachable!("The SIMD width is 1, so there should be no epilogue")
1657 }
1658}
1659
1660#[cfg(target_arch = "x86_64")]
1661impl SIMDSchema<u8, u8, V4> for IP {
1662 type SIMDWidth = Const<32>;
1663 type Accumulator = <V4 as Architecture>::i32x16;
1664 type Left = <V4 as Architecture>::u8x32;
1665 type Right = <V4 as Architecture>::u8x32;
1666 type Return = f32;
1667 type Main = Strategy4x1;
1668
1669 #[inline(always)]
1670 fn init(&self, arch: V4) -> Self::Accumulator {
1671 Self::Accumulator::default(arch)
1672 }
1673
1674 #[inline(always)]
1675 fn accumulate(
1676 &self,
1677 x: Self::Left,
1678 y: Self::Right,
1679 acc: Self::Accumulator,
1680 ) -> Self::Accumulator {
1681 diskann_wide::alias!(i16s = <V4>::i16x32);
1682
1683 let x: i16s = x.into();
1684 let y: i16s = y.into();
1685 acc.dot_simd(x, y)
1686 }
1687
1688 #[inline(always)]
1689 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1690 x.sum_tree().as_f32_lossy()
1691 }
1692}
1693
1694#[cfg(target_arch = "x86_64")]
1695impl SIMDSchema<u8, u8, V3> for IP {
1696 type SIMDWidth = Const<16>;
1697 type Accumulator = <V3 as Architecture>::i32x8;
1698 type Left = <V3 as Architecture>::u8x16;
1699 type Right = <V3 as Architecture>::u8x16;
1700 type Return = f32;
1701 type Main = Strategy4x1;
1702
1703 #[inline(always)]
1704 fn init(&self, arch: V3) -> Self::Accumulator {
1705 Self::Accumulator::default(arch)
1706 }
1707
1708 #[inline(always)]
1709 fn accumulate(
1710 &self,
1711 x: Self::Left,
1712 y: Self::Right,
1713 acc: Self::Accumulator,
1714 ) -> Self::Accumulator {
1715 diskann_wide::alias!(i16s = <V3>::i16x16);
1716
1717 let x: i16s = x.into();
1720 let y: i16s = y.into();
1721 acc.dot_simd(x, y)
1722 }
1723
1724 #[inline(always)]
1726 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1727 x.sum_tree().as_f32_lossy()
1728 }
1729}
1730
1731impl SIMDSchema<u8, u8, Scalar> for IP {
1732 type SIMDWidth = Const<1>;
1733 type Accumulator = Emulated<i32, 1>;
1734 type Left = Emulated<u8, 1>;
1735 type Right = Emulated<u8, 1>;
1736 type Return = f32;
1737 type Main = Strategy1x1;
1738
1739 #[inline(always)]
1740 fn init(&self, arch: Scalar) -> Self::Accumulator {
1741 Self::Accumulator::default(arch)
1742 }
1743
1744 #[inline(always)]
1745 fn accumulate(
1746 &self,
1747 x: Self::Left,
1748 y: Self::Right,
1749 acc: Self::Accumulator,
1750 ) -> Self::Accumulator {
1751 let x: Self::Accumulator = x.into();
1752 let y: Self::Accumulator = y.into();
1753 x * y + acc
1754 }
1755
1756 #[inline(always)]
1758 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1759 x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1760 }
1761
1762 #[inline(always)]
1763 unsafe fn epilogue(
1764 &self,
1765 _arch: Scalar,
1766 _x: *const u8,
1767 _y: *const u8,
1768 _len: usize,
1769 _acc: Self::Accumulator,
1770 ) -> Self::Accumulator {
1771 unreachable!("The SIMD width is 1, so there should be no epilogue")
1772 }
1773}
1774
1775#[derive(Clone, Copy, Debug)]
1777pub struct ResumableIP<A = diskann_wide::arch::Current>
1778where
1779 A: Architecture,
1780 IP: SIMDSchema<f32, f32, A>,
1781{
1782 acc: <IP as SIMDSchema<f32, f32, A>>::Accumulator,
1783}
1784
1785impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableIP<A>
1786where
1787 A: Architecture,
1788 IP: SIMDSchema<f32, f32, A, Return = f32>,
1789{
1790 type NonResumable = IP;
1791 type FinalReturn = f32;
1792
1793 #[inline(always)]
1794 fn init(arch: A) -> Self {
1795 Self { acc: IP.init(arch) }
1796 }
1797
1798 #[inline(always)]
1799 fn combine_with(&self, other: <IP as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1800 Self {
1801 acc: self.acc + other,
1802 }
1803 }
1804
1805 #[inline(always)]
1806 fn sum(&self) -> f32 {
1807 IP.reduce(self.acc)
1808 }
1809}
1810
1811#[derive(Debug, Clone, Copy)]
1818pub struct FullCosineAccumulator<T> {
1819 normx: T,
1820 normy: T,
1821 xy: T,
1822}
1823
1824impl<T> FullCosineAccumulator<T>
1825where
1826 T: SIMDVector
1827 + SIMDSumTree
1828 + SIMDMulAdd
1829 + std::ops::Mul<Output = T>
1830 + std::ops::Add<Output = T>,
1831 T::Scalar: LossyF32Conversion,
1832{
1833 #[inline(always)]
1834 pub fn new(arch: T::Arch) -> Self {
1835 let zero = T::default(arch);
1837 Self {
1838 normx: zero,
1839 normy: zero,
1840 xy: zero,
1841 }
1842 }
1843
1844 #[inline(always)]
1845 pub fn add_with(&self, x: T, y: T) -> Self {
1846 FullCosineAccumulator {
1848 normx: x.mul_add_simd(x, self.normx),
1849 normy: y.mul_add_simd(y, self.normy),
1850 xy: x.mul_add_simd(y, self.xy),
1851 }
1852 }
1853
1854 #[inline(always)]
1855 pub fn add_with_unfused(&self, x: T, y: T) -> Self {
1856 FullCosineAccumulator {
1858 normx: x * x + self.normx,
1859 normy: y * y + self.normy,
1860 xy: x * y + self.xy,
1861 }
1862 }
1863
1864 #[inline(always)]
1865 pub fn sum(&self) -> f32 {
1866 let normx = self.normx.sum_tree().as_f32_lossy();
1867 let normy = self.normy.sum_tree().as_f32_lossy();
1868
1869 let denominator = normx.sqrt() * normy.sqrt();
1876 let prod = self.xy.sum_tree().as_f32_lossy();
1877
1878 force_eval(denominator);
1886 force_eval(prod);
1887
1888 if normx < f32::MIN_POSITIVE || normy < f32::MIN_POSITIVE {
1896 return 0.0;
1897 }
1898
1899 let v = prod / denominator;
1900 (-1.0f32).max(1.0f32.min(v))
1901 }
1902
1903 #[inline(always)]
1905 pub fn sum_as_l2(&self) -> f32 {
1906 let normx = self.normx.sum_tree().as_f32_lossy();
1907 let normy = self.normy.sum_tree().as_f32_lossy();
1908 let xy = self.xy.sum_tree().as_f32_lossy();
1909 normx + normy - (xy + xy)
1910 }
1911}
1912
1913impl<T> std::ops::Add for FullCosineAccumulator<T>
1914where
1915 T: std::ops::Add<Output = T>,
1916{
1917 type Output = Self;
1918 #[inline(always)]
1919 fn add(self, other: Self) -> Self {
1920 FullCosineAccumulator {
1921 normx: self.normx + other.normx,
1922 normy: self.normy + other.normy,
1923 xy: self.xy + other.xy,
1924 }
1925 }
1926}
1927
1928#[derive(Default, Clone, Copy)]
1930pub struct CosineStateless;
1931
1932#[cfg(target_arch = "x86_64")]
1933impl SIMDSchema<f32, f32, V4> for CosineStateless {
1934 type SIMDWidth = Const<16>;
1935 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
1936 type Left = <V4 as Architecture>::f32x16;
1937 type Right = <V4 as Architecture>::f32x16;
1938 type Return = f32;
1939
1940 type Main = Strategy2x4;
1943
1944 #[inline(always)]
1945 fn init(&self, arch: V4) -> Self::Accumulator {
1946 Self::Accumulator::new(arch)
1947 }
1948
1949 #[inline(always)]
1950 fn accumulate(
1951 &self,
1952 x: Self::Left,
1953 y: Self::Right,
1954 acc: Self::Accumulator,
1955 ) -> Self::Accumulator {
1956 acc.add_with(x, y)
1957 }
1958
1959 #[inline(always)]
1961 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
1962 acc.sum()
1963 }
1964}
1965
1966#[cfg(target_arch = "x86_64")]
1967impl SIMDSchema<f32, f32, V3> for CosineStateless {
1968 type SIMDWidth = Const<8>;
1969 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
1970 type Left = <V3 as Architecture>::f32x8;
1971 type Right = <V3 as Architecture>::f32x8;
1972 type Return = f32;
1973
1974 type Main = Strategy2x4;
1977
1978 #[inline(always)]
1979 fn init(&self, arch: V3) -> Self::Accumulator {
1980 Self::Accumulator::new(arch)
1981 }
1982
1983 #[inline(always)]
1984 fn accumulate(
1985 &self,
1986 x: Self::Left,
1987 y: Self::Right,
1988 acc: Self::Accumulator,
1989 ) -> Self::Accumulator {
1990 acc.add_with(x, y)
1991 }
1992
1993 #[inline(always)]
1995 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
1996 acc.sum()
1997 }
1998}
1999
2000impl SIMDSchema<f32, f32, Scalar> for CosineStateless {
2001 type SIMDWidth = Const<4>;
2002 type Accumulator = FullCosineAccumulator<Emulated<f32, 4>>;
2003 type Left = Emulated<f32, 4>;
2004 type Right = Emulated<f32, 4>;
2005 type Return = f32;
2006
2007 type Main = Strategy2x1;
2008
2009 #[inline(always)]
2010 fn init(&self, arch: Scalar) -> Self::Accumulator {
2011 Self::Accumulator::new(arch)
2012 }
2013
2014 #[inline(always)]
2015 fn accumulate(
2016 &self,
2017 x: Self::Left,
2018 y: Self::Right,
2019 acc: Self::Accumulator,
2020 ) -> Self::Accumulator {
2021 acc.add_with_unfused(x, y)
2022 }
2023
2024 #[inline(always)]
2025 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2026 acc.sum()
2027 }
2028}
2029
2030#[cfg(target_arch = "x86_64")]
2031impl SIMDSchema<Half, Half, V4> for CosineStateless {
2032 type SIMDWidth = Const<16>;
2033 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2034 type Left = <V4 as Architecture>::f16x16;
2035 type Right = <V4 as Architecture>::f16x16;
2036 type Return = f32;
2037 type Main = Strategy2x4;
2038
2039 #[inline(always)]
2040 fn init(&self, arch: V4) -> Self::Accumulator {
2041 Self::Accumulator::new(arch)
2042 }
2043
2044 #[inline(always)]
2045 fn accumulate(
2046 &self,
2047 x: Self::Left,
2048 y: Self::Right,
2049 acc: Self::Accumulator,
2050 ) -> Self::Accumulator {
2051 diskann_wide::alias!(f32s = <V4>::f32x16);
2052
2053 let x: f32s = x.into();
2054 let y: f32s = y.into();
2055 acc.add_with(x, y)
2056 }
2057
2058 #[inline(always)]
2059 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2060 acc.sum()
2061 }
2062}
2063
2064#[cfg(target_arch = "x86_64")]
2065impl SIMDSchema<Half, Half, V3> for CosineStateless {
2066 type SIMDWidth = Const<8>;
2067 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2068 type Left = <V3 as Architecture>::f16x8;
2069 type Right = <V3 as Architecture>::f16x8;
2070 type Return = f32;
2071 type Main = Strategy2x4;
2072
2073 #[inline(always)]
2074 fn init(&self, arch: V3) -> Self::Accumulator {
2075 Self::Accumulator::new(arch)
2076 }
2077
2078 #[inline(always)]
2079 fn accumulate(
2080 &self,
2081 x: Self::Left,
2082 y: Self::Right,
2083 acc: Self::Accumulator,
2084 ) -> Self::Accumulator {
2085 diskann_wide::alias!(f32s = <V3>::f32x8);
2086
2087 let x: f32s = x.into();
2088 let y: f32s = y.into();
2089 acc.add_with(x, y)
2090 }
2091
2092 #[inline(always)]
2094 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2095 acc.sum()
2096 }
2097}
2098
2099impl SIMDSchema<Half, Half, Scalar> for CosineStateless {
2100 type SIMDWidth = Const<1>;
2101 type Accumulator = FullCosineAccumulator<Emulated<f32, 1>>;
2102 type Left = Emulated<Half, 1>;
2103 type Right = Emulated<Half, 1>;
2104 type Return = f32;
2105 type Main = Strategy1x1;
2106
2107 #[inline(always)]
2108 fn init(&self, arch: Scalar) -> Self::Accumulator {
2109 Self::Accumulator::new(arch)
2110 }
2111
2112 #[inline(always)]
2113 fn accumulate(
2114 &self,
2115 x: Self::Left,
2116 y: Self::Right,
2117 acc: Self::Accumulator,
2118 ) -> Self::Accumulator {
2119 let x: Emulated<f32, 1> = x.into();
2120 let y: Emulated<f32, 1> = y.into();
2121 acc.add_with_unfused(x, y)
2122 }
2123
2124 #[inline(always)]
2125 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2126 acc.sum()
2127 }
2128}
2129impl<A> SIMDSchema<f32, Half, A> for CosineStateless
2130where
2131 A: Architecture,
2132{
2133 type SIMDWidth = Const<8>;
2134 type Accumulator = FullCosineAccumulator<A::f32x8>;
2135 type Left = A::f32x8;
2136 type Right = A::f16x8;
2137 type Return = f32;
2138 type Main = Strategy2x4;
2139
2140 #[inline(always)]
2141 fn init(&self, arch: A) -> Self::Accumulator {
2142 Self::Accumulator::new(arch)
2143 }
2144
2145 #[inline(always)]
2146 fn accumulate(
2147 &self,
2148 x: Self::Left,
2149 y: Self::Right,
2150 acc: Self::Accumulator,
2151 ) -> Self::Accumulator {
2152 let y: A::f32x8 = y.into();
2153 acc.add_with(x, y)
2154 }
2155
2156 #[inline(always)]
2157 fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2158 acc.sum()
2159 }
2160}
2161
2162#[cfg(target_arch = "x86_64")]
2163impl SIMDSchema<i8, i8, V4> for CosineStateless {
2164 type SIMDWidth = Const<32>;
2165 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2166 type Left = <V4 as Architecture>::i8x32;
2167 type Right = <V4 as Architecture>::i8x32;
2168 type Return = f32;
2169 type Main = Strategy4x1;
2170
2171 #[inline(always)]
2172 fn init(&self, arch: V4) -> Self::Accumulator {
2173 Self::Accumulator::new(arch)
2174 }
2175
2176 #[inline(always)]
2177 fn accumulate(
2178 &self,
2179 x: Self::Left,
2180 y: Self::Right,
2181 acc: Self::Accumulator,
2182 ) -> Self::Accumulator {
2183 diskann_wide::alias!(i16s = <V4>::i16x32);
2184
2185 let x: i16s = x.into();
2186 let y: i16s = y.into();
2187
2188 FullCosineAccumulator {
2189 normx: acc.normx.dot_simd(x, x),
2190 normy: acc.normy.dot_simd(y, y),
2191 xy: acc.xy.dot_simd(x, y),
2192 }
2193 }
2194
2195 #[inline(always)]
2197 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2198 x.sum()
2199 }
2200}
2201
2202#[cfg(target_arch = "x86_64")]
2203impl SIMDSchema<i8, i8, V3> for CosineStateless {
2204 type SIMDWidth = Const<16>;
2205 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2206 type Left = <V3 as Architecture>::i8x16;
2207 type Right = <V3 as Architecture>::i8x16;
2208 type Return = f32;
2209 type Main = Strategy4x1;
2210
2211 #[inline(always)]
2212 fn init(&self, arch: V3) -> Self::Accumulator {
2213 Self::Accumulator::new(arch)
2214 }
2215
2216 #[inline(always)]
2217 fn accumulate(
2218 &self,
2219 x: Self::Left,
2220 y: Self::Right,
2221 acc: Self::Accumulator,
2222 ) -> Self::Accumulator {
2223 diskann_wide::alias!(i16s = <V3>::i16x16);
2224
2225 let x: i16s = x.into();
2226 let y: i16s = y.into();
2227
2228 FullCosineAccumulator {
2229 normx: acc.normx.dot_simd(x, x),
2230 normy: acc.normy.dot_simd(y, y),
2231 xy: acc.xy.dot_simd(x, y),
2232 }
2233 }
2234
2235 #[inline(always)]
2237 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2238 x.sum()
2239 }
2240}
2241
2242impl SIMDSchema<i8, i8, Scalar> for CosineStateless {
2243 type SIMDWidth = Const<4>;
2244 type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2245 type Left = Emulated<i8, 4>;
2246 type Right = Emulated<i8, 4>;
2247 type Return = f32;
2248 type Main = Strategy1x1;
2249
2250 #[inline(always)]
2251 fn init(&self, arch: Scalar) -> Self::Accumulator {
2252 Self::Accumulator::new(arch)
2253 }
2254
2255 #[inline(always)]
2256 fn accumulate(
2257 &self,
2258 x: Self::Left,
2259 y: Self::Right,
2260 acc: Self::Accumulator,
2261 ) -> Self::Accumulator {
2262 let x: Emulated<i32, 4> = x.into();
2263 let y: Emulated<i32, 4> = y.into();
2264 acc.add_with(x, y)
2265 }
2266
2267 #[inline(always)]
2269 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2270 x.sum()
2271 }
2272
2273 #[inline(always)]
2274 unsafe fn epilogue(
2275 &self,
2276 arch: Scalar,
2277 x: *const i8,
2278 y: *const i8,
2279 len: usize,
2280 acc: Self::Accumulator,
2281 ) -> Self::Accumulator {
2282 let mut xy: i32 = 0;
2283 let mut xx: i32 = 0;
2284 let mut yy: i32 = 0;
2285
2286 for i in 0..len {
2287 let vx: i32 = unsafe { x.add(i).read() }.into();
2289 let vy: i32 = unsafe { y.add(i).read() }.into();
2291
2292 xx += vx * vx;
2293 xy += vx * vy;
2294 yy += vy * vy;
2295 }
2296
2297 acc + FullCosineAccumulator {
2298 normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2299 normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2300 xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2301 }
2302 }
2303}
2304
2305#[cfg(target_arch = "x86_64")]
2306impl SIMDSchema<u8, u8, V4> for CosineStateless {
2307 type SIMDWidth = Const<32>;
2308 type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2309 type Left = <V4 as Architecture>::u8x32;
2310 type Right = <V4 as Architecture>::u8x32;
2311 type Return = f32;
2312 type Main = Strategy4x1;
2313
2314 #[inline(always)]
2315 fn init(&self, arch: V4) -> Self::Accumulator {
2316 Self::Accumulator::new(arch)
2317 }
2318
2319 #[inline(always)]
2320 fn accumulate(
2321 &self,
2322 x: Self::Left,
2323 y: Self::Right,
2324 acc: Self::Accumulator,
2325 ) -> Self::Accumulator {
2326 diskann_wide::alias!(i16s = <V4>::i16x32);
2327
2328 let x: i16s = x.into();
2329 let y: i16s = y.into();
2330
2331 FullCosineAccumulator {
2332 normx: acc.normx.dot_simd(x, x),
2333 normy: acc.normy.dot_simd(y, y),
2334 xy: acc.xy.dot_simd(x, y),
2335 }
2336 }
2337
2338 #[inline(always)]
2340 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2341 x.sum()
2342 }
2343}
2344
2345#[cfg(target_arch = "x86_64")]
2346impl SIMDSchema<u8, u8, V3> for CosineStateless {
2347 type SIMDWidth = Const<16>;
2348 type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2349 type Left = <V3 as Architecture>::u8x16;
2350 type Right = <V3 as Architecture>::u8x16;
2351 type Return = f32;
2352 type Main = Strategy4x1;
2353
2354 #[inline(always)]
2355 fn init(&self, arch: V3) -> Self::Accumulator {
2356 Self::Accumulator::new(arch)
2357 }
2358
2359 #[inline(always)]
2360 fn accumulate(
2361 &self,
2362 x: Self::Left,
2363 y: Self::Right,
2364 acc: Self::Accumulator,
2365 ) -> Self::Accumulator {
2366 diskann_wide::alias!(i16s = <V3>::i16x16);
2367
2368 let x: i16s = x.into();
2369 let y: i16s = y.into();
2370
2371 FullCosineAccumulator {
2372 normx: acc.normx.dot_simd(x, x),
2373 normy: acc.normy.dot_simd(y, y),
2374 xy: acc.xy.dot_simd(x, y),
2375 }
2376 }
2377
2378 #[inline(always)]
2380 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2381 x.sum()
2382 }
2383}
2384
2385impl SIMDSchema<u8, u8, Scalar> for CosineStateless {
2386 type SIMDWidth = Const<4>;
2387 type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2388 type Left = Emulated<u8, 4>;
2389 type Right = Emulated<u8, 4>;
2390 type Return = f32;
2391 type Main = Strategy1x1;
2392
2393 #[inline(always)]
2394 fn init(&self, arch: Scalar) -> Self::Accumulator {
2395 Self::Accumulator::new(arch)
2396 }
2397
2398 #[inline(always)]
2399 fn accumulate(
2400 &self,
2401 x: Self::Left,
2402 y: Self::Right,
2403 acc: Self::Accumulator,
2404 ) -> Self::Accumulator {
2405 let x: Emulated<i32, 4> = x.into();
2406 let y: Emulated<i32, 4> = y.into();
2407 acc.add_with(x, y)
2408 }
2409
2410 #[inline(always)]
2412 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2413 x.sum()
2414 }
2415
2416 #[inline(always)]
2417 unsafe fn epilogue(
2418 &self,
2419 arch: Scalar,
2420 x: *const u8,
2421 y: *const u8,
2422 len: usize,
2423 acc: Self::Accumulator,
2424 ) -> Self::Accumulator {
2425 let mut xy: i32 = 0;
2426 let mut xx: i32 = 0;
2427 let mut yy: i32 = 0;
2428
2429 for i in 0..len {
2430 let vx: i32 = unsafe { x.add(i).read() }.into();
2432 let vy: i32 = unsafe { y.add(i).read() }.into();
2434
2435 xx += vx * vx;
2436 xy += vx * vy;
2437 yy += vy * vy;
2438 }
2439
2440 acc + FullCosineAccumulator {
2441 normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2442 normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2443 xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2444 }
2445 }
2446}
2447
2448#[derive(Debug, Clone, Copy)]
2450pub struct ResumableCosine<A = diskann_wide::arch::Current>(
2451 <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
2452)
2453where
2454 A: Architecture,
2455 CosineStateless: SIMDSchema<f32, f32, A>;
2456
2457impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableCosine<A>
2458where
2459 A: Architecture,
2460 CosineStateless: SIMDSchema<f32, f32, A, Return = f32>,
2461{
2462 type NonResumable = CosineStateless;
2463 type FinalReturn = f32;
2464
2465 #[inline(always)]
2466 fn init(arch: A) -> Self {
2467 Self(CosineStateless.init(arch))
2468 }
2469
2470 #[inline(always)]
2471 fn combine_with(
2472 &self,
2473 other: <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
2474 ) -> Self {
2475 Self(self.0 + other)
2476 }
2477
2478 #[inline(always)]
2479 fn sum(&self) -> f32 {
2480 CosineStateless.reduce(self.0)
2481 }
2482}
2483
2484#[derive(Clone, Copy, Debug, Default)]
2499pub struct L1Norm;
2500
2501#[cfg(target_arch = "x86_64")]
2502impl SIMDSchema<f32, f32, V4> for L1Norm {
2503 type SIMDWidth = Const<16>;
2504 type Accumulator = <V4 as Architecture>::f32x16;
2505 type Left = <V4 as Architecture>::f32x16;
2506 type Right = <V4 as Architecture>::f32x16;
2507 type Return = f32;
2508 type Main = Strategy4x1;
2509
2510 #[inline(always)]
2511 fn init(&self, arch: V4) -> Self::Accumulator {
2512 Self::Accumulator::default(arch)
2513 }
2514
2515 #[inline(always)]
2516 fn accumulate(
2517 &self,
2518 x: Self::Left,
2519 _y: Self::Right,
2520 acc: Self::Accumulator,
2521 ) -> Self::Accumulator {
2522 x.abs_simd() + acc
2523 }
2524
2525 #[inline(always)]
2527 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2528 x.sum_tree()
2529 }
2530}
2531
2532#[cfg(target_arch = "x86_64")]
2533impl SIMDSchema<f32, f32, V3> for L1Norm {
2534 type SIMDWidth = Const<8>;
2535 type Accumulator = <V3 as Architecture>::f32x8;
2536 type Left = <V3 as Architecture>::f32x8;
2537 type Right = <V3 as Architecture>::f32x8;
2538 type Return = f32;
2539 type Main = Strategy4x1;
2540
2541 #[inline(always)]
2542 fn init(&self, arch: V3) -> Self::Accumulator {
2543 Self::Accumulator::default(arch)
2544 }
2545
2546 #[inline(always)]
2547 fn accumulate(
2548 &self,
2549 x: Self::Left,
2550 _y: Self::Right,
2551 acc: Self::Accumulator,
2552 ) -> Self::Accumulator {
2553 x.abs_simd() + acc
2554 }
2555
2556 #[inline(always)]
2558 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2559 x.sum_tree()
2560 }
2561}
2562
2563impl SIMDSchema<f32, f32, Scalar> for L1Norm {
2564 type SIMDWidth = Const<4>;
2565 type Accumulator = Emulated<f32, 4>;
2566 type Left = Emulated<f32, 4>;
2567 type Right = Emulated<f32, 4>;
2568 type Return = f32;
2569 type Main = Strategy2x1;
2570
2571 #[inline(always)]
2572 fn init(&self, arch: Scalar) -> Self::Accumulator {
2573 Self::Accumulator::default(arch)
2574 }
2575
2576 #[inline(always)]
2577 fn accumulate(
2578 &self,
2579 x: Self::Left,
2580 _y: Self::Right,
2581 acc: Self::Accumulator,
2582 ) -> Self::Accumulator {
2583 x.abs_simd() + acc
2584 }
2585
2586 #[inline(always)]
2588 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2589 x.sum_tree()
2590 }
2591
2592 #[inline(always)]
2593 unsafe fn epilogue(
2594 &self,
2595 arch: Scalar,
2596 x: *const f32,
2597 _y: *const f32,
2598 len: usize,
2599 acc: Self::Accumulator,
2600 ) -> Self::Accumulator {
2601 let mut s: f32 = 0.0;
2602 for i in 0..len {
2603 let vx = unsafe { x.add(i).read() };
2605 s += vx.abs();
2606 }
2607 acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
2608 }
2609}
2610
2611#[cfg(target_arch = "x86_64")]
2612impl SIMDSchema<Half, Half, V4> for L1Norm {
2613 type SIMDWidth = Const<8>;
2614 type Accumulator = <V4 as Architecture>::f32x8;
2615 type Left = <V4 as Architecture>::f16x8;
2616 type Right = <V4 as Architecture>::f16x8;
2617 type Return = f32;
2618 type Main = Strategy2x4;
2619
2620 #[inline(always)]
2621 fn init(&self, arch: V4) -> Self::Accumulator {
2622 Self::Accumulator::default(arch)
2623 }
2624
2625 #[inline(always)]
2626 fn accumulate(
2627 &self,
2628 x: Self::Left,
2629 _y: Self::Right,
2630 acc: Self::Accumulator,
2631 ) -> Self::Accumulator {
2632 let x: <V4 as Architecture>::f32x8 = x.into();
2633 x.abs_simd() + acc
2634 }
2635
2636 #[inline(always)]
2638 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2639 x.sum_tree()
2640 }
2641}
2642
2643#[cfg(target_arch = "x86_64")]
2644impl SIMDSchema<Half, Half, V3> for L1Norm {
2645 type SIMDWidth = Const<8>;
2646 type Accumulator = <V3 as Architecture>::f32x8;
2647 type Left = <V3 as Architecture>::f16x8;
2648 type Right = <V3 as Architecture>::f16x8;
2649 type Return = f32;
2650 type Main = Strategy2x4;
2651
2652 #[inline(always)]
2653 fn init(&self, arch: V3) -> Self::Accumulator {
2654 Self::Accumulator::default(arch)
2655 }
2656
2657 #[inline(always)]
2658 fn accumulate(
2659 &self,
2660 x: Self::Left,
2661 _y: Self::Right,
2662 acc: Self::Accumulator,
2663 ) -> Self::Accumulator {
2664 let x: <V3 as Architecture>::f32x8 = x.into();
2665 x.abs_simd() + acc
2666 }
2667
2668 #[inline(always)]
2670 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2671 x.sum_tree()
2672 }
2673}
2674
2675impl SIMDSchema<Half, Half, Scalar> for L1Norm {
2676 type SIMDWidth = Const<1>;
2677 type Accumulator = Emulated<f32, 1>;
2678 type Left = Emulated<Half, 1>;
2679 type Right = Emulated<Half, 1>;
2680 type Return = f32;
2681 type Main = Strategy1x1;
2682
2683 #[inline(always)]
2684 fn init(&self, arch: Scalar) -> Self::Accumulator {
2685 Self::Accumulator::default(arch)
2686 }
2687
2688 #[inline(always)]
2689 fn accumulate(
2690 &self,
2691 x: Self::Left,
2692 _y: Self::Right,
2693 acc: Self::Accumulator,
2694 ) -> Self::Accumulator {
2695 let x: Self::Accumulator = x.into();
2696 x.abs_simd() + acc
2697 }
2698
2699 #[inline(always)]
2701 fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2702 x.to_array()[0]
2703 }
2704
2705 #[inline(always)]
2706 unsafe fn epilogue(
2707 &self,
2708 _arch: Scalar,
2709 _x: *const Half,
2710 _y: *const Half,
2711 _len: usize,
2712 _acc: Self::Accumulator,
2713 ) -> Self::Accumulator {
2714 unreachable!("The SIMD width is 1, so there should be no epilogue")
2715 }
2716}
2717
2718#[cfg(test)]
2723mod tests {
2724 use std::{collections::HashMap, sync::LazyLock};
2725
2726 use approx::assert_relative_eq;
2727 use diskann_wide::{arch::Target1, ARCH};
2728 use half::f16;
2729 use rand::{distr::StandardUniform, rngs::StdRng, Rng, SeedableRng};
2730 use rand_distr;
2731
2732 use super::*;
2733 use crate::{distance::reference, norm::LInfNorm, test_util};
2734
2735 fn cosine_norm_check_impl<A>(arch: A)
2740 where
2741 A: diskann_wide::Architecture,
2742 CosineStateless:
2743 SIMDSchema<f32, f32, A, Return = f32> + SIMDSchema<Half, Half, A, Return = f32>,
2744 {
2745 {
2747 let x: [f32; 2] = [0.0, 0.0];
2748 let y: [f32; 2] = [0.0, 1.0];
2749 assert_eq!(
2750 simd_op(&CosineStateless {}, arch, x, x),
2751 0.0,
2752 "when both vectors are zero, similarity should be zero",
2753 );
2754 assert_eq!(
2755 simd_op(&CosineStateless {}, arch, x, y),
2756 0.0,
2757 "when one vector is zero, similarity should be zero",
2758 );
2759 assert_eq!(
2760 simd_op(&CosineStateless {}, arch, y, x),
2761 0.0,
2762 "when one vector is zero, similarity should be zero",
2763 );
2764 }
2765
2766 {
2768 let x: [f32; 4] = [0.0, 0.0, 2.938736e-39f32, 0.0];
2769 let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
2770 assert_eq!(
2771 simd_op(&CosineStateless {}, arch, x, x),
2772 0.0,
2773 "when both vectors are almost zero, similarity should be zero",
2774 );
2775 assert_eq!(
2776 simd_op(&CosineStateless {}, arch, x, y),
2777 0.0,
2778 "when one vector is almost zero, similarity should be zero",
2779 );
2780 assert_eq!(
2781 simd_op(&CosineStateless {}, arch, y, x),
2782 0.0,
2783 "when one vector is almost zero, similarity should be zero",
2784 );
2785 }
2786
2787 {
2789 let x: [f32; 4] = [0.0, 0.0, 1.0842022e-19f32, 0.0];
2790 let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
2791 assert_eq!(
2792 simd_op(&CosineStateless {}, arch, x, x),
2793 1.0,
2794 "cosine-stateless should handle vectors this small",
2795 );
2796 assert_eq!(
2797 simd_op(&CosineStateless {}, arch, x, y),
2798 1.0,
2799 "cosine-stateless should handle vectors this small",
2800 );
2801 assert_eq!(
2802 simd_op(&CosineStateless {}, arch, y, x),
2803 1.0,
2804 "cosine-stateless should handle vectors this small",
2805 );
2806 }
2807
2808 let cvt = diskann_wide::cast_f32_to_f16;
2809
2810 {
2812 let x: [Half; 2] = [Half::default(), Half::default()];
2813 let y: [Half; 2] = [Half::default(), cvt(1.0)];
2814 assert_eq!(
2815 simd_op(&CosineStateless {}, arch, x, x),
2816 0.0,
2817 "when both vectors are zero, similarity should be zero",
2818 );
2819 assert_eq!(
2820 simd_op(&CosineStateless {}, arch, x, y),
2821 0.0,
2822 "when one vector is zero, similarity should be zero",
2823 );
2824 assert_eq!(
2825 simd_op(&CosineStateless {}, arch, y, x),
2826 0.0,
2827 "when one vector is zero, similarity should be zero",
2828 );
2829 }
2830
2831 {
2833 let x: [Half; 4] = [
2834 Half::default(),
2835 Half::default(),
2836 Half::MIN_POSITIVE_SUBNORMAL,
2837 Half::default(),
2838 ];
2839 let y: [Half; 4] = [Half::default(), Half::default(), cvt(1.0), Half::default()];
2840 assert_eq!(
2841 simd_op(&CosineStateless {}, arch, x, x),
2842 1.0,
2843 "when both vectors are almost zero, similarity should be zero",
2844 );
2845 assert_eq!(
2846 simd_op(&CosineStateless {}, arch, x, y),
2847 1.0,
2848 "when one vector is almost zero, similarity should be zero",
2849 );
2850 assert_eq!(
2851 simd_op(&CosineStateless {}, arch, y, x),
2852 1.0,
2853 "when one vector is almost zero, similarity should be zero",
2854 );
2855
2856 let threshold = f32::MIN_POSITIVE;
2862 let bound = 50;
2863 let values = {
2864 let mut down = threshold;
2865 let mut up = threshold;
2866 for _ in 0..bound {
2867 down = down.next_down();
2868 up = up.next_up();
2869 }
2870 assert!(down > 0.0);
2871 let min = down.sqrt();
2872 let max = up.sqrt();
2873 let mut v = min;
2874 let mut values = Vec::new();
2875 while v <= max {
2876 values.push(v);
2877 v = v.next_up();
2878 }
2879 values
2880 };
2881
2882 let mut lo = 0;
2883 let mut hi = 0;
2884 for i in values.iter() {
2885 for j in values.iter() {
2886 let s: f32 = simd_op(&CosineStateless {}, arch, [*i], [*j]);
2887 if i * i < threshold || j * j < threshold {
2888 lo += 1;
2889 assert_eq!(s, 0.0, "failed for i = {}, j = {}", i, j);
2890 } else {
2891 hi += 1;
2892 assert_eq!(s, 1.0, "failed for i = {}, j = {}", i, j);
2893 }
2894 }
2895 }
2896 assert_ne!(lo, 0);
2897 assert_ne!(hi, 0);
2898 }
2899 }
2900
2901 #[test]
2902 fn cosine_norm_check() {
2903 cosine_norm_check_impl::<diskann_wide::arch::Current>(diskann_wide::arch::current());
2904 cosine_norm_check_impl::<diskann_wide::arch::Scalar>(diskann_wide::arch::Scalar::new());
2905 }
2906
2907 #[test]
2908 #[cfg(target_arch = "x86_64")]
2909 fn cosine_norm_check_x86_64() {
2910 if let Some(arch) = V3::new_checked() {
2911 cosine_norm_check_impl::<V3>(arch);
2912 }
2913
2914 if let Some(arch) = V4::new_checked_miri() {
2915 cosine_norm_check_impl::<V4>(arch);
2916 }
2917 }
2918
2919 fn test_resumable<T, L, R, A>(arch: A, x: &[L], y: &[R], chunk_size: usize) -> f32
2925 where
2926 A: Architecture,
2927 T: ResumableSIMDSchema<L, R, A, FinalReturn = f32>,
2928 {
2929 let mut acc = Resumable(<T as ResumableSIMDSchema<L, R, A>>::init(arch));
2930 let iter = std::iter::zip(x.chunks(chunk_size), y.chunks(chunk_size));
2931 for (a, b) in iter {
2932 acc = simd_op(&acc, arch, a, b);
2933 }
2934 acc.0.sum()
2935 }
2936
2937 fn stress_test_with_resumable<
2938 A: Architecture,
2939 O: Default + SIMDSchema<f32, f32, A, Return = f32>,
2940 T: ResumableSIMDSchema<f32, f32, A, NonResumable = O, FinalReturn = f32>,
2941 Rand: Rng,
2942 >(
2943 arch: A,
2944 reference: fn(&[f32], &[f32]) -> f32,
2945 dim: usize,
2946 epsilon: f32,
2947 max_relative: f32,
2948 rng: &mut Rand,
2949 ) {
2950 let chunk_divisors: Vec<usize> = vec![1, 2, 3, 4, 16, 54, 64, 65, 70, 77];
2952 let checker = test_util::AdHocChecker::<f32, f32>::new(|a: &[f32], b: &[f32]| {
2953 let expected = reference(a, b);
2954 let got = simd_op(&O::default(), arch, a, b);
2955 println!("dim = {}", dim);
2956 assert_relative_eq!(
2957 expected,
2958 got,
2959 epsilon = epsilon,
2960 max_relative = max_relative,
2961 );
2962
2963 if dim == 0 {
2964 return;
2965 }
2966
2967 for d in &chunk_divisors {
2968 let chunk_size = dim / d + (!dim.is_multiple_of(*d) as usize);
2969 let chunked = test_resumable::<T, f32, f32, _>(arch, a, b, chunk_size);
2970 assert_relative_eq!(chunked, got, epsilon = epsilon, max_relative = max_relative);
2971 }
2972 });
2973
2974 test_util::test_distance_function(
2975 checker,
2976 rand_distr::Normal::new(0.0, 10.0).unwrap(),
2977 rand_distr::Normal::new(0.0, 10.0).unwrap(),
2978 dim,
2979 10,
2980 rng,
2981 )
2982 }
2983
2984 #[allow(clippy::too_many_arguments)]
2985 fn stress_test<L, R, DistLeft, DistRight, O, Rand, A>(
2986 arch: A,
2987 reference: fn(&[L], &[R]) -> f32,
2988 left_dist: DistLeft,
2989 right_dist: DistRight,
2990 dim: usize,
2991 epsilon: f32,
2992 max_relative: f32,
2993 rng: &mut Rand,
2994 ) where
2995 L: test_util::CornerCases,
2996 R: test_util::CornerCases,
2997 DistLeft: test_util::GenerateRandomArguments<L>,
2998 DistRight: test_util::GenerateRandomArguments<R>,
2999 O: Default + SIMDSchema<L, R, A, Return = f32>,
3000 Rand: Rng,
3001 A: Architecture,
3002 {
3003 let checker = test_util::Checker::<L, R, f32>::new(
3004 |x: &[L], y: &[R]| simd_op(&O::default(), arch, x, y),
3005 reference,
3006 |got, expected| {
3007 assert_relative_eq!(
3008 expected,
3009 got,
3010 epsilon = epsilon,
3011 max_relative = max_relative
3012 );
3013 },
3014 );
3015
3016 let trials = if cfg!(miri) { 0 } else { 10 };
3017
3018 test_util::test_distance_function(checker, left_dist, right_dist, dim, trials, rng);
3019 }
3020
3021 fn stress_test_linf<L, Dist, Rand, A>(
3022 arch: A,
3023 reference: fn(&[L]) -> f32,
3024 dist: Dist,
3025 dim: usize,
3026 epsilon: f32,
3027 max_relative: f32,
3028 rng: &mut Rand,
3029 ) where
3030 L: test_util::CornerCases + Copy,
3031 Dist: Clone + test_util::GenerateRandomArguments<L>,
3032 Rand: Rng,
3033 A: Architecture,
3034 LInfNorm: for<'a> Target1<A, f32, &'a [L]>,
3035 {
3036 let checker = test_util::Checker::<L, L, f32>::new(
3037 |x: &[L], _y: &[L]| (LInfNorm).run(arch, x),
3038 |x: &[L], _y: &[L]| reference(x),
3039 |got, expected| {
3040 assert_relative_eq!(
3041 expected,
3042 got,
3043 epsilon = epsilon,
3044 max_relative = max_relative
3045 );
3046 },
3047 );
3048
3049 println!("checking {dim}");
3050 test_util::test_distance_function(checker, dist.clone(), dist, dim, 10, rng);
3051 }
3052
3053 macro_rules! float_test {
3058 ($name:ident,
3059 $impl:ty,
3060 $resumable:ident,
3061 $reference:path,
3062 $eps:literal,
3063 $relative:literal,
3064 $seed:literal,
3065 $upper:literal,
3066 $($arch:tt)*
3067 ) => {
3068 #[test]
3069 fn $name() {
3070 if let Some(arch) = $($arch)* {
3071 let mut rng = StdRng::seed_from_u64($seed);
3072 for dim in 0..$upper {
3073 stress_test_with_resumable::<_, $impl, $resumable<_>, StdRng>(
3074 arch,
3075 |l, r| $reference(l, r).into_inner(),
3076 dim,
3077 $eps,
3078 $relative,
3079 &mut rng,
3080 );
3081 }
3082 }
3083 }
3084 }
3085 }
3086
3087 float_test!(
3092 test_l2_f32_current,
3093 L2,
3094 ResumableL2,
3095 reference::reference_squared_l2_f32_mathematical,
3096 1e-5,
3097 1e-5,
3098 0xf149c2bcde660128,
3099 64,
3100 Some(diskann_wide::ARCH)
3101 );
3102
3103 float_test!(
3104 test_l2_f32_scalar,
3105 L2,
3106 ResumableL2,
3107 reference::reference_squared_l2_f32_mathematical,
3108 1e-5,
3109 1e-5,
3110 0xf149c2bcde660128,
3111 64,
3112 Some(diskann_wide::arch::Scalar)
3113 );
3114
3115 #[cfg(target_arch = "x86_64")]
3116 float_test!(
3117 test_l2_f32_x86_64_v3,
3118 L2,
3119 ResumableL2,
3120 reference::reference_squared_l2_f32_mathematical,
3121 1e-5,
3122 1e-5,
3123 0xf149c2bcde660128,
3124 256,
3125 V3::new_checked()
3126 );
3127
3128 #[cfg(target_arch = "x86_64")]
3129 float_test!(
3130 test_l2_f32_x86_64_v4,
3131 L2,
3132 ResumableL2,
3133 reference::reference_squared_l2_f32_mathematical,
3134 1e-5,
3135 1e-5,
3136 0xf149c2bcde660128,
3137 256,
3138 V4::new_checked_miri()
3139 );
3140
3141 float_test!(
3146 test_ip_f32_current,
3147 IP,
3148 ResumableIP,
3149 reference::reference_innerproduct_f32_mathematical,
3150 2e-4,
3151 1e-3,
3152 0xb4687c17a9ea9866,
3153 64,
3154 Some(diskann_wide::ARCH)
3155 );
3156
3157 float_test!(
3158 test_ip_f32_scalar,
3159 IP,
3160 ResumableIP,
3161 reference::reference_innerproduct_f32_mathematical,
3162 2e-4,
3163 1e-3,
3164 0xb4687c17a9ea9866,
3165 64,
3166 Some(diskann_wide::arch::Scalar)
3167 );
3168
3169 #[cfg(target_arch = "x86_64")]
3170 float_test!(
3171 test_ip_f32_x86_64_v3,
3172 IP,
3173 ResumableIP,
3174 reference::reference_innerproduct_f32_mathematical,
3175 2e-4,
3176 1e-3,
3177 0xb4687c17a9ea9866,
3178 256,
3179 V3::new_checked()
3180 );
3181
3182 #[cfg(target_arch = "x86_64")]
3183 float_test!(
3184 test_ip_f32_x86_64_v4,
3185 IP,
3186 ResumableIP,
3187 reference::reference_innerproduct_f32_mathematical,
3188 2e-4,
3189 1e-3,
3190 0xb4687c17a9ea9866,
3191 256,
3192 V4::new_checked_miri()
3193 );
3194
3195 float_test!(
3200 test_cosine_f32_current,
3201 CosineStateless,
3202 ResumableCosine,
3203 reference::reference_cosine_f32_mathematical,
3204 1e-5,
3205 1e-5,
3206 0xe860e9dc65f38bb8,
3207 64,
3208 Some(diskann_wide::ARCH)
3209 );
3210
3211 float_test!(
3212 test_cosine_f32_scalar,
3213 CosineStateless,
3214 ResumableCosine,
3215 reference::reference_cosine_f32_mathematical,
3216 1e-5,
3217 1e-5,
3218 0xe860e9dc65f38bb8,
3219 64,
3220 Some(diskann_wide::arch::Scalar)
3221 );
3222
3223 #[cfg(target_arch = "x86_64")]
3224 float_test!(
3225 test_cosine_f32_x86_64_v3,
3226 CosineStateless,
3227 ResumableCosine,
3228 reference::reference_cosine_f32_mathematical,
3229 1e-5,
3230 1e-5,
3231 0xe860e9dc65f38bb8,
3232 256,
3233 V3::new_checked()
3234 );
3235
3236 #[cfg(target_arch = "x86_64")]
3237 float_test!(
3238 test_cosine_f32_x86_64_v4,
3239 CosineStateless,
3240 ResumableCosine,
3241 reference::reference_cosine_f32_mathematical,
3242 1e-5,
3243 1e-5,
3244 0xe860e9dc65f38bb8,
3245 256,
3246 V4::new_checked_miri()
3247 );
3248
3249 macro_rules! half_test {
3254 ($name:ident,
3255 $impl:ty,
3256 $reference:path,
3257 $eps:literal,
3258 $relative:literal,
3259 $seed:literal,
3260 $upper:literal,
3261 $($arch:tt)*
3262 ) => {
3263 #[test]
3264 fn $name() {
3265 if let Some(arch) = $($arch)* {
3266 let mut rng = StdRng::seed_from_u64($seed);
3267 for dim in 0..$upper {
3268 stress_test::<
3269 Half,
3270 Half,
3271 rand_distr::Normal<f32>,
3272 rand_distr::Normal<f32>,
3273 $impl,
3274 StdRng,
3275 _
3276 >(
3277 arch,
3278 |l, r| $reference(l, r).into_inner(),
3279 rand_distr::Normal::new(0.0, 10.0).unwrap(),
3280 rand_distr::Normal::new(0.0, 10.0).unwrap(),
3281 dim,
3282 $eps,
3283 $relative,
3284 &mut rng
3285 );
3286 }
3287 }
3288 }
3289 }
3290 }
3291
3292 half_test!(
3297 test_l2_f16_current,
3298 L2,
3299 reference::reference_squared_l2_f16_mathematical,
3300 1e-5,
3301 1e-5,
3302 0x87ca6f1051667500,
3303 64,
3304 Some(diskann_wide::ARCH)
3305 );
3306
3307 half_test!(
3308 test_l2_f16_scalar,
3309 L2,
3310 reference::reference_squared_l2_f16_mathematical,
3311 1e-5,
3312 1e-5,
3313 0x87ca6f1051667500,
3314 64,
3315 Some(diskann_wide::arch::Scalar)
3316 );
3317
3318 #[cfg(target_arch = "x86_64")]
3319 half_test!(
3320 test_l2_f16_x86_64_v3,
3321 L2,
3322 reference::reference_squared_l2_f16_mathematical,
3323 1e-5,
3324 1e-5,
3325 0x87ca6f1051667500,
3326 256,
3327 V3::new_checked()
3328 );
3329
3330 #[cfg(target_arch = "x86_64")]
3331 half_test!(
3332 test_l2_f16_x86_64_v4,
3333 L2,
3334 reference::reference_squared_l2_f16_mathematical,
3335 1e-5,
3336 1e-5,
3337 0x87ca6f1051667500,
3338 256,
3339 V4::new_checked_miri()
3340 );
3341
3342 half_test!(
3347 test_ip_f16_current,
3348 IP,
3349 reference::reference_innerproduct_f16_mathematical,
3350 2e-4,
3351 2e-4,
3352 0x5909f5f20307ccbe,
3353 64,
3354 Some(diskann_wide::ARCH)
3355 );
3356
3357 half_test!(
3358 test_ip_f16_scalar,
3359 IP,
3360 reference::reference_innerproduct_f16_mathematical,
3361 2e-4,
3362 2e-4,
3363 0x5909f5f20307ccbe,
3364 64,
3365 Some(diskann_wide::arch::Scalar)
3366 );
3367
3368 #[cfg(target_arch = "x86_64")]
3369 half_test!(
3370 test_ip_f16_x86_64_v3,
3371 IP,
3372 reference::reference_innerproduct_f16_mathematical,
3373 2e-4,
3374 2e-4,
3375 0x5909f5f20307ccbe,
3376 256,
3377 V3::new_checked()
3378 );
3379
3380 #[cfg(target_arch = "x86_64")]
3381 half_test!(
3382 test_ip_f16_x86_64_v4,
3383 IP,
3384 reference::reference_innerproduct_f16_mathematical,
3385 2e-4,
3386 2e-4,
3387 0x5909f5f20307ccbe,
3388 256,
3389 V4::new_checked_miri()
3390 );
3391
3392 half_test!(
3397 test_cosine_f16_current,
3398 CosineStateless,
3399 reference::reference_cosine_f16_mathematical,
3400 1e-5,
3401 1e-5,
3402 0x41dda34655f05ef6,
3403 64,
3404 Some(diskann_wide::ARCH)
3405 );
3406
3407 half_test!(
3408 test_cosine_f16_scalar,
3409 CosineStateless,
3410 reference::reference_cosine_f16_mathematical,
3411 1e-5,
3412 1e-5,
3413 0x41dda34655f05ef6,
3414 64,
3415 Some(diskann_wide::arch::Scalar)
3416 );
3417
3418 #[cfg(target_arch = "x86_64")]
3419 half_test!(
3420 test_cosine_f16_x86_64_v3,
3421 CosineStateless,
3422 reference::reference_cosine_f16_mathematical,
3423 1e-5,
3424 1e-5,
3425 0x41dda34655f05ef6,
3426 256,
3427 V3::new_checked()
3428 );
3429
3430 #[cfg(target_arch = "x86_64")]
3431 half_test!(
3432 test_cosine_f16_x86_64_v4,
3433 CosineStateless,
3434 reference::reference_cosine_f16_mathematical,
3435 1e-5,
3436 1e-5,
3437 0x41dda34655f05ef6,
3438 256,
3439 V4::new_checked_miri()
3440 );
3441
3442 macro_rules! int_test {
3447 (
3448 $name:ident,
3449 $T:ty,
3450 $impl:ty,
3451 $reference:path,
3452 $seed:literal,
3453 $upper:literal,
3454 { $($arch:tt)* }
3455 ) => {
3456 #[test]
3457 fn $name() {
3458 if let Some(arch) = $($arch)* {
3459 let mut rng = StdRng::seed_from_u64($seed);
3460 for dim in 0..$upper {
3461 stress_test::<$T, $T, _, _, $impl, _, _>(
3462 arch,
3463 |l, r| $reference(l, r).into_inner(),
3464 StandardUniform,
3465 StandardUniform,
3466 dim,
3467 0.0,
3468 0.0,
3469 &mut rng,
3470 )
3471 }
3472 }
3473 }
3474 }
3475 }
3476
3477 int_test!(
3482 test_l2_u8_current,
3483 u8,
3484 L2,
3485 reference::reference_squared_l2_u8_mathematical,
3486 0x945bdc37d8279d4b,
3487 128,
3488 { Some(ARCH) }
3489 );
3490
3491 int_test!(
3492 test_l2_u8_scalar,
3493 u8,
3494 L2,
3495 reference::reference_squared_l2_u8_mathematical,
3496 0x74c86334ab7a51f9,
3497 128,
3498 { Some(diskann_wide::arch::Scalar) }
3499 );
3500
3501 #[cfg(target_arch = "x86_64")]
3502 int_test!(
3503 test_l2_u8_x86_64_v3,
3504 u8,
3505 L2,
3506 reference::reference_squared_l2_u8_mathematical,
3507 0x74c86334ab7a51f9,
3508 256,
3509 { V3::new_checked() }
3510 );
3511
3512 #[cfg(target_arch = "x86_64")]
3513 int_test!(
3514 test_l2_u8_x86_64_v4,
3515 u8,
3516 L2,
3517 reference::reference_squared_l2_u8_mathematical,
3518 0x74c86334ab7a51f9,
3519 320,
3520 { V4::new_checked_miri() }
3521 );
3522
3523 int_test!(
3524 test_ip_u8_current,
3525 u8,
3526 IP,
3527 reference::reference_innerproduct_u8_mathematical,
3528 0xcbe0342c75085fd5,
3529 64,
3530 { Some(ARCH) }
3531 );
3532
3533 int_test!(
3534 test_ip_u8_scalar,
3535 u8,
3536 IP,
3537 reference::reference_innerproduct_u8_mathematical,
3538 0x888e07fc489e773f,
3539 64,
3540 { Some(diskann_wide::arch::Scalar) }
3541 );
3542
3543 #[cfg(target_arch = "x86_64")]
3544 int_test!(
3545 test_ip_u8_x86_64_v3,
3546 u8,
3547 IP,
3548 reference::reference_innerproduct_u8_mathematical,
3549 0x888e07fc489e773f,
3550 256,
3551 { V3::new_checked() }
3552 );
3553
3554 #[cfg(target_arch = "x86_64")]
3555 int_test!(
3556 test_ip_u8_x86_64_v4,
3557 u8,
3558 IP,
3559 reference::reference_innerproduct_u8_mathematical,
3560 0x888e07fc489e773f,
3561 320,
3562 { V4::new_checked_miri() }
3563 );
3564
3565 int_test!(
3566 test_cosine_u8_current,
3567 u8,
3568 CosineStateless,
3569 reference::reference_cosine_u8_mathematical,
3570 0x96867b6aff616b28,
3571 64,
3572 { Some(ARCH) }
3573 );
3574
3575 int_test!(
3576 test_cosine_u8_scalar,
3577 u8,
3578 CosineStateless,
3579 reference::reference_cosine_u8_mathematical,
3580 0xcc258c9391733211,
3581 64,
3582 { Some(diskann_wide::arch::Scalar) }
3583 );
3584
3585 #[cfg(target_arch = "x86_64")]
3586 int_test!(
3587 test_cosine_u8_x86_64_v3,
3588 u8,
3589 CosineStateless,
3590 reference::reference_cosine_u8_mathematical,
3591 0xcc258c9391733211,
3592 256,
3593 { V3::new_checked() }
3594 );
3595
3596 #[cfg(target_arch = "x86_64")]
3597 int_test!(
3598 test_cosine_u8_x86_64_v4,
3599 u8,
3600 CosineStateless,
3601 reference::reference_cosine_u8_mathematical,
3602 0xcc258c9391733211,
3603 320,
3604 { V4::new_checked_miri() }
3605 );
3606
3607 int_test!(
3612 test_l2_i8_current,
3613 i8,
3614 L2,
3615 reference::reference_squared_l2_i8_mathematical,
3616 0xa60136248cd3c2f0,
3617 64,
3618 { Some(ARCH) }
3619 );
3620
3621 int_test!(
3622 test_l2_i8_scalar,
3623 i8,
3624 L2,
3625 reference::reference_squared_l2_i8_mathematical,
3626 0x3e8bada709e176be,
3627 64,
3628 { Some(diskann_wide::arch::Scalar) }
3629 );
3630
3631 #[cfg(target_arch = "x86_64")]
3632 int_test!(
3633 test_l2_i8_x86_64_v3,
3634 i8,
3635 L2,
3636 reference::reference_squared_l2_i8_mathematical,
3637 0x3e8bada709e176be,
3638 256,
3639 { V3::new_checked() }
3640 );
3641
3642 #[cfg(target_arch = "x86_64")]
3643 int_test!(
3644 test_l2_i8_x86_64_v4,
3645 i8,
3646 L2,
3647 reference::reference_squared_l2_i8_mathematical,
3648 0x3e8bada709e176be,
3649 320,
3650 { V4::new_checked_miri() }
3651 );
3652
3653 int_test!(
3654 test_ip_i8_current,
3655 i8,
3656 IP,
3657 reference::reference_innerproduct_i8_mathematical,
3658 0xe8306104740509e1,
3659 64,
3660 { Some(ARCH) }
3661 );
3662
3663 int_test!(
3664 test_ip_i8_scalar,
3665 i8,
3666 IP,
3667 reference::reference_innerproduct_i8_mathematical,
3668 0x8a263408c7b31d85,
3669 64,
3670 { Some(diskann_wide::arch::Scalar) }
3671 );
3672
3673 #[cfg(target_arch = "x86_64")]
3674 int_test!(
3675 test_ip_i8_x86_64_v3,
3676 i8,
3677 IP,
3678 reference::reference_innerproduct_i8_mathematical,
3679 0x8a263408c7b31d85,
3680 256,
3681 { V3::new_checked() }
3682 );
3683
3684 #[cfg(target_arch = "x86_64")]
3685 int_test!(
3686 test_ip_i8_x86_64_v4,
3687 i8,
3688 IP,
3689 reference::reference_innerproduct_i8_mathematical,
3690 0x8a263408c7b31d85,
3691 320,
3692 { V4::new_checked_miri() }
3693 );
3694
3695 int_test!(
3696 test_cosine_i8_current,
3697 i8,
3698 CosineStateless,
3699 reference::reference_cosine_i8_mathematical,
3700 0x818c210190701e4b,
3701 64,
3702 { Some(ARCH) }
3703 );
3704
3705 int_test!(
3706 test_cosine_i8_scalar,
3707 i8,
3708 CosineStateless,
3709 reference::reference_cosine_i8_mathematical,
3710 0x2d077bed2629b18e,
3711 64,
3712 { Some(diskann_wide::arch::Scalar) }
3713 );
3714
3715 #[cfg(target_arch = "x86_64")]
3716 int_test!(
3717 test_cosine_i8_x86_64_v3,
3718 i8,
3719 CosineStateless,
3720 reference::reference_cosine_i8_mathematical,
3721 0x2d077bed2629b18e,
3722 256,
3723 { V3::new_checked() }
3724 );
3725
3726 #[cfg(target_arch = "x86_64")]
3727 int_test!(
3728 test_cosine_i8_x86_64_v4,
3729 i8,
3730 CosineStateless,
3731 reference::reference_cosine_i8_mathematical,
3732 0x2d077bed2629b18e,
3733 320,
3734 { V4::new_checked_miri() }
3735 );
3736
3737 macro_rules! linf_test {
3742 ($name:ident,
3743 $T:ty,
3744 $reference:path,
3745 $eps:literal,
3746 $relative:literal,
3747 $seed:literal,
3748 $upper:literal,
3749 $($arch:tt)*
3750 ) => {
3751 #[test]
3752 fn $name() {
3753 if let Some(arch) = $($arch)* {
3754 let mut rng = StdRng::seed_from_u64($seed);
3755 for dim in 0..$upper {
3756 stress_test_linf::<$T, _, StdRng, _>(
3757 arch,
3758 |l| $reference(l).into_inner(),
3759 rand_distr::Normal::new(-10.0, 10.0).unwrap(),
3760 dim,
3761 $eps,
3762 $relative,
3763 &mut rng,
3764 );
3765 }
3766 }
3767 }
3768 }
3769 }
3770
3771 linf_test!(
3772 test_linf_f32_scalar,
3773 f32,
3774 reference::reference_linf_f32_mathematical,
3775 1e-6,
3776 1e-6,
3777 0xf149c2bcde660128,
3778 256,
3779 Some(Scalar::new())
3780 );
3781
3782 #[cfg(target_arch = "x86_64")]
3783 linf_test!(
3784 test_linf_f32_v3,
3785 f32,
3786 reference::reference_linf_f32_mathematical,
3787 1e-6,
3788 1e-6,
3789 0xf149c2bcde660128,
3790 256,
3791 V3::new_checked()
3792 );
3793
3794 #[cfg(target_arch = "x86_64")]
3795 linf_test!(
3796 test_linf_f32_v4,
3797 f32,
3798 reference::reference_linf_f32_mathematical,
3799 1e-6,
3800 1e-6,
3801 0xf149c2bcde660128,
3802 256,
3803 V4::new_checked_miri()
3804 );
3805
3806 linf_test!(
3807 test_linf_f16_scalar,
3808 f16,
3809 reference::reference_linf_f16_mathematical,
3810 1e-6,
3811 1e-6,
3812 0xf149c2bcde660128,
3813 256,
3814 Some(Scalar::new())
3815 );
3816
3817 #[cfg(target_arch = "x86_64")]
3818 linf_test!(
3819 test_linf_f16_v3,
3820 f16,
3821 reference::reference_linf_f16_mathematical,
3822 1e-6,
3823 1e-6,
3824 0xf149c2bcde660128,
3825 256,
3826 V3::new_checked()
3827 );
3828
3829 #[cfg(target_arch = "x86_64")]
3830 linf_test!(
3831 test_linf_f16_v4,
3832 f16,
3833 reference::reference_linf_f16_mathematical,
3834 1e-6,
3835 1e-6,
3836 0xf149c2bcde660128,
3837 256,
3838 V4::new_checked_miri()
3839 );
3840
3841 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3846 enum DataType {
3847 Float32,
3848 Float16,
3849 UInt8,
3850 Int8,
3851 }
3852
3853 trait AsDataType {
3854 fn as_data_type() -> DataType;
3855 }
3856
3857 impl AsDataType for f32 {
3858 fn as_data_type() -> DataType {
3859 DataType::Float32
3860 }
3861 }
3862
3863 impl AsDataType for f16 {
3864 fn as_data_type() -> DataType {
3865 DataType::Float16
3866 }
3867 }
3868
3869 impl AsDataType for u8 {
3870 fn as_data_type() -> DataType {
3871 DataType::UInt8
3872 }
3873 }
3874
3875 impl AsDataType for i8 {
3876 fn as_data_type() -> DataType {
3877 DataType::Int8
3878 }
3879 }
3880
3881 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3882 enum Arch {
3883 Scalar,
3884 #[expect(non_camel_case_types)]
3885 X86_64_V3,
3886 #[expect(non_camel_case_types)]
3887 X86_64_V4,
3888 }
3889
3890 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3891 struct Key {
3892 arch: Arch,
3893 left: DataType,
3894 right: DataType,
3895 }
3896
3897 impl Key {
3898 fn new(arch: Arch, left: DataType, right: DataType) -> Self {
3899 Self { arch, left, right }
3900 }
3901 }
3902
3903 static MIRI_BOUNDS: LazyLock<HashMap<Key, usize>> = LazyLock::new(|| {
3904 use Arch::{Scalar, X86_64_V3, X86_64_V4};
3905 use DataType::{Float16, Float32, Int8, UInt8};
3906
3907 [
3908 (Key::new(Scalar, Float32, Float32), 64),
3909 (Key::new(X86_64_V3, Float32, Float32), 256),
3910 (Key::new(X86_64_V4, Float32, Float32), 256),
3911 (Key::new(Scalar, Float16, Float16), 64),
3912 (Key::new(X86_64_V3, Float16, Float16), 256),
3913 (Key::new(X86_64_V4, Float16, Float16), 256),
3914 (Key::new(Scalar, Float32, Float16), 64),
3915 (Key::new(X86_64_V3, Float32, Float16), 256),
3916 (Key::new(X86_64_V4, Float32, Float16), 256),
3917 (Key::new(Scalar, UInt8, UInt8), 64),
3918 (Key::new(X86_64_V3, UInt8, UInt8), 256),
3919 (Key::new(X86_64_V4, UInt8, UInt8), 320),
3920 (Key::new(Scalar, Int8, Int8), 64),
3921 (Key::new(X86_64_V3, Int8, Int8), 256),
3922 (Key::new(X86_64_V4, Int8, Int8), 320),
3923 ]
3924 .into_iter()
3925 .collect()
3926 });
3927
3928 macro_rules! test_bounds {
3929 (
3930 $function:ident,
3931 $left:ty,
3932 $left_ex:expr,
3933 $right:ty,
3934 $right_ex:expr
3935 ) => {
3936 #[test]
3937 fn $function() {
3938 let left: $left = $left_ex;
3939 let right: $right = $right_ex;
3940
3941 let left_type = <$left>::as_data_type();
3942 let right_type = <$right>::as_data_type();
3943
3944 {
3946 let max = MIRI_BOUNDS[&Key::new(Arch::Scalar, left_type, right_type)];
3947 for dim in 0..max {
3948 let left: Vec<$left> = vec![left; dim];
3949 let right: Vec<$right> = vec![right; dim];
3950
3951 let arch = diskann_wide::arch::Scalar;
3952 simd_op(&L2, arch, left.as_slice(), right.as_slice());
3953 simd_op(&IP, arch, left.as_slice(), right.as_slice());
3954 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3955 }
3956 }
3957
3958 #[cfg(target_arch = "x86_64")]
3959 if let Some(arch) = V3::new_checked() {
3960 let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V3, left_type, right_type)];
3961 for dim in 0..max {
3962 let left: Vec<$left> = vec![left; dim];
3963 let right: Vec<$right> = vec![right; dim];
3964
3965 simd_op(&L2, arch, left.as_slice(), right.as_slice());
3966 simd_op(&IP, arch, left.as_slice(), right.as_slice());
3967 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3968 }
3969 }
3970
3971 #[cfg(target_arch = "x86_64")]
3972 if let Some(arch) = V4::new_checked_miri() {
3973 let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V4, left_type, right_type)];
3974 for dim in 0..max {
3975 let left: Vec<$left> = vec![left; dim];
3976 let right: Vec<$right> = vec![right; dim];
3977
3978 simd_op(&L2, arch, left.as_slice(), right.as_slice());
3979 simd_op(&IP, arch, left.as_slice(), right.as_slice());
3980 simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3981 }
3982 }
3983 }
3984 };
3985 }
3986
3987 test_bounds!(miri_test_bounds_f32xf32, f32, 1.0f32, f32, 2.0f32);
3988 test_bounds!(
3989 miri_test_bounds_f16xf16,
3990 f16,
3991 diskann_wide::cast_f32_to_f16(1.0f32),
3992 f16,
3993 diskann_wide::cast_f32_to_f16(2.0f32)
3994 );
3995 test_bounds!(
3996 miri_test_bounds_f32xf16,
3997 f32,
3998 1.0f32,
3999 f16,
4000 diskann_wide::cast_f32_to_f16(2.0f32)
4001 );
4002 test_bounds!(miri_test_bounds_u8xu8, u8, 1u8, u8, 1u8);
4003 test_bounds!(miri_test_bounds_i8xi8, i8, 1i8, i8, 1i8);
4004}