1use diskann_vector::PureDistanceFunction;
111use diskann_wide::{ARCH, Architecture, arch::Target2};
112#[cfg(target_arch = "x86_64")]
113use diskann_wide::{
114 SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
115};
116
117use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
118use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
119
120type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
122
123macro_rules! retarget {
129 ($arch:path, $op:ty, ($N:literal, $M:literal)) => {
130 impl Target2<
131 $arch,
132 MathematicalResult<u32>,
133 USlice<'_, $N>,
134 USlice<'_, $M>,
135 > for $op {
136 #[inline(always)]
137 fn run(
138 self,
139 arch: $arch,
140 x: USlice<'_, $N>,
141 y: USlice<'_, $M>
142 ) -> MathematicalResult<u32> {
143 self.run(arch.retarget(), x, y)
144 }
145 }
146 };
147 ($arch:path, $op:ty, $N:literal) => {
148 retarget!($arch, $op, ($N, $N));
149 };
150 ($arch:path, $op:ty, $($args:tt),+ $(,)?) => {
151 $(retarget!($arch, $op, $args);)+
152 };
153}
154
155macro_rules! dispatch_pure {
157 ($op:ty, ($N:literal, $M:literal)) => {
158 impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $M>, MathematicalResult<u32>> for $op {
159 #[inline(always)]
160 fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $M>) -> MathematicalResult<u32> {
161 (diskann_wide::ARCH).run2(Self, x, y)
162 }
163 }
164 };
165 ($op:ty, $N:literal) => {
166 dispatch_pure!($op, ($N, $N));
167 };
168 ($op:ty, $($args:tt),+ $(,)?) => {
169 $(dispatch_pure!($op, $args);)+
170 }
171}
172
173#[cfg(target_arch = "x86_64")]
180unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
181where
182 F: FnMut(u32) -> R,
183{
184 f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
186}
187
188#[cfg(target_arch = "x86_64")]
195unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
196where
197 F: FnMut(u32) -> R,
198{
199 f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
201}
202
203#[cfg(target_arch = "x86_64")]
210unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
211where
212 F: FnMut(u32) -> R,
213{
214 let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
216 let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
218 f(lo | hi << 16)
219}
220
221#[cfg(target_arch = "x86_64")]
228unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
229where
230 F: FnMut(u32) -> R,
231{
232 f(unsafe { ptr.read_unaligned() })
234}
235
236trait BitVectorOp<Repr>
248where
249 Repr: Representation<1>,
250{
251 fn on_u64(x: u64, y: u64) -> u32;
253
254 fn on_u8(x: u8, y: u8) -> u32;
259}
260
261impl BitVectorOp<Unsigned> for SquaredL2 {
263 #[inline(always)]
264 fn on_u64(x: u64, y: u64) -> u32 {
265 (x ^ y).count_ones()
266 }
267 #[inline(always)]
268 fn on_u8(x: u8, y: u8) -> u32 {
269 (x ^ y).count_ones()
270 }
271}
272
273impl BitVectorOp<Binary> for Hamming {
275 #[inline(always)]
276 fn on_u64(x: u64, y: u64) -> u32 {
277 (x ^ y).count_ones()
278 }
279 #[inline(always)]
280 fn on_u8(x: u8, y: u8) -> u32 {
281 (x ^ y).count_ones()
282 }
283}
284
285impl BitVectorOp<Unsigned> for InnerProduct {
293 #[inline(always)]
294 fn on_u64(x: u64, y: u64) -> u32 {
295 (x & y).count_ones()
296 }
297 #[inline(always)]
298 fn on_u8(x: u8, y: u8) -> u32 {
299 (x & y).count_ones()
300 }
301}
302
303#[inline(always)]
308fn bitvector_op<Op, Repr>(
309 x: BitSlice<'_, 1, Repr>,
310 y: BitSlice<'_, 1, Repr>,
311) -> MathematicalResult<u32>
312where
313 Repr: Representation<1>,
314 Op: BitVectorOp<Repr>,
315{
316 let len = check_lengths!(x, y)?;
317
318 let px: *const u64 = x.as_ptr().cast();
319 let py: *const u64 = y.as_ptr().cast();
320
321 let mut i = 0;
322 let mut s: u32 = 0;
323
324 let blocks = len / 64;
326 while i < blocks {
327 let vx = unsafe { px.add(i).read_unaligned() };
331
332 let vy = unsafe { py.add(i).read_unaligned() };
336
337 s += Op::on_u64(vx, vy);
338 i += 1;
339 }
340
341 i *= 8;
343 let px: *const u8 = x.as_ptr();
344 let py: *const u8 = y.as_ptr();
345
346 let blocks = len / 8;
347 while i < blocks {
348 let vx = unsafe { px.add(i).read_unaligned() };
351
352 let vy = unsafe { py.add(i).read_unaligned() };
356 s += Op::on_u8(vx, vy);
357 i += 1;
358 }
359
360 if i * 8 != len {
361 let vx = unsafe { px.add(i).read_unaligned() };
364
365 let vy = unsafe { py.add(i).read_unaligned() };
367 let m = (0x01u8 << (len - 8 * i)) - 1;
368
369 s += Op::on_u8(vx & m, vy & m)
370 }
371 Ok(MV::new(s))
372}
373
374impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
378 for Hamming
379{
380 fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
381 bitvector_op::<Hamming, Binary>(x, y)
382 }
383}
384
385impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
398where
399 A: Architecture,
400 diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
401{
402 #[inline(always)]
403 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
404 check_lengths!(x, y)?;
405
406 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
407 diskann_vector::distance::SquaredL2 {},
408 arch,
409 x.as_slice(),
410 y.as_slice(),
411 );
412
413 Ok(MV::new(r.into_inner() as u32))
414 }
415}
416
417#[cfg(target_arch = "x86_64")]
431impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
432 for SquaredL2
433{
434 #[inline(always)]
435 fn run(
436 self,
437 arch: diskann_wide::arch::x86_64::V3,
438 x: USlice<'_, 4>,
439 y: USlice<'_, 4>,
440 ) -> MathematicalResult<u32> {
441 let len = check_lengths!(x, y)?;
442
443 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
444 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
445 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
446
447 let px_u32: *const u32 = x.as_ptr().cast();
448 let py_u32: *const u32 = y.as_ptr().cast();
449
450 let mut i = 0;
451 let mut s: u32 = 0;
452
453 let blocks = len / 8;
455 if i < blocks {
456 let mut s0 = i32s::default(arch);
457 let mut s1 = i32s::default(arch);
458 let mut s2 = i32s::default(arch);
459 let mut s3 = i32s::default(arch);
460 let mask = u32s::splat(arch, 0x000f000f);
461 while i + 8 < blocks {
462 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
467
468 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
472
473 let wx: i16s = (vx & mask).reinterpret_simd();
474 let wy: i16s = (vy & mask).reinterpret_simd();
475 let d = wx - wy;
476 s0 = s0.dot_simd(d, d);
477
478 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
479 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
480 let d = wx - wy;
481 s1 = s1.dot_simd(d, d);
482
483 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
484 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
485 let d = wx - wy;
486 s2 = s2.dot_simd(d, d);
487
488 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
489 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
490 let d = wx - wy;
491 s3 = s3.dot_simd(d, d);
492
493 i += 8;
494 }
495
496 let remainder = blocks - i;
497
498 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
504
505 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
509
510 let wx: i16s = (vx & mask).reinterpret_simd();
511 let wy: i16s = (vy & mask).reinterpret_simd();
512 let d = wx - wy;
513 s0 = s0.dot_simd(d, d);
514
515 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
516 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
517 let d = wx - wy;
518 s1 = s1.dot_simd(d, d);
519
520 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
521 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
522 let d = wx - wy;
523 s2 = s2.dot_simd(d, d);
524
525 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
526 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
527 let d = wx - wy;
528 s3 = s3.dot_simd(d, d);
529
530 i += remainder;
531
532 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
533 }
534
535 i *= 8;
537
538 if i != len {
540 #[inline(never)]
542 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
543 let mut s: i32 = 0;
544 for i in from..x.len() {
545 let ix = unsafe { x.get_unchecked(i) } as i32;
547 let iy = unsafe { y.get_unchecked(i) } as i32;
549 let d = ix - iy;
550 s += d * d;
551 }
552 s as u32
553 }
554 s += fallback(x, y, i);
555 }
556
557 Ok(MV::new(s))
558 }
559}
560
561#[cfg(target_arch = "x86_64")]
575impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
576 for SquaredL2
577{
578 #[inline(always)]
579 fn run(
580 self,
581 arch: diskann_wide::arch::x86_64::V3,
582 x: USlice<'_, 2>,
583 y: USlice<'_, 2>,
584 ) -> MathematicalResult<u32> {
585 let len = check_lengths!(x, y)?;
586
587 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
588 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
589 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
590
591 let px_u32: *const u32 = x.as_ptr().cast();
592 let py_u32: *const u32 = y.as_ptr().cast();
593
594 let mut i = 0;
595 let mut s: u32 = 0;
596
597 let blocks = len / 16;
599 if i < blocks {
600 let mut s0 = i32s::default(arch);
601 let mut s1 = i32s::default(arch);
602 let mut s2 = i32s::default(arch);
603 let mut s3 = i32s::default(arch);
604 let mask = u32s::splat(arch, 0x00030003);
605 while i + 8 < blocks {
606 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
611
612 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
616
617 let wx: i16s = (vx & mask).reinterpret_simd();
618 let wy: i16s = (vy & mask).reinterpret_simd();
619 let d = wx - wy;
620 s0 = s0.dot_simd(d, d);
621
622 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
623 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
624 let d = wx - wy;
625 s1 = s1.dot_simd(d, d);
626
627 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
628 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
629 let d = wx - wy;
630 s2 = s2.dot_simd(d, d);
631
632 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
633 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
634 let d = wx - wy;
635 s3 = s3.dot_simd(d, d);
636
637 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
638 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
639 let d = wx - wy;
640 s0 = s0.dot_simd(d, d);
641
642 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
643 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
644 let d = wx - wy;
645 s1 = s1.dot_simd(d, d);
646
647 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
648 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
649 let d = wx - wy;
650 s2 = s2.dot_simd(d, d);
651
652 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
653 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
654 let d = wx - wy;
655 s3 = s3.dot_simd(d, d);
656
657 i += 8;
658 }
659
660 let remainder = blocks - i;
661
662 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
668
669 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
673 let wx: i16s = (vx & mask).reinterpret_simd();
674 let wy: i16s = (vy & mask).reinterpret_simd();
675 let d = wx - wy;
676 s0 = s0.dot_simd(d, d);
677
678 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
679 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
680 let d = wx - wy;
681 s1 = s1.dot_simd(d, d);
682
683 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
684 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
685 let d = wx - wy;
686 s2 = s2.dot_simd(d, d);
687
688 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
689 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
690 let d = wx - wy;
691 s3 = s3.dot_simd(d, d);
692
693 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
694 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
695 let d = wx - wy;
696 s0 = s0.dot_simd(d, d);
697
698 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
699 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
700 let d = wx - wy;
701 s1 = s1.dot_simd(d, d);
702
703 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
704 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
705 let d = wx - wy;
706 s2 = s2.dot_simd(d, d);
707
708 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
709 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
710 let d = wx - wy;
711 s3 = s3.dot_simd(d, d);
712
713 i += remainder;
714
715 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
716 }
717
718 i *= 16;
720
721 if i != len {
723 #[inline(never)]
725 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
726 let mut s: i32 = 0;
727 for i in from..x.len() {
728 let ix = unsafe { x.get_unchecked(i) } as i32;
730 let iy = unsafe { y.get_unchecked(i) } as i32;
732 let d = ix - iy;
733 s += d * d;
734 }
735 s as u32
736 }
737 s += fallback(x, y, i);
738 }
739
740 Ok(MV::new(s))
741 }
742}
743
744impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
748where
749 A: Architecture,
750{
751 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
752 bitvector_op::<Self, Unsigned>(x, y)
753 }
754}
755
756macro_rules! impl_fallback_l2 {
758 ($N:literal) => {
759 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
767 #[inline(never)]
768 fn run(
769 self,
770 _: diskann_wide::arch::Scalar,
771 x: USlice<'_, $N>,
772 y: USlice<'_, $N>
773 ) -> MathematicalResult<u32> {
774 let len = check_lengths!(x, y)?;
775
776 let mut accum: i32 = 0;
777 for i in 0..len {
778 let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
780 let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
782 let diff = ix - iy;
783 accum += diff * diff;
784 }
785 Ok(MV::new(accum as u32))
786 }
787 }
788 };
789 ($($N:literal),+ $(,)?) => {
790 $(impl_fallback_l2!($N);)+
791 };
792}
793
794impl_fallback_l2!(7, 6, 5, 4, 3, 2);
795
796#[cfg(target_arch = "x86_64")]
797retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
798
799#[cfg(target_arch = "x86_64")]
800retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
801
802dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
803#[cfg(target_arch = "aarch64")]
804retarget!(
805 diskann_wide::arch::aarch64::Neon,
806 SquaredL2,
807 7,
808 6,
809 5,
810 4,
811 3,
812 2
813);
814
815impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
828where
829 A: Architecture,
830 diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
831{
832 #[inline(always)]
833 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
834 check_lengths!(x, y)?;
835 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
836 diskann_vector::distance::InnerProduct {},
837 arch,
838 x.as_slice(),
839 y.as_slice(),
840 );
841
842 Ok(MV::new(r.into_inner() as u32))
843 }
844}
845
846#[cfg(target_arch = "x86_64")]
863impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
864 for InnerProduct
865{
866 #[expect(non_camel_case_types)]
867 #[inline(always)]
868 fn run(
869 self,
870 arch: diskann_wide::arch::x86_64::V4,
871 x: USlice<'_, 2>,
872 y: USlice<'_, 2>,
873 ) -> MathematicalResult<u32> {
874 let len = check_lengths!(x, y)?;
875
876 type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
877 type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
878 type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
879 type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
880
881 let px_u32: *const u32 = x.as_ptr().cast();
882 let py_u32: *const u32 = y.as_ptr().cast();
883
884 let mut i = 0;
885 let mut s: u32 = 0;
886
887 let blocks = len.div_ceil(16);
889 if i < blocks {
890 let mut s0 = i32s::default(arch);
891 let mut s1 = i32s::default(arch);
892 let mut s2 = i32s::default(arch);
893 let mut s3 = i32s::default(arch);
894 let mask = u32s::splat(arch, 0x03030303);
895 while i + 16 < blocks {
896 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
901
902 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
906
907 let wx: u8s = (vx & mask).reinterpret_simd();
908 let wy: i8s = (vy & mask).reinterpret_simd();
909 s0 = s0.dot_simd(wx, wy);
910
911 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
912 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
913 s1 = s1.dot_simd(wx, wy);
914
915 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
916 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
917 s2 = s2.dot_simd(wx, wy);
918
919 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
920 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
921 s3 = s3.dot_simd(wx, wy);
922
923 i += 16;
924 }
925
926 let remainder = len / 4 - 4 * i;
930
931 let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
936 let vx: u32s = vx.reinterpret_simd();
937
938 let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
942 let vy: u32s = vy.reinterpret_simd();
943
944 let wx: u8s = (vx & mask).reinterpret_simd();
945 let wy: i8s = (vy & mask).reinterpret_simd();
946 s0 = s0.dot_simd(wx, wy);
947
948 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
949 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
950 s1 = s1.dot_simd(wx, wy);
951
952 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
953 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
954 s2 = s2.dot_simd(wx, wy);
955
956 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
957 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
958 s3 = s3.dot_simd(wx, wy);
959
960 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
961 i = (4 * i) + remainder;
962 }
963
964 i *= 4;
966
967 debug_assert!(len - i <= 3);
969 let rest = (len - i).min(3);
970 if i != len {
971 for j in 0..rest {
972 let ix = unsafe { x.get_unchecked(i + j) } as u32;
974 let iy = unsafe { y.get_unchecked(i + j) } as u32;
976 s += ix * iy;
977 }
978 }
979
980 Ok(MV::new(s))
981 }
982}
983
984#[cfg(target_arch = "x86_64")]
998impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
999 for InnerProduct
1000{
1001 #[inline(always)]
1002 fn run(
1003 self,
1004 arch: diskann_wide::arch::x86_64::V3,
1005 x: USlice<'_, 4>,
1006 y: USlice<'_, 4>,
1007 ) -> MathematicalResult<u32> {
1008 let len = check_lengths!(x, y)?;
1009
1010 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1011 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1012 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1013
1014 let px_u32: *const u32 = x.as_ptr().cast();
1015 let py_u32: *const u32 = y.as_ptr().cast();
1016
1017 let mut i = 0;
1018 let mut s: u32 = 0;
1019
1020 let blocks = len / 8;
1021 if i < blocks {
1022 let mut s0 = i32s::default(arch);
1023 let mut s1 = i32s::default(arch);
1024 let mut s2 = i32s::default(arch);
1025 let mut s3 = i32s::default(arch);
1026 let mask = u32s::splat(arch, 0x000f000f);
1027 while i + 8 < blocks {
1028 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1033
1034 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1038
1039 let wx: i16s = (vx & mask).reinterpret_simd();
1040 let wy: i16s = (vy & mask).reinterpret_simd();
1041 s0 = s0.dot_simd(wx, wy);
1042
1043 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1044 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1045 s1 = s1.dot_simd(wx, wy);
1046
1047 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1048 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1049 s2 = s2.dot_simd(wx, wy);
1050
1051 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1052 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1053 s3 = s3.dot_simd(wx, wy);
1054
1055 i += 8;
1056 }
1057
1058 let remainder = blocks - i;
1059
1060 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1066
1067 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1071
1072 let wx: i16s = (vx & mask).reinterpret_simd();
1073 let wy: i16s = (vy & mask).reinterpret_simd();
1074 s0 = s0.dot_simd(wx, wy);
1075
1076 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1077 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1078 s1 = s1.dot_simd(wx, wy);
1079
1080 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1081 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1082 s2 = s2.dot_simd(wx, wy);
1083
1084 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1085 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1086 s3 = s3.dot_simd(wx, wy);
1087
1088 i += remainder;
1089
1090 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1091 }
1092
1093 i *= 8;
1095
1096 if i != len {
1098 #[inline(never)]
1100 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1101 let mut s: u32 = 0;
1102 for i in from..x.len() {
1103 let ix = unsafe { x.get_unchecked(i) } as u32;
1105 let iy = unsafe { y.get_unchecked(i) } as u32;
1107 s += ix * iy;
1108 }
1109 s
1110 }
1111 s += fallback(x, y, i);
1112 }
1113
1114 Ok(MV::new(s))
1115 }
1116}
1117
1118#[cfg(target_arch = "x86_64")]
1132impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1133 for InnerProduct
1134{
1135 #[inline(always)]
1136 fn run(
1137 self,
1138 arch: diskann_wide::arch::x86_64::V3,
1139 x: USlice<'_, 2>,
1140 y: USlice<'_, 2>,
1141 ) -> MathematicalResult<u32> {
1142 let len = check_lengths!(x, y)?;
1143
1144 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1145 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1146 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1147
1148 let px_u32: *const u32 = x.as_ptr().cast();
1149 let py_u32: *const u32 = y.as_ptr().cast();
1150
1151 let mut i = 0;
1152 let mut s: u32 = 0;
1153
1154 let blocks = len / 16;
1156 if i < blocks {
1157 let mut s0 = i32s::default(arch);
1158 let mut s1 = i32s::default(arch);
1159 let mut s2 = i32s::default(arch);
1160 let mut s3 = i32s::default(arch);
1161 let mask = u32s::splat(arch, 0x00030003);
1162 while i + 8 < blocks {
1163 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1168
1169 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1173
1174 let wx: i16s = (vx & mask).reinterpret_simd();
1175 let wy: i16s = (vy & mask).reinterpret_simd();
1176 s0 = s0.dot_simd(wx, wy);
1177
1178 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1179 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1180 s1 = s1.dot_simd(wx, wy);
1181
1182 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1183 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1184 s2 = s2.dot_simd(wx, wy);
1185
1186 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1187 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1188 s3 = s3.dot_simd(wx, wy);
1189
1190 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1191 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1192 s0 = s0.dot_simd(wx, wy);
1193
1194 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1195 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1196 s1 = s1.dot_simd(wx, wy);
1197
1198 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1199 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1200 s2 = s2.dot_simd(wx, wy);
1201
1202 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1203 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1204 s3 = s3.dot_simd(wx, wy);
1205
1206 i += 8;
1207 }
1208
1209 let remainder = blocks - i;
1210
1211 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1217
1218 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1222 let wx: i16s = (vx & mask).reinterpret_simd();
1223 let wy: i16s = (vy & mask).reinterpret_simd();
1224 s0 = s0.dot_simd(wx, wy);
1225
1226 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1227 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1228 s1 = s1.dot_simd(wx, wy);
1229
1230 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1231 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1232 s2 = s2.dot_simd(wx, wy);
1233
1234 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1235 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1236 s3 = s3.dot_simd(wx, wy);
1237
1238 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1239 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1240 s0 = s0.dot_simd(wx, wy);
1241
1242 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1243 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1244 s1 = s1.dot_simd(wx, wy);
1245
1246 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1247 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1248 s2 = s2.dot_simd(wx, wy);
1249
1250 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1251 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1252 s3 = s3.dot_simd(wx, wy);
1253
1254 i += remainder;
1255
1256 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1257 }
1258
1259 i *= 16;
1261
1262 if i != len {
1264 #[inline(never)]
1266 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1267 let mut s: u32 = 0;
1268 for i in from..x.len() {
1269 let ix = unsafe { x.get_unchecked(i) } as u32;
1271 let iy = unsafe { y.get_unchecked(i) } as u32;
1273 s += ix * iy;
1274 }
1275 s
1276 }
1277 s += fallback(x, y, i);
1278 }
1279
1280 Ok(MV::new(s))
1281 }
1282}
1283
1284impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1288where
1289 A: Architecture,
1290{
1291 #[inline(always)]
1292 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1293 bitvector_op::<Self, Unsigned>(x, y)
1294 }
1295}
1296
1297macro_rules! impl_fallback_ip {
1299 (($N:literal, $M:literal)) => {
1300 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $M>> for InnerProduct {
1308 #[inline(never)]
1309 fn run(
1310 self,
1311 _: diskann_wide::arch::Scalar,
1312 x: USlice<'_, $N>,
1313 y: USlice<'_, $M>
1314 ) -> MathematicalResult<u32> {
1315 let len = check_lengths!(x, y)?;
1316
1317 let mut accum: u32 = 0;
1318 for i in 0..len {
1319 let ix = unsafe { x.get_unchecked(i) } as u32;
1321 let iy = unsafe { y.get_unchecked(i) } as u32;
1323 accum += ix * iy;
1324 }
1325 Ok(MV::new(accum))
1326 }
1327 }
1328 };
1329 ($N:literal) => {
1330 impl_fallback_ip!(($N, $N));
1331 };
1332 ($($args:tt),+ $(,)?) => {
1333 $(impl_fallback_ip!($args);)+
1334 };
1335}
1336
1337impl_fallback_ip!(7, 6, 5, 4, 3, 2, (8, 4), (8, 2), (8, 1));
1338
1339#[cfg(target_arch = "x86_64")]
1340retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1341
1342#[cfg(target_arch = "x86_64")]
1343retarget!(
1344 diskann_wide::arch::x86_64::V4,
1345 InnerProduct,
1346 7,
1347 6,
1348 5,
1349 4,
1350 3,
1351 (8, 4),
1352 (8, 2),
1353 (8, 1)
1354);
1355
1356dispatch_pure!(
1357 InnerProduct,
1358 1,
1359 2,
1360 3,
1361 4,
1362 5,
1363 6,
1364 7,
1365 (8, 8),
1366 (8, 4),
1367 (8, 2),
1368 (8, 1)
1369);
1370
1371#[cfg(target_arch = "x86_64")]
1375impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 4>>
1376 for InnerProduct
1377{
1378 #[inline(always)]
1390 fn run(
1391 self,
1392 arch: diskann_wide::arch::x86_64::V3,
1393 x: USlice<'_, 8>,
1394 y: USlice<'_, 4>,
1395 ) -> MathematicalResult<u32> {
1396 use std::arch::x86_64::_mm256_maddubs_epi16;
1397
1398 let len = check_lengths!(x, y)?;
1399
1400 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1401 diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
1402 diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1403 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1404
1405 let px: *const u8 = x.as_ptr();
1406 let py: *const u8 = y.as_ptr();
1407
1408 let mut i: usize = 0;
1409 let mut s: u32 = 0;
1410
1411 #[inline(always)]
1412 fn unpack_half(input: u8s_16) -> u8s_32 {
1413 let combined = diskann_wide::LoHi::new(input, input >> 4).zip::<u8s_32>();
1414 combined & u8s_32::splat(input.arch(), (1u8 << 4) - 1)
1415 }
1416
1417 let blocks = len / 32;
1419 if blocks > 0 {
1420 let mut acc = i32s::default(arch);
1421
1422 let products = |x: u8s_32, y: u8s_32| -> i16s {
1423 i16s::from_underlying(arch, unsafe {
1425 _mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
1426 })
1427 };
1428
1429 let ones = i16s::splat(arch, 1);
1430
1431 while i + 4 <= blocks {
1437 let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
1441 let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
1443 let m0 = products(vx, unpack_half(vy));
1444
1445 let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 1))) };
1448 let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 1))) };
1450 let m1 = products(vx, unpack_half(vy));
1451
1452 let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 2))) };
1455 let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 2))) };
1457 let m2 = products(vx, unpack_half(vy));
1458
1459 let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 3))) };
1462 let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 3))) };
1464 let m3 = products(vx, unpack_half(vy));
1465
1466 acc = acc.dot_simd(m0 + m1 + m2 + m3, ones);
1467 i += 4;
1468 }
1469
1470 while i < blocks {
1472 let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
1475 let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
1477 acc = acc.dot_simd(products(vx, unpack_half(vy)), ones);
1478 i += 1;
1479 }
1480
1481 s = acc.sum_tree() as u32;
1482 }
1483
1484 i *= 32;
1486
1487 if i != len {
1489 #[inline(never)]
1490 fn fallback(x: USlice<'_, 8>, y: USlice<'_, 4>, from: usize) -> u32 {
1491 let mut s: u32 = 0;
1492 for i in from..x.len() {
1493 let ix = unsafe { x.get_unchecked(i) } as u32;
1495 let iy = unsafe { y.get_unchecked(i) } as u32;
1497 s += ix * iy;
1498 }
1499 s
1500 }
1501 s += fallback(x, y, i);
1502 }
1503
1504 Ok(MV::new(s))
1505 }
1506}
1507
1508#[cfg(target_arch = "x86_64")]
1509impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 2>>
1510 for InnerProduct
1511{
1512 #[inline(always)]
1525 fn run(
1526 self,
1527 arch: diskann_wide::arch::x86_64::V3,
1528 x: USlice<'_, 8>,
1529 y: USlice<'_, 2>,
1530 ) -> MathematicalResult<u32> {
1531 use diskann_wide::SplitJoin;
1532 use std::arch::x86_64::_mm256_maddubs_epi16;
1533
1534 let len = check_lengths!(x, y)?;
1535
1536 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1537 diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
1538 diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1539 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1540
1541 let px: *const u8 = x.as_ptr();
1542 let py: *const u8 = y.as_ptr();
1543
1544 let mut i: usize = 0;
1545 let mut s: u32 = 0;
1546
1547 let blocks = len / 64;
1549 if blocks > 0 {
1550 let mut acc = i32s::default(arch);
1551
1552 let products = |x: u8s_32, y: u8s_32| -> i16s {
1553 i16s::from_underlying(arch, unsafe {
1555 _mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
1556 })
1557 };
1558
1559 #[inline(always)]
1560 fn unpack_sub<const N: u8>(input: u8s_16) -> u8s_32 {
1561 let combined = diskann_wide::LoHi::new(input, input >> N).zip::<u8s_32>();
1562 combined & u8s_32::splat(input.arch(), (1u8 << N) - 1)
1563 }
1564
1565 let unpack_crumbs = |x: u8s_16| -> (u8s_32, u8s_32) {
1566 let nibbles = unpack_sub::<4>(x);
1568
1569 let diskann_wide::LoHi { lo, hi } = nibbles.split();
1572 let lower = unpack_sub::<2>(lo);
1573 let upper = unpack_sub::<2>(hi);
1574
1575 (lower, upper)
1576 };
1577
1578 let ones = i16s::splat(arch, 1);
1579
1580 while i + 4 <= blocks {
1587 let (vx0, vx1, (vy0, vy1)) = unsafe {
1591 (
1592 u8s_32::load_simd(arch, px.add(64 * i)),
1593 u8s_32::load_simd(arch, px.add(64 * i + 32)),
1594 unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
1595 )
1596 };
1597 let m0a = products(vx0, vy0);
1598 let m0b = products(vx1, vy1);
1599
1600 let (vx0, vx1, (vy0, vy1)) = unsafe {
1603 (
1604 u8s_32::load_simd(arch, px.add(64 * (i + 1))),
1605 u8s_32::load_simd(arch, px.add(64 * (i + 1) + 32)),
1606 unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 1)))),
1607 )
1608 };
1609 let m1a = products(vx0, vy0);
1610 let m1b = products(vx1, vy1);
1611
1612 let (vx0, vx1, (vy0, vy1)) = unsafe {
1615 (
1616 u8s_32::load_simd(arch, px.add(64 * (i + 2))),
1617 u8s_32::load_simd(arch, px.add(64 * (i + 2) + 32)),
1618 unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 2)))),
1619 )
1620 };
1621 let m2a = products(vx0, vy0);
1622 let m2b = products(vx1, vy1);
1623
1624 let (vx0, vx1, (vy0, vy1)) = unsafe {
1627 (
1628 u8s_32::load_simd(arch, px.add(64 * (i + 3))),
1629 u8s_32::load_simd(arch, px.add(64 * (i + 3) + 32)),
1630 unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 3)))),
1631 )
1632 };
1633 let m3a = products(vx0, vy0);
1634 let m3b = products(vx1, vy1);
1635
1636 acc = acc.dot_simd((m0a + m0b + m1a + m1b) + (m2a + m2b + m3a + m3b), ones);
1637 i += 4;
1638 }
1639
1640 while i < blocks {
1642 let (vx0, vx1, (vy0, vy1)) = unsafe {
1645 (
1646 u8s_32::load_simd(arch, px.add(64 * i)),
1647 u8s_32::load_simd(arch, px.add(64 * i + 32)),
1648 unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
1649 )
1650 };
1651 acc = acc.dot_simd(products(vx0, vy0) + products(vx1, vy1), ones);
1652 i += 1;
1653 }
1654
1655 s = acc.sum_tree() as u32;
1656 }
1657
1658 i *= 64;
1660
1661 if i != len {
1663 #[inline(never)]
1664 fn fallback(x: USlice<'_, 8>, y: USlice<'_, 2>, from: usize) -> u32 {
1665 let mut s: u32 = 0;
1666 for i in from..x.len() {
1667 let (ix, iy) =
1669 unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
1670 s += ix * iy;
1671 }
1672 s
1673 }
1674 s += fallback(x, y, i);
1675 }
1676
1677 Ok(MV::new(s))
1678 }
1679}
1680
1681#[cfg(target_arch = "x86_64")]
1682impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 1>>
1683 for InnerProduct
1684{
1685 #[inline(always)]
1699 fn run(
1700 self,
1701 arch: diskann_wide::arch::x86_64::V3,
1702 x: USlice<'_, 8>,
1703 y: USlice<'_, 1>,
1704 ) -> MathematicalResult<u32> {
1705 use diskann_wide::{FromInt, SIMDMask};
1706 use std::arch::x86_64::_mm256_sad_epu8;
1707
1708 let len = check_lengths!(x, y)?;
1709
1710 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1711 diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1712
1713 type Mask32 = diskann_wide::BitMask<32, diskann_wide::arch::x86_64::V3>;
1714 type Mask8x32 = diskann_wide::arch::x86_64::v3::masks::mask8x32;
1715
1716 let px: *const u8 = x.as_ptr();
1717 let py: *const u8 = y.as_ptr();
1718
1719 let mut i: usize = 0;
1720 let mut s: u32 = 0;
1721
1722 let blocks = len / 32;
1724 if blocks > 0 {
1725 let mut acc = i32s::default(arch);
1726 let zero = u8s_32::default(arch);
1727
1728 let masked_sad = |vx: u8s_32, bits: u32| -> i32s {
1734 let byte_mask: Mask8x32 = Mask32::from_int(arch, bits).into();
1736
1737 let masked = vx & u8s_32::from_underlying(arch, byte_mask.to_underlying());
1739
1740 i32s::from_underlying(arch, unsafe {
1743 _mm256_sad_epu8(masked.to_underlying(), zero.to_underlying())
1744 })
1745 };
1746
1747 while i + 4 <= blocks {
1749 let s0 = unsafe {
1752 let vx = u8s_32::load_simd(arch, px.add(32 * i));
1753 let bits = (py.add(4 * i) as *const u32).read_unaligned();
1754 masked_sad(vx, bits)
1755 };
1756
1757 let s1 = unsafe {
1759 let vx = u8s_32::load_simd(arch, px.add(32 * (i + 1)));
1760 let bits = (py.add(4 * (i + 1)) as *const u32).read_unaligned();
1761 masked_sad(vx, bits)
1762 };
1763
1764 let s2 = unsafe {
1766 let vx = u8s_32::load_simd(arch, px.add(32 * (i + 2)));
1767 let bits = (py.add(4 * (i + 2)) as *const u32).read_unaligned();
1768 masked_sad(vx, bits)
1769 };
1770
1771 let s3 = unsafe {
1773 let vx = u8s_32::load_simd(arch, px.add(32 * (i + 3)));
1774 let bits = (py.add(4 * (i + 3)) as *const u32).read_unaligned();
1775 masked_sad(vx, bits)
1776 };
1777
1778 acc = acc + s0 + s1 + s2 + s3;
1779 i += 4;
1780 }
1781
1782 while i < blocks {
1784 let si = unsafe {
1787 let vx = u8s_32::load_simd(arch, px.add(32 * i));
1788 let bits = (py.add(4 * i) as *const u32).read_unaligned();
1789 masked_sad(vx, bits)
1790 };
1791 acc = acc + si;
1792 i += 1;
1793 }
1794
1795 s = acc.sum_tree() as u32;
1796 }
1797
1798 i *= 32;
1800
1801 if i != len {
1803 #[inline(never)]
1804 fn fallback(x: USlice<'_, 8>, y: USlice<'_, 1>, from: usize) -> u32 {
1805 let mut s: u32 = 0;
1806 for i in from..x.len() {
1807 let (ix, iy) =
1809 unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
1810 s += ix * iy;
1811 }
1812 s
1813 }
1814 s += fallback(x, y, i);
1815 }
1816
1817 Ok(MV::new(s))
1818 }
1819}
1820
1821#[cfg(target_arch = "aarch64")]
1822retarget!(
1823 diskann_wide::arch::aarch64::Neon,
1824 InnerProduct,
1825 7,
1826 6,
1827 5,
1828 4,
1829 3,
1830 2,
1831 (8, 4),
1832 (8, 2),
1833 (8, 1)
1834);
1835
1836impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1865 for InnerProduct
1866where
1867 A: Architecture,
1868{
1869 #[inline(always)]
1870 fn run(
1871 self,
1872 _: A,
1873 x: USlice<'_, 4, BitTranspose>,
1874 y: USlice<'_, 1, Dense>,
1875 ) -> MathematicalResult<u32> {
1876 let len = check_lengths!(x, y)?;
1877
1878 let px: *const u64 = x.as_ptr().cast();
1885 let py: *const u64 = y.as_ptr().cast();
1886
1887 let mut i = 0;
1888 let mut s: u32 = 0;
1889
1890 let blocks = len / 64;
1891 while i < blocks {
1892 let bits = unsafe { py.add(i).read_unaligned() };
1894
1895 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1901 s += (bits & b0).count_ones();
1902
1903 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1905 s += (bits & b1).count_ones() << 1;
1906
1907 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1909 s += (bits & b2).count_ones() << 2;
1910
1911 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1913 s += (bits & b3).count_ones() << 3;
1914
1915 i += 1;
1916 }
1917
1918 if 64 * i == len {
1920 return Ok(MV::new(s));
1921 }
1922
1923 let k = i * 8;
1925
1926 let py = unsafe { py.cast::<u8>().add(k) };
1931 let bytes_remaining = y.bytes() - k;
1932 let mut bits: u64 = 0;
1933
1934 for j in 0..bytes_remaining.min(8) {
1937 bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1940 }
1941
1942 bits &= (0x01u64 << (len - (64 * i))) - 1;
1945
1946 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1951 s += (bits & b0).count_ones();
1952
1953 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1955 s += (bits & b1).count_ones() << 1;
1956
1957 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1959 s += (bits & b2).count_ones() << 2;
1960
1961 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1963 s += (bits & b3).count_ones() << 3;
1964
1965 Ok(MV::new(s))
1966 }
1967}
1968
1969impl
1970 PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1971 for InnerProduct
1972{
1973 fn evaluate(
1974 x: USlice<'_, 4, BitTranspose>,
1975 y: USlice<'_, 1, Dense>,
1976 ) -> MathematicalResult<u32> {
1977 (diskann_wide::ARCH).run2(Self, x, y)
1978 }
1979}
1980
1981#[cfg(target_arch = "x86_64")]
2012impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
2013 for InnerProduct
2014{
2015 #[inline(always)]
2016 fn run(
2017 self,
2018 arch: diskann_wide::arch::x86_64::V3,
2019 x: &[f32],
2020 y: USlice<'_, 1>,
2021 ) -> MathematicalResult<f32> {
2022 let len = check_lengths!(x, y)?;
2023
2024 use std::arch::x86_64::*;
2025
2026 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2027 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
2028
2029 let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
2032
2033 let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
2035
2036 let px: *const f32 = x.as_ptr();
2037 let py: *const u32 = y.as_ptr().cast();
2038
2039 let mut i = 0;
2040 let mut s = f32s::default(arch);
2041
2042 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
2043 let to_f32 = |v: u32s| -> f32s {
2044 f32s::from_underlying(arch, unsafe {
2047 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
2048 })
2049 };
2050
2051 let blocks = len / 32;
2053 if i < blocks {
2054 let mut s0 = f32s::default(arch);
2055 let mut s1 = f32s::default(arch);
2056
2057 while i < blocks {
2058 let iy = prep(unsafe { py.add(i).read_unaligned() });
2060
2061 let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
2063 let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
2065 let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
2067 let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
2069
2070 s0 = ix0.mul_add_simd(to_f32(iy), s0);
2071 s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
2072 s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
2073 s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
2074
2075 i += 1;
2076 }
2077 s = s0 + s1;
2078 }
2079
2080 let remainder = len % 32;
2081 if remainder != 0 {
2082 let tail = if len % 8 == 0 { 8 } else { len % 8 };
2083
2084 let py = unsafe { py.add(blocks) };
2087
2088 if remainder <= 8 {
2089 unsafe {
2092 load_one(py, |iy| {
2093 let iy = prep(iy);
2094 let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
2095 s = ix.mul_add_simd(to_f32(iy), s);
2096 })
2097 }
2098 } else if remainder <= 16 {
2099 unsafe {
2102 load_two(py, |iy| {
2103 let iy = prep(iy);
2104 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2105 let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
2106 s = ix0.mul_add_simd(to_f32(iy), s);
2107 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2108 })
2109 }
2110 } else if remainder <= 24 {
2111 unsafe {
2114 load_three(py, |iy| {
2115 let iy = prep(iy);
2116
2117 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2118 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
2119 let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
2120
2121 s = ix0.mul_add_simd(to_f32(iy), s);
2122 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2123 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
2124 })
2125 }
2126 } else {
2127 unsafe {
2130 load_four(py, |iy| {
2131 let iy = prep(iy);
2132
2133 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2134 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
2135 let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
2136 let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
2137
2138 s = ix0.mul_add_simd(to_f32(iy), s);
2139 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2140 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
2141 s = ix3.mul_add_simd(to_f32(iy >> 24), s);
2142 })
2143 }
2144 }
2145 }
2146
2147 Ok(MV::new(s.sum_tree()))
2148 }
2149}
2150
2151#[cfg(target_arch = "x86_64")]
2155impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
2156 for InnerProduct
2157{
2158 #[inline(always)]
2159 fn run(
2160 self,
2161 arch: diskann_wide::arch::x86_64::V3,
2162 x: &[f32],
2163 y: USlice<'_, 2>,
2164 ) -> MathematicalResult<f32> {
2165 let len = check_lengths!(x, y)?;
2166
2167 use std::arch::x86_64::*;
2168
2169 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2170 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
2171
2172 let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
2176
2177 let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
2179
2180 let px: *const f32 = x.as_ptr();
2181 let py: *const u32 = y.as_ptr().cast();
2182
2183 let mut i = 0;
2184 let mut s = f32s::default(arch);
2185
2186 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
2187 let to_f32 = |v: u32s| -> f32s {
2188 f32s::from_underlying(arch, unsafe {
2191 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
2192 })
2193 };
2194
2195 let blocks = len / 16;
2196 if blocks != 0 {
2197 let mut s0 = f32s::default(arch);
2198 let mut s1 = f32s::default(arch);
2199
2200 while i + 2 <= blocks {
2202 let iy = prep(unsafe { py.add(i).read_unaligned() });
2205
2206 let (ix0, ix1) = unsafe {
2209 (
2210 f32s::load_simd(arch, px.add(16 * i)),
2211 f32s::load_simd(arch, px.add(16 * i + 8)),
2212 )
2213 };
2214
2215 s0 = ix0.mul_add_simd(to_f32(iy), s0);
2216 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2217
2218 let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
2221
2222 let (ix0, ix1) = unsafe {
2224 (
2225 f32s::load_simd(arch, px.add(16 * (i + 1))),
2226 f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
2227 )
2228 };
2229
2230 s0 = ix0.mul_add_simd(to_f32(iy), s0);
2231 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2232
2233 i += 2;
2234 }
2235
2236 if i < blocks {
2238 let iy = prep(unsafe { py.add(i).read_unaligned() });
2241
2242 let (ix0, ix1) = unsafe {
2244 (
2245 f32s::load_simd(arch, px.add(16 * i)),
2246 f32s::load_simd(arch, px.add(16 * i + 8)),
2247 )
2248 };
2249
2250 s0 = ix0.mul_add_simd(to_f32(iy), s0);
2251 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2252 }
2253
2254 s = s0 + s1;
2255 }
2256
2257 let remainder = len % 16;
2258 if remainder != 0 {
2259 let tail = if len % 8 == 0 { 8 } else { len % 8 };
2260 let py = unsafe { py.add(blocks) };
2263
2264 if remainder <= 4 {
2265 unsafe {
2268 load_one(py, |iy| {
2269 let iy = prep(iy);
2270 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
2271 s = ix.mul_add_simd(to_f32(iy), s);
2272 });
2273 }
2274 } else if remainder <= 8 {
2275 unsafe {
2278 load_two(py, |iy| {
2279 let iy = prep(iy);
2280 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
2281 s = ix.mul_add_simd(to_f32(iy), s);
2282 });
2283 }
2284 } else if remainder <= 12 {
2285 unsafe {
2288 load_three(py, |iy| {
2289 let iy = prep(iy);
2290 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
2291 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
2292 s = ix0.mul_add_simd(to_f32(iy), s);
2293 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
2294 });
2295 }
2296 } else {
2297 unsafe {
2300 load_four(py, |iy| {
2301 let iy = prep(iy);
2302 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
2303 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
2304 s = ix0.mul_add_simd(to_f32(iy), s);
2305 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
2306 });
2307 }
2308 }
2309 }
2310
2311 Ok(MV::new(s.sum_tree()))
2312 }
2313}
2314
2315#[cfg(target_arch = "x86_64")]
2320impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
2321 for InnerProduct
2322{
2323 #[inline(always)]
2324 fn run(
2325 self,
2326 arch: diskann_wide::arch::x86_64::V3,
2327 x: &[f32],
2328 y: USlice<'_, 4>,
2329 ) -> MathematicalResult<f32> {
2330 let len = check_lengths!(x, y)?;
2331
2332 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2333 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
2334
2335 let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
2336 let mask = i32s::splat(arch, 0x0f);
2337
2338 let to_f32 = |v: u32| -> f32s {
2339 ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
2340 };
2341
2342 let px: *const f32 = x.as_ptr();
2343 let py: *const u32 = y.as_ptr().cast();
2344
2345 let mut i = 0;
2346 let mut s = f32s::default(arch);
2347
2348 let blocks = len / 8;
2349 while i < blocks {
2350 let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
2352 let iy = to_f32(unsafe { py.add(i).read_unaligned() });
2354 s = ix.mul_add_simd(iy, s);
2355
2356 i += 1;
2357 }
2358
2359 let remainder = len % 8;
2360 if remainder != 0 {
2361 let f = |iy| {
2362 let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
2366 s = ix.mul_add_simd(to_f32(iy), s);
2367 };
2368
2369 let py = unsafe { py.add(blocks) };
2372
2373 if remainder <= 2 {
2374 unsafe { load_one(py, f) };
2376 } else if remainder <= 4 {
2377 unsafe { load_two(py, f) };
2379 } else if remainder <= 6 {
2380 unsafe { load_three(py, f) };
2382 } else {
2383 unsafe { load_four(py, f) };
2385 }
2386 }
2387
2388 Ok(MV::new(s.sum_tree()))
2389 }
2390}
2391
2392impl<const N: usize>
2393 Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
2394 for InnerProduct
2395where
2396 Unsigned: Representation<N>,
2397{
2398 #[inline(always)]
2401 fn run(
2402 self,
2403 _: diskann_wide::arch::Scalar,
2404 x: &[f32],
2405 y: USlice<'_, N>,
2406 ) -> MathematicalResult<f32> {
2407 check_lengths!(x, y)?;
2408
2409 let mut s = 0.0;
2410 for (i, x) in x.iter().enumerate() {
2411 let y = unsafe { y.get_unchecked(i) } as f32;
2414 s += x * y;
2415 }
2416
2417 Ok(MV::new(s))
2418 }
2419}
2420
2421macro_rules! ip_retarget {
2423 ($arch:path, $N:literal) => {
2424 impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
2425 for InnerProduct
2426 {
2427 #[inline(always)]
2428 fn run(
2429 self,
2430 arch: $arch,
2431 x: &[f32],
2432 y: USlice<'_, $N>,
2433 ) -> MathematicalResult<f32> {
2434 self.run(arch.retarget(), x, y)
2435 }
2436 }
2437 };
2438 ($arch:path, $($Ns:literal),*) => {
2439 $(ip_retarget!($arch, $Ns);)*
2440 }
2441}
2442
2443#[cfg(target_arch = "x86_64")]
2444ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
2445
2446#[cfg(target_arch = "x86_64")]
2447ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
2448
2449#[cfg(target_arch = "aarch64")]
2450ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8);
2451
2452macro_rules! dispatch_full_ip {
2455 ($N:literal) => {
2456 impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
2460 for InnerProduct
2461 {
2462 fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
2463 Self.run(ARCH, x, y)
2464 }
2465 }
2466 };
2467 ($($Ns:literal),*) => {
2468 $(dispatch_full_ip!($Ns);)*
2469 }
2470}
2471
2472dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
2473
2474#[cfg(test)]
2479mod tests {
2480 use std::{collections::HashMap, fmt::Display, sync::LazyLock};
2481
2482 use diskann_utils::{Reborrow, lazy_format};
2483 use rand::{
2484 Rng, SeedableRng,
2485 distr::{Distribution, Uniform},
2486 rngs::StdRng,
2487 seq::IndexedRandom,
2488 };
2489
2490 use super::*;
2491 use crate::bits::{BoxedBitSlice, Representation, Unsigned};
2492
2493 type MR = MathematicalResult<u32>;
2494
2495 #[inline(always)]
2496 fn should_check_this_dimension(dim: usize) -> bool {
2497 if cfg!(miri) {
2498 return dim.is_power_of_two()
2499 || (dim > 1 && (dim - 1).is_power_of_two())
2500 || (dim < 64 && (dim % 8 == 7));
2501 }
2502
2503 true
2504 }
2505
2506 fn test_bitslice_distances<const NBITS: usize, R>(
2516 dim_max: usize,
2517 trials_per_dim: usize,
2518 evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2519 evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2520 context: &str,
2521 rng: &mut R,
2522 ) where
2523 Unsigned: Representation<NBITS>,
2524 R: Rng,
2525 {
2526 let domain = Unsigned::domain_const::<NBITS>();
2527 let min: i64 = *domain.start();
2528 let max: i64 = *domain.end();
2529
2530 let dist = Uniform::new_inclusive(min, max).unwrap();
2531
2532 for dim in 0..dim_max {
2533 if !should_check_this_dimension(dim) {
2534 continue;
2535 }
2536
2537 let mut x_reference: Vec<u8> = vec![0; dim];
2538 let mut y_reference: Vec<u8> = vec![0; dim];
2539
2540 let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2541 let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2542
2543 for trial in 0..trials_per_dim {
2544 x_reference
2545 .iter_mut()
2546 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2547 y_reference
2548 .iter_mut()
2549 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2550
2551 x.as_mut_slice().fill(u8::MAX);
2554 y.as_mut_slice().fill(u8::MAX);
2555
2556 for i in 0..dim {
2557 x.set(i, x_reference[i].into()).unwrap();
2558 y.set(i, y_reference[i].into()).unwrap();
2559 }
2560
2561 let expected: MV<f32> =
2563 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2564
2565 let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2566
2567 assert_eq!(
2569 expected.into_inner(),
2570 got.into_inner() as f32,
2571 "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2572 NBITS,
2573 dim,
2574 trial,
2575 context,
2576 );
2577
2578 let expected: MV<f32> =
2580 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2581
2582 let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2583
2584 assert_eq!(
2586 expected.into_inner(),
2587 got.into_inner() as f32,
2588 "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2589 NBITS,
2590 dim,
2591 trial,
2592 context,
2593 );
2594 }
2595 }
2596
2597 let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2599 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2600
2601 assert!(
2602 evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2603 "context: {}",
2604 context
2605 );
2606 assert!(
2607 evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2608 "context: {}",
2609 context
2610 );
2611
2612 assert!(
2613 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2614 "context: {}",
2615 context
2616 );
2617 assert!(
2618 evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2619 "context: {}",
2620 context
2621 );
2622 }
2623
2624 cfg_if::cfg_if! {
2625 if #[cfg(miri)] {
2626 const MAX_DIM: usize = 132;
2627 const TRIALS_PER_DIM: usize = 1;
2628 } else {
2629 const MAX_DIM: usize = 256;
2630 const TRIALS_PER_DIM: usize = 20;
2631 }
2632 }
2633
2634 static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2644 use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4};
2645 [
2646 (Key::new(1, Scalar), Bounds::new(64, 64)),
2647 (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2648 (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2649 (Key::new(1, Neon), Bounds::new(64, 64)),
2650 (Key::new(2, Scalar), Bounds::new(64, 64)),
2651 (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2653 (Key::new(2, X86_64_V4), Bounds::new(768, 600)), (Key::new(2, Neon), Bounds::new(64, 64)),
2655 (Key::new(3, Scalar), Bounds::new(64, 64)),
2656 (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2657 (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2658 (Key::new(3, Neon), Bounds::new(64, 64)),
2659 (Key::new(4, Scalar), Bounds::new(64, 64)),
2660 (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2662 (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2663 (Key::new(4, Neon), Bounds::new(64, 64)),
2664 (Key::new(5, Scalar), Bounds::new(64, 64)),
2665 (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2666 (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2667 (Key::new(5, Neon), Bounds::new(64, 64)),
2668 (Key::new(6, Scalar), Bounds::new(64, 64)),
2669 (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2670 (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2671 (Key::new(6, Neon), Bounds::new(64, 64)),
2672 (Key::new(7, Scalar), Bounds::new(64, 64)),
2673 (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2674 (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2675 (Key::new(7, Neon), Bounds::new(64, 64)),
2676 (Key::new(8, Scalar), Bounds::new(64, 64)),
2677 (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2678 (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2679 (Key::new(8, Neon), Bounds::new(64, 64)),
2680 ]
2681 .into_iter()
2682 .collect()
2683 });
2684
2685 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2686 enum ArchKey {
2687 Scalar,
2688 #[expect(non_camel_case_types)]
2689 X86_64_V3,
2690 #[expect(non_camel_case_types)]
2691 X86_64_V4,
2692 Neon,
2693 }
2694
2695 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2696 struct Key {
2697 nbits: usize,
2698 arch: ArchKey,
2699 }
2700
2701 impl Key {
2702 fn new(nbits: usize, arch: ArchKey) -> Self {
2703 Self { nbits, arch }
2704 }
2705 }
2706
2707 #[derive(Debug, Clone, Copy)]
2708 struct Bounds {
2709 standard: usize,
2710 miri: usize,
2711 }
2712
2713 impl Bounds {
2714 fn new(standard: usize, miri: usize) -> Self {
2715 Self { standard, miri }
2716 }
2717
2718 fn get(&self) -> usize {
2719 if cfg!(miri) { self.miri } else { self.standard }
2720 }
2721 }
2722
2723 macro_rules! test_bitslice {
2724 ($name:ident, $nbits:literal, $seed:literal) => {
2725 #[test]
2726 fn $name() {
2727 let mut rng = StdRng::seed_from_u64($seed);
2728
2729 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2730
2731 test_bitslice_distances::<$nbits, _>(
2732 max_dim,
2733 TRIALS_PER_DIM,
2734 &|x, y| SquaredL2::evaluate(x, y),
2735 &|x, y| InnerProduct::evaluate(x, y),
2736 "pure distance function",
2737 &mut rng,
2738 );
2739
2740 test_bitslice_distances::<$nbits, _>(
2741 max_dim,
2742 TRIALS_PER_DIM,
2743 &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2744 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2745 "scalar arch",
2746 &mut rng,
2747 );
2748
2749 #[cfg(target_arch = "x86_64")]
2751 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2752 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2753 test_bitslice_distances::<$nbits, _>(
2754 max_dim,
2755 TRIALS_PER_DIM,
2756 &|x, y| arch.run2(SquaredL2, x, y),
2757 &|x, y| arch.run2(InnerProduct, x, y),
2758 "x86-64-v3",
2759 &mut rng,
2760 );
2761 }
2762
2763 #[cfg(target_arch = "x86_64")]
2764 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2765 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2766 test_bitslice_distances::<$nbits, _>(
2767 max_dim,
2768 TRIALS_PER_DIM,
2769 &|x, y| arch.run2(SquaredL2, x, y),
2770 &|x, y| arch.run2(InnerProduct, x, y),
2771 "x86-64-v4",
2772 &mut rng,
2773 );
2774 }
2775
2776 #[cfg(target_arch = "aarch64")]
2777 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2778 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get();
2779 test_bitslice_distances::<$nbits, _>(
2780 max_dim,
2781 TRIALS_PER_DIM,
2782 &|x, y| arch.run2(SquaredL2, x, y),
2783 &|x, y| arch.run2(InnerProduct, x, y),
2784 "neon",
2785 &mut rng,
2786 );
2787 }
2788 }
2789 };
2790 }
2791
2792 test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2793 test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2794 test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2795 test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2796 test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2797 test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2798 test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2799 test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2800
2801 fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2806 where
2807 R: Rng,
2808 {
2809 let dist: [i8; 2] = [-1, 1];
2810
2811 for dim in 0..dim_max {
2812 if !should_check_this_dimension(dim) {
2813 continue;
2814 }
2815
2816 let mut x_reference: Vec<i8> = vec![1; dim];
2817 let mut y_reference: Vec<i8> = vec![1; dim];
2818
2819 let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2820 let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2821
2822 for _ in 0..trials_per_dim {
2823 x_reference
2824 .iter_mut()
2825 .for_each(|i| *i = *dist.choose(rng).unwrap());
2826 y_reference
2827 .iter_mut()
2828 .for_each(|i| *i = *dist.choose(rng).unwrap());
2829
2830 x.as_mut_slice().fill(u8::MAX);
2833 y.as_mut_slice().fill(u8::MAX);
2834
2835 for i in 0..dim {
2836 x.set(i, x_reference[i].into()).unwrap();
2837 y.set(i, y_reference[i].into()).unwrap();
2838 }
2839
2840 let expected: MV<f32> =
2846 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2847 let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2848 assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2849 }
2850 }
2851
2852 let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2853 let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2854 assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2855 assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2856 }
2857
2858 #[test]
2859 fn test_hamming_distance() {
2860 let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2861 test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2862 }
2863
2864 fn test_bit_transpose_distances<R>(
2869 dim_max: usize,
2870 trials_per_dim: usize,
2871 evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2872 context: &str,
2873 rng: &mut R,
2874 ) where
2875 R: Rng,
2876 {
2877 let dist_4bit = {
2878 let domain = Unsigned::domain_const::<4>();
2879 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2880 };
2881
2882 let dist_1bit = {
2883 let domain = Unsigned::domain_const::<1>();
2884 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2885 };
2886
2887 for dim in 0..dim_max {
2888 if !should_check_this_dimension(dim) {
2889 continue;
2890 }
2891
2892 let mut x_reference: Vec<u8> = vec![0; dim];
2893 let mut y_reference: Vec<u8> = vec![0; dim];
2894
2895 let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2896 let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2897
2898 for trial in 0..trials_per_dim {
2899 x_reference
2900 .iter_mut()
2901 .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2902 y_reference
2903 .iter_mut()
2904 .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2905
2906 x.as_mut_slice().fill(u8::MAX);
2908 y.as_mut_slice().fill(u8::MAX);
2909
2910 for i in 0..dim {
2911 x.set(i, x_reference[i].into()).unwrap();
2912 y.set(i, y_reference[i].into()).unwrap();
2913 }
2914
2915 let expected: MV<f32> =
2917 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2918
2919 let got = evaluate_ip(x.reborrow(), y.reborrow());
2920
2921 assert_eq!(
2923 expected.into_inner(),
2924 got.unwrap().into_inner() as f32,
2925 "faild InnerProduct for dim = {}, trial = {} -- context {}",
2926 dim,
2927 trial,
2928 context,
2929 );
2930 }
2931 }
2932
2933 let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2934 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2935 assert!(
2936 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2937 "context: {}",
2938 context
2939 );
2940
2941 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2942 assert!(
2943 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2944 "context: {}",
2945 context
2946 );
2947 }
2948
2949 #[test]
2950 fn test_bit_transpose_distance() {
2951 let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2952
2953 test_bit_transpose_distances(
2954 MAX_DIM,
2955 TRIALS_PER_DIM,
2956 &|x, y| InnerProduct::evaluate(x, y),
2957 "pure distance function",
2958 &mut rng,
2959 );
2960
2961 test_bit_transpose_distances(
2962 MAX_DIM,
2963 TRIALS_PER_DIM,
2964 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2965 "scalar",
2966 &mut rng,
2967 );
2968
2969 #[cfg(target_arch = "x86_64")]
2971 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2972 test_bit_transpose_distances(
2973 MAX_DIM,
2974 TRIALS_PER_DIM,
2975 &|x, y| arch.run2(InnerProduct, x, y),
2976 "x86-64-v3",
2977 &mut rng,
2978 );
2979 }
2980
2981 #[cfg(target_arch = "x86_64")]
2983 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2984 test_bit_transpose_distances(
2985 MAX_DIM,
2986 TRIALS_PER_DIM,
2987 &|x, y| arch.run2(InnerProduct, x, y),
2988 "x86-64-v4",
2989 &mut rng,
2990 );
2991 }
2992
2993 #[cfg(target_arch = "aarch64")]
2995 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2996 test_bit_transpose_distances(
2997 MAX_DIM,
2998 TRIALS_PER_DIM,
2999 &|x, y| arch.run2(InnerProduct, x, y),
3000 "neon",
3001 &mut rng,
3002 );
3003 }
3004 }
3005
3006 fn test_full_distances<const NBITS: usize>(
3011 dim_max: usize,
3012 trials_per_dim: usize,
3013 evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
3014 context: &str,
3015 rng: &mut impl Rng,
3016 ) where
3017 Unsigned: Representation<NBITS>,
3018 {
3019 let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
3021 let dist_bit = {
3022 let domain = Unsigned::domain_const::<NBITS>();
3023 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
3024 };
3025
3026 for dim in 0..dim_max {
3027 if !should_check_this_dimension(dim) {
3028 continue;
3029 }
3030
3031 let mut x: Vec<f32> = vec![0.0; dim];
3032
3033 let mut y_reference: Vec<u8> = vec![0; dim];
3034 let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
3035
3036 for trial in 0..trials_per_dim {
3037 x.iter_mut()
3038 .for_each(|i| *i = *dist_float.choose(rng).unwrap());
3039 y_reference
3040 .iter_mut()
3041 .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
3042
3043 y.as_mut_slice().fill(u8::MAX);
3045
3046 let mut expected = 0.0;
3047 for i in 0..dim {
3048 y.set(i, y_reference[i].into()).unwrap();
3049 expected += y_reference[i] as f32 * x[i];
3050 }
3051
3052 let got = evaluate_ip(&x, y.reborrow()).unwrap();
3054
3055 assert_eq!(
3057 expected,
3058 got.into_inner(),
3059 "faild InnerProduct for dim = {}, trial = {} -- context {}",
3060 dim,
3061 trial,
3062 context,
3063 );
3064
3065 let scalar: MV<f32> = InnerProduct
3068 .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
3069 .unwrap();
3070 assert_eq!(got.into_inner(), scalar.into_inner());
3071 }
3072 }
3073
3074 let x = vec![0.0; 10];
3076 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
3077 assert!(
3078 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
3079 "context: {}",
3080 context
3081 );
3082
3083 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
3084 assert!(
3085 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
3086 "context: {}",
3087 context
3088 );
3089 }
3090
3091 macro_rules! test_full {
3092 ($name:ident, $nbits:literal, $seed:literal) => {
3093 #[test]
3094 fn $name() {
3095 let mut rng = StdRng::seed_from_u64($seed);
3096
3097 test_full_distances::<$nbits>(
3098 MAX_DIM,
3099 TRIALS_PER_DIM,
3100 &|x, y| InnerProduct::evaluate(x, y),
3101 "pure distance function",
3102 &mut rng,
3103 );
3104
3105 test_full_distances::<$nbits>(
3106 MAX_DIM,
3107 TRIALS_PER_DIM,
3108 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
3109 "scalar",
3110 &mut rng,
3111 );
3112
3113 #[cfg(target_arch = "x86_64")]
3115 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3116 test_full_distances::<$nbits>(
3117 MAX_DIM,
3118 TRIALS_PER_DIM,
3119 &|x, y| arch.run2(InnerProduct, x, y),
3120 "x86-64-v3",
3121 &mut rng,
3122 );
3123 }
3124
3125 #[cfg(target_arch = "x86_64")]
3126 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
3127 test_full_distances::<$nbits>(
3128 MAX_DIM,
3129 TRIALS_PER_DIM,
3130 &|x, y| arch.run2(InnerProduct, x, y),
3131 "x86-64-v4",
3132 &mut rng,
3133 );
3134 }
3135
3136 #[cfg(target_arch = "aarch64")]
3137 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
3138 test_full_distances::<$nbits>(
3139 MAX_DIM,
3140 TRIALS_PER_DIM,
3141 &|x, y| arch.run2(InnerProduct, x, y),
3142 "neon",
3143 &mut rng,
3144 );
3145 }
3146 }
3147 };
3148 }
3149
3150 test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
3151 test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
3152 test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
3153 test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
3154 test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
3155 test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
3156 test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
3157 test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
3158
3159 struct HetCase<const M: usize> {
3166 x_vals: Vec<i64>,
3167 y_vals: Vec<i64>,
3168 }
3169
3170 impl<const M: usize> HetCase<M>
3171 where
3172 Unsigned: Representation<M>,
3173 {
3174 fn new(dim: usize, fill: impl FnMut(usize) -> (i64, i64)) -> Self {
3175 let (x_vals, y_vals) = (0..dim).map(fill).unzip();
3176 Self { x_vals, y_vals }
3177 }
3178
3179 fn check_with(
3180 &self,
3181 label: impl Display,
3182 evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3183 ) {
3184 let dim = self.x_vals.len();
3185 let mut x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
3186 let mut y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim);
3187 x.as_mut_slice().fill(u8::MAX);
3189 y.as_mut_slice().fill(u8::MAX);
3190 for (i, (&xv, &yv)) in self.x_vals.iter().zip(&self.y_vals).enumerate() {
3191 x.set(i, xv).unwrap();
3192 y.set(i, yv).unwrap();
3193 }
3194 let expected: u32 = self
3195 .x_vals
3196 .iter()
3197 .zip(&self.y_vals)
3198 .map(|(&a, &b)| a as u32 * b as u32)
3199 .sum();
3200 let got = evaluate(x.reborrow(), y.reborrow()).unwrap().into_inner();
3201 assert_eq!(expected, got, "{} failed for dim = {}", label, dim);
3202 }
3203 }
3204
3205 fn fuzz_heterogeneous_ip<const M: usize>(
3207 dim_max: usize,
3208 trials_per_dim: usize,
3209 max_val: i64,
3210 evaluate_ip: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3211 context: &str,
3212 rng: &mut impl Rng,
3213 ) where
3214 Unsigned: Representation<M>,
3215 {
3216 let dist_8bit = Uniform::new_inclusive(0i64, 255i64).unwrap();
3217 let dist_mbit = Uniform::new_inclusive(0i64, max_val).unwrap();
3218
3219 for dim in 0..dim_max {
3220 for trial in 0..trials_per_dim {
3221 HetCase::<M>::new(dim, |_| {
3222 (dist_8bit.sample(&mut *rng), dist_mbit.sample(&mut *rng))
3223 })
3224 .check_with(
3225 lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M),
3226 evaluate_ip,
3227 );
3228 }
3229
3230 let x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
3232 let y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim + 1);
3233 assert!(
3234 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
3235 "context: {}",
3236 context,
3237 );
3238 }
3239 }
3240
3241 fn het_test_max_values<const M: usize>(
3244 max_val: i64,
3245 context: &str,
3246 evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3247 ) where
3248 Unsigned: Representation<M>,
3249 {
3250 let dims = [127, 128, 129, 255, 256, 512, 768, 896, 3072];
3251 for &dim in &dims {
3252 let case = HetCase::<M>::new(dim, |_| (255, max_val));
3253 case.check_with(lazy_format!("max-value {context} dim={dim}"), evaluate);
3254 }
3255 }
3256
3257 fn het_test_known_answers<const M: usize>(
3259 max_val: i64,
3260 evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3261 ) where
3262 Unsigned: Representation<M>,
3263 {
3264 HetCase::<M>::new(64, |_| (200, max_val)).check_with("vpmaddubsw operand-order", evaluate);
3267
3268 let y_val = (max_val / 2).max(1);
3270 HetCase::<M>::new(128, |i| ((i % 256) as i64, y_val))
3271 .check_with("ascending-x constant-y", evaluate);
3272
3273 HetCase::<M>::new(1, |_| (200, max_val)).check_with("single element", evaluate);
3275 }
3276
3277 fn het_test_edge_cases<const M: usize>(
3279 max_val: i64,
3280 block_size: usize,
3281 evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3282 ) where
3283 Unsigned: Representation<M>,
3284 {
3285 let y_half = (max_val / 2).max(1);
3286
3287 HetCase::<M>::new(64, |_| (0, max_val)).check_with("x-zero y-nonzero", evaluate);
3289 HetCase::<M>::new(64, |_| (255, 0)).check_with("y-zero x-nonzero", evaluate);
3290
3291 for dim in 0..=(block_size + 1) {
3293 HetCase::<M>::new(dim, |_| (3, y_half)).check_with("uniform fill", evaluate);
3294 }
3295
3296 for &dim in &[block_size, 2 * block_size, 4 * block_size, 8 * block_size] {
3298 HetCase::<M>::new(dim, |_| (100, max_val)).check_with("exact block boundary", evaluate);
3299 }
3300
3301 HetCase::<M>::new(300, |i| ((i % 256) as i64, 1))
3303 .check_with("x-varies y-constant", evaluate);
3304
3305 HetCase::<M>::new(300, |i| (1, (i as i64) % (max_val + 1)))
3307 .check_with("x-constant y-varies", evaluate);
3308
3309 HetCase::<M>::new(128, |i| if i % 2 == 0 { (255, max_val) } else { (0, 0) })
3311 .check_with("alternating pattern", evaluate);
3312
3313 HetCase::<M>::new(128, |i| if i % 2 == 0 { (0, 0) } else { (255, max_val) })
3315 .check_with("opposite alternating", evaluate);
3316
3317 HetCase::<M>::new(1024, |_| (255, max_val)).check_with("large accumulation", evaluate);
3319
3320 for x_val in [128i64, 170, 200, 240, 255] {
3322 HetCase::<M>::new(block_size, move |_| (x_val, y_half))
3323 .check_with(lazy_format!("x > 127 (x_val={x_val})"), evaluate);
3324 }
3325
3326 HetCase::<M>::new(block_size - 1, |i| {
3328 (
3329 ((i * 7 + 3) % 256) as i64,
3330 ((i * 11 + 5) as i64) % (max_val + 1),
3331 )
3332 })
3333 .check_with("dim=block_size-1 (all scalar)", evaluate);
3334
3335 let unroll4 = 4 * block_size;
3337 for &dim in &[
3338 unroll4,
3339 unroll4 + 1,
3340 unroll4 + block_size,
3341 unroll4 + block_size + 1,
3342 ] {
3343 HetCase::<M>::new(dim, |i| {
3344 (((i + 1) % 256) as i64, ((i + 1) as i64) % (max_val + 1))
3345 })
3346 .check_with("unroll boundary", evaluate);
3347 }
3348 }
3349
3350 macro_rules! heterogeneous_ip_tests_8xM {
3351 (
3352 mod_name: $mod:ident,
3353 M: $M:literal,
3354 max_val: $max_val:literal,
3355 block_size: $block_size:literal,
3356 seed_fuzz: $seed_fuzz:literal,
3357 ) => {
3358 mod $mod {
3359 use super::*;
3360
3361 #[test]
3362 fn all_ip_dispatches() {
3363 let mut rng = StdRng::seed_from_u64($seed_fuzz);
3364
3365 fuzz_heterogeneous_ip::<$M>(
3366 MAX_DIM,
3367 TRIALS_PER_DIM,
3368 $max_val,
3369 &|x, y| InnerProduct::evaluate(x, y),
3370 "pure distance function",
3371 &mut rng,
3372 );
3373 fuzz_heterogeneous_ip::<$M>(
3374 MAX_DIM,
3375 TRIALS_PER_DIM,
3376 $max_val,
3377 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
3378 "scalar arch",
3379 &mut rng,
3380 );
3381 #[cfg(target_arch = "x86_64")]
3382 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3383 fuzz_heterogeneous_ip::<$M>(
3384 MAX_DIM,
3385 TRIALS_PER_DIM,
3386 $max_val,
3387 &|x, y| arch.run2(InnerProduct, x, y),
3388 "x86-64-v3",
3389 &mut rng,
3390 );
3391 }
3392 #[cfg(target_arch = "x86_64")]
3393 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
3394 fuzz_heterogeneous_ip::<$M>(
3395 MAX_DIM,
3396 TRIALS_PER_DIM,
3397 $max_val,
3398 &|x, y| arch.run2(InnerProduct, x, y),
3399 "x86-64-v4",
3400 &mut rng,
3401 );
3402 }
3403 }
3404
3405 #[test]
3406 fn max_values() {
3407 het_test_max_values::<$M>($max_val, "dispatch", &|x, y| {
3408 InnerProduct::evaluate(x, y)
3409 });
3410 #[cfg(target_arch = "x86_64")]
3411 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3412 het_test_max_values::<$M>($max_val, "V3", &|x, y| {
3413 arch.run2(InnerProduct, x, y)
3414 });
3415 }
3416 }
3417
3418 #[test]
3419 fn known_answers() {
3420 het_test_known_answers::<$M>($max_val, &|x, y| InnerProduct::evaluate(x, y));
3421 }
3422
3423 #[test]
3424 fn edge_cases() {
3425 het_test_edge_cases::<$M>($max_val, $block_size, &|x, y| {
3426 InnerProduct::evaluate(x, y)
3427 });
3428 }
3429 }
3430 };
3431 }
3432
3433 heterogeneous_ip_tests_8xM! {
3434 mod_name: heterogeneous_ip_8x4,
3435 M: 4,
3436 max_val: 15,
3437 block_size: 32,
3438 seed_fuzz: 0xd3a7f1c09b2e4856,
3439 }
3440
3441 heterogeneous_ip_tests_8xM! {
3442 mod_name: heterogeneous_ip_8x2,
3443 M: 2,
3444 max_val: 3,
3445 block_size: 64,
3446 seed_fuzz: 0x82c4a6e809f1d3b5,
3447 }
3448
3449 heterogeneous_ip_tests_8xM! {
3450 mod_name: heterogeneous_ip_8x1,
3451 M: 1,
3452 max_val: 1,
3453 block_size: 32,
3454 seed_fuzz: 0x1b17_a5e7c2d0f839,
3455 }
3456}