1use diskann_vector::PureDistanceFunction;
107use diskann_wide::{ARCH, Architecture, arch::Target2};
108#[cfg(target_arch = "x86_64")]
109use diskann_wide::{
110 SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
111};
112
113use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
114use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
115
116type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
118
119macro_rules! retarget {
125 ($arch:path, $op:ty, $N:literal) => {
126 impl Target2<
127 $arch,
128 MathematicalResult<u32>,
129 USlice<'_, $N>,
130 USlice<'_, $N>,
131 > for $op {
132 #[inline(always)]
133 fn run(
134 self,
135 arch: $arch,
136 x: USlice<'_, $N>,
137 y: USlice<'_, $N>
138 ) -> MathematicalResult<u32> {
139 self.run(arch.retarget(), x, y)
140 }
141 }
142 };
143 ($arch:path, $op:ty, $($N:literal),+ $(,)?) => {
144 $(retarget!($arch, $op, $N);)+
145 }
146}
147
148macro_rules! dispatch_pure {
150 ($op:ty, $N:literal) => {
151 impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $N>, MathematicalResult<u32>> for $op {
153 #[inline(always)]
154 fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $N>) -> MathematicalResult<u32> {
155 (diskann_wide::ARCH).run2(Self, x, y)
156 }
157 }
158 };
159 ($op:ty, $($N:literal),+ $(,)?) => {
160 $(dispatch_pure!($op, $N);)+
161 }
162}
163
164#[cfg(target_arch = "x86_64")]
171unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
172where
173 F: FnMut(u32) -> R,
174{
175 f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
177}
178
179#[cfg(target_arch = "x86_64")]
186unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
187where
188 F: FnMut(u32) -> R,
189{
190 f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
192}
193
194#[cfg(target_arch = "x86_64")]
201unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
202where
203 F: FnMut(u32) -> R,
204{
205 let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
207 let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
209 f(lo | hi << 16)
210}
211
212#[cfg(target_arch = "x86_64")]
219unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
220where
221 F: FnMut(u32) -> R,
222{
223 f(unsafe { ptr.read_unaligned() })
225}
226
227trait BitVectorOp<Repr>
239where
240 Repr: Representation<1>,
241{
242 fn on_u64(x: u64, y: u64) -> u32;
244
245 fn on_u8(x: u8, y: u8) -> u32;
250}
251
252impl BitVectorOp<Unsigned> for SquaredL2 {
254 #[inline(always)]
255 fn on_u64(x: u64, y: u64) -> u32 {
256 (x ^ y).count_ones()
257 }
258 #[inline(always)]
259 fn on_u8(x: u8, y: u8) -> u32 {
260 (x ^ y).count_ones()
261 }
262}
263
264impl BitVectorOp<Binary> for Hamming {
266 #[inline(always)]
267 fn on_u64(x: u64, y: u64) -> u32 {
268 (x ^ y).count_ones()
269 }
270 #[inline(always)]
271 fn on_u8(x: u8, y: u8) -> u32 {
272 (x ^ y).count_ones()
273 }
274}
275
276impl BitVectorOp<Unsigned> for InnerProduct {
284 #[inline(always)]
285 fn on_u64(x: u64, y: u64) -> u32 {
286 (x & y).count_ones()
287 }
288 #[inline(always)]
289 fn on_u8(x: u8, y: u8) -> u32 {
290 (x & y).count_ones()
291 }
292}
293
294#[inline(always)]
299fn bitvector_op<Op, Repr>(
300 x: BitSlice<'_, 1, Repr>,
301 y: BitSlice<'_, 1, Repr>,
302) -> MathematicalResult<u32>
303where
304 Repr: Representation<1>,
305 Op: BitVectorOp<Repr>,
306{
307 let len = check_lengths!(x, y)?;
308
309 let px: *const u64 = x.as_ptr().cast();
310 let py: *const u64 = y.as_ptr().cast();
311
312 let mut i = 0;
313 let mut s: u32 = 0;
314
315 let blocks = len / 64;
317 while i < blocks {
318 let vx = unsafe { px.add(i).read_unaligned() };
322
323 let vy = unsafe { py.add(i).read_unaligned() };
327
328 s += Op::on_u64(vx, vy);
329 i += 1;
330 }
331
332 i *= 8;
334 let px: *const u8 = x.as_ptr();
335 let py: *const u8 = y.as_ptr();
336
337 let blocks = len / 8;
338 while i < blocks {
339 let vx = unsafe { px.add(i).read_unaligned() };
342
343 let vy = unsafe { py.add(i).read_unaligned() };
347 s += Op::on_u8(vx, vy);
348 i += 1;
349 }
350
351 if i * 8 != len {
352 let vx = unsafe { px.add(i).read_unaligned() };
355
356 let vy = unsafe { py.add(i).read_unaligned() };
358 let m = (0x01u8 << (len - 8 * i)) - 1;
359
360 s += Op::on_u8(vx & m, vy & m)
361 }
362 Ok(MV::new(s))
363}
364
365impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
369 for Hamming
370{
371 fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
372 bitvector_op::<Hamming, Binary>(x, y)
373 }
374}
375
376impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
389where
390 A: Architecture,
391 diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
392{
393 #[inline(always)]
394 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
395 check_lengths!(x, y)?;
396
397 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
398 diskann_vector::distance::SquaredL2 {},
399 arch,
400 x.as_slice(),
401 y.as_slice(),
402 );
403
404 Ok(MV::new(r.into_inner() as u32))
405 }
406}
407
408#[cfg(target_arch = "x86_64")]
422impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
423 for SquaredL2
424{
425 #[inline(always)]
426 fn run(
427 self,
428 arch: diskann_wide::arch::x86_64::V3,
429 x: USlice<'_, 4>,
430 y: USlice<'_, 4>,
431 ) -> MathematicalResult<u32> {
432 let len = check_lengths!(x, y)?;
433
434 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
435 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
436 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
437
438 let px_u32: *const u32 = x.as_ptr().cast();
439 let py_u32: *const u32 = y.as_ptr().cast();
440
441 let mut i = 0;
442 let mut s: u32 = 0;
443
444 let blocks = len / 8;
446 if i < blocks {
447 let mut s0 = i32s::default(arch);
448 let mut s1 = i32s::default(arch);
449 let mut s2 = i32s::default(arch);
450 let mut s3 = i32s::default(arch);
451 let mask = u32s::splat(arch, 0x000f000f);
452 while i + 8 < blocks {
453 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
458
459 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
463
464 let wx: i16s = (vx & mask).reinterpret_simd();
465 let wy: i16s = (vy & mask).reinterpret_simd();
466 let d = wx - wy;
467 s0 = s0.dot_simd(d, d);
468
469 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
470 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
471 let d = wx - wy;
472 s1 = s1.dot_simd(d, d);
473
474 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
475 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
476 let d = wx - wy;
477 s2 = s2.dot_simd(d, d);
478
479 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
480 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
481 let d = wx - wy;
482 s3 = s3.dot_simd(d, d);
483
484 i += 8;
485 }
486
487 let remainder = blocks - i;
488
489 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
495
496 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
500
501 let wx: i16s = (vx & mask).reinterpret_simd();
502 let wy: i16s = (vy & mask).reinterpret_simd();
503 let d = wx - wy;
504 s0 = s0.dot_simd(d, d);
505
506 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
507 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
508 let d = wx - wy;
509 s1 = s1.dot_simd(d, d);
510
511 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
512 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
513 let d = wx - wy;
514 s2 = s2.dot_simd(d, d);
515
516 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
517 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
518 let d = wx - wy;
519 s3 = s3.dot_simd(d, d);
520
521 i += remainder;
522
523 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
524 }
525
526 i *= 8;
528
529 if i != len {
531 #[inline(never)]
533 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
534 let mut s: i32 = 0;
535 for i in from..x.len() {
536 let ix = unsafe { x.get_unchecked(i) } as i32;
538 let iy = unsafe { y.get_unchecked(i) } as i32;
540 let d = ix - iy;
541 s += d * d;
542 }
543 s as u32
544 }
545 s += fallback(x, y, i);
546 }
547
548 Ok(MV::new(s))
549 }
550}
551
552#[cfg(target_arch = "x86_64")]
566impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
567 for SquaredL2
568{
569 #[inline(always)]
570 fn run(
571 self,
572 arch: diskann_wide::arch::x86_64::V3,
573 x: USlice<'_, 2>,
574 y: USlice<'_, 2>,
575 ) -> MathematicalResult<u32> {
576 let len = check_lengths!(x, y)?;
577
578 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
579 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
580 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
581
582 let px_u32: *const u32 = x.as_ptr().cast();
583 let py_u32: *const u32 = y.as_ptr().cast();
584
585 let mut i = 0;
586 let mut s: u32 = 0;
587
588 let blocks = len / 16;
590 if i < blocks {
591 let mut s0 = i32s::default(arch);
592 let mut s1 = i32s::default(arch);
593 let mut s2 = i32s::default(arch);
594 let mut s3 = i32s::default(arch);
595 let mask = u32s::splat(arch, 0x00030003);
596 while i + 8 < blocks {
597 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
602
603 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
607
608 let wx: i16s = (vx & mask).reinterpret_simd();
609 let wy: i16s = (vy & mask).reinterpret_simd();
610 let d = wx - wy;
611 s0 = s0.dot_simd(d, d);
612
613 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
614 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
615 let d = wx - wy;
616 s1 = s1.dot_simd(d, d);
617
618 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
619 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
620 let d = wx - wy;
621 s2 = s2.dot_simd(d, d);
622
623 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
624 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
625 let d = wx - wy;
626 s3 = s3.dot_simd(d, d);
627
628 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
629 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
630 let d = wx - wy;
631 s0 = s0.dot_simd(d, d);
632
633 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
634 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
635 let d = wx - wy;
636 s1 = s1.dot_simd(d, d);
637
638 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
639 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
640 let d = wx - wy;
641 s2 = s2.dot_simd(d, d);
642
643 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
644 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
645 let d = wx - wy;
646 s3 = s3.dot_simd(d, d);
647
648 i += 8;
649 }
650
651 let remainder = blocks - i;
652
653 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
659
660 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
664 let wx: i16s = (vx & mask).reinterpret_simd();
665 let wy: i16s = (vy & mask).reinterpret_simd();
666 let d = wx - wy;
667 s0 = s0.dot_simd(d, d);
668
669 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
670 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
671 let d = wx - wy;
672 s1 = s1.dot_simd(d, d);
673
674 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
675 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
676 let d = wx - wy;
677 s2 = s2.dot_simd(d, d);
678
679 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
680 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
681 let d = wx - wy;
682 s3 = s3.dot_simd(d, d);
683
684 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
685 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
686 let d = wx - wy;
687 s0 = s0.dot_simd(d, d);
688
689 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
690 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
691 let d = wx - wy;
692 s1 = s1.dot_simd(d, d);
693
694 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
695 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
696 let d = wx - wy;
697 s2 = s2.dot_simd(d, d);
698
699 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
700 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
701 let d = wx - wy;
702 s3 = s3.dot_simd(d, d);
703
704 i += remainder;
705
706 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
707 }
708
709 i *= 16;
711
712 if i != len {
714 #[inline(never)]
716 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
717 let mut s: i32 = 0;
718 for i in from..x.len() {
719 let ix = unsafe { x.get_unchecked(i) } as i32;
721 let iy = unsafe { y.get_unchecked(i) } as i32;
723 let d = ix - iy;
724 s += d * d;
725 }
726 s as u32
727 }
728 s += fallback(x, y, i);
729 }
730
731 Ok(MV::new(s))
732 }
733}
734
735impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
739where
740 A: Architecture,
741{
742 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
743 bitvector_op::<Self, Unsigned>(x, y)
744 }
745}
746
747macro_rules! impl_fallback_l2 {
749 ($N:literal) => {
750 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
758 #[inline(never)]
759 fn run(
760 self,
761 _: diskann_wide::arch::Scalar,
762 x: USlice<'_, $N>,
763 y: USlice<'_, $N>
764 ) -> MathematicalResult<u32> {
765 let len = check_lengths!(x, y)?;
766
767 let mut accum: i32 = 0;
768 for i in 0..len {
769 let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
771 let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
773 let diff = ix - iy;
774 accum += diff * diff;
775 }
776 Ok(MV::new(accum as u32))
777 }
778 }
779 };
780 ($($N:literal),+ $(,)?) => {
781 $(impl_fallback_l2!($N);)+
782 };
783}
784
785impl_fallback_l2!(7, 6, 5, 4, 3, 2);
786
787#[cfg(target_arch = "x86_64")]
788retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
789
790#[cfg(target_arch = "x86_64")]
791retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
792
793#[cfg(target_arch = "aarch64")]
794retarget!(
795 diskann_wide::arch::aarch64::Neon,
796 SquaredL2,
797 7,
798 6,
799 5,
800 4,
801 3,
802 2
803);
804
805dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
806
807impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
820where
821 A: Architecture,
822 diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
823{
824 #[inline(always)]
825 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
826 check_lengths!(x, y)?;
827 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
828 diskann_vector::distance::InnerProduct {},
829 arch,
830 x.as_slice(),
831 y.as_slice(),
832 );
833
834 Ok(MV::new(r.into_inner() as u32))
835 }
836}
837
838#[cfg(target_arch = "x86_64")]
855impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
856 for InnerProduct
857{
858 #[expect(non_camel_case_types)]
859 #[inline(always)]
860 fn run(
861 self,
862 arch: diskann_wide::arch::x86_64::V4,
863 x: USlice<'_, 2>,
864 y: USlice<'_, 2>,
865 ) -> MathematicalResult<u32> {
866 let len = check_lengths!(x, y)?;
867
868 type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
869 type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
870 type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
871 type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
872
873 let px_u32: *const u32 = x.as_ptr().cast();
874 let py_u32: *const u32 = y.as_ptr().cast();
875
876 let mut i = 0;
877 let mut s: u32 = 0;
878
879 let blocks = len.div_ceil(16);
881 if i < blocks {
882 let mut s0 = i32s::default(arch);
883 let mut s1 = i32s::default(arch);
884 let mut s2 = i32s::default(arch);
885 let mut s3 = i32s::default(arch);
886 let mask = u32s::splat(arch, 0x03030303);
887 while i + 16 < blocks {
888 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
893
894 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
898
899 let wx: u8s = (vx & mask).reinterpret_simd();
900 let wy: i8s = (vy & mask).reinterpret_simd();
901 s0 = s0.dot_simd(wx, wy);
902
903 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
904 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
905 s1 = s1.dot_simd(wx, wy);
906
907 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
908 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
909 s2 = s2.dot_simd(wx, wy);
910
911 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
912 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
913 s3 = s3.dot_simd(wx, wy);
914
915 i += 16;
916 }
917
918 let remainder = len / 4 - 4 * i;
922
923 let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
928 let vx: u32s = vx.reinterpret_simd();
929
930 let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
934 let vy: u32s = vy.reinterpret_simd();
935
936 let wx: u8s = (vx & mask).reinterpret_simd();
937 let wy: i8s = (vy & mask).reinterpret_simd();
938 s0 = s0.dot_simd(wx, wy);
939
940 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
941 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
942 s1 = s1.dot_simd(wx, wy);
943
944 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
945 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
946 s2 = s2.dot_simd(wx, wy);
947
948 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
949 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
950 s3 = s3.dot_simd(wx, wy);
951
952 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
953 i = (4 * i) + remainder;
954 }
955
956 i *= 4;
958
959 debug_assert!(len - i <= 3);
961 let rest = (len - i).min(3);
962 if i != len {
963 for j in 0..rest {
964 let ix = unsafe { x.get_unchecked(i + j) } as u32;
966 let iy = unsafe { y.get_unchecked(i + j) } as u32;
968 s += ix * iy;
969 }
970 }
971
972 Ok(MV::new(s))
973 }
974}
975
976#[cfg(target_arch = "x86_64")]
990impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
991 for InnerProduct
992{
993 #[inline(always)]
994 fn run(
995 self,
996 arch: diskann_wide::arch::x86_64::V3,
997 x: USlice<'_, 4>,
998 y: USlice<'_, 4>,
999 ) -> MathematicalResult<u32> {
1000 let len = check_lengths!(x, y)?;
1001
1002 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1003 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1004 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1005
1006 let px_u32: *const u32 = x.as_ptr().cast();
1007 let py_u32: *const u32 = y.as_ptr().cast();
1008
1009 let mut i = 0;
1010 let mut s: u32 = 0;
1011
1012 let blocks = len / 8;
1013 if i < blocks {
1014 let mut s0 = i32s::default(arch);
1015 let mut s1 = i32s::default(arch);
1016 let mut s2 = i32s::default(arch);
1017 let mut s3 = i32s::default(arch);
1018 let mask = u32s::splat(arch, 0x000f000f);
1019 while i + 8 < blocks {
1020 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1025
1026 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1030
1031 let wx: i16s = (vx & mask).reinterpret_simd();
1032 let wy: i16s = (vy & mask).reinterpret_simd();
1033 s0 = s0.dot_simd(wx, wy);
1034
1035 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1036 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1037 s1 = s1.dot_simd(wx, wy);
1038
1039 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1040 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1041 s2 = s2.dot_simd(wx, wy);
1042
1043 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1044 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1045 s3 = s3.dot_simd(wx, wy);
1046
1047 i += 8;
1048 }
1049
1050 let remainder = blocks - i;
1051
1052 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1058
1059 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1063
1064 let wx: i16s = (vx & mask).reinterpret_simd();
1065 let wy: i16s = (vy & mask).reinterpret_simd();
1066 s0 = s0.dot_simd(wx, wy);
1067
1068 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1069 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1070 s1 = s1.dot_simd(wx, wy);
1071
1072 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1073 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1074 s2 = s2.dot_simd(wx, wy);
1075
1076 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1077 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1078 s3 = s3.dot_simd(wx, wy);
1079
1080 i += remainder;
1081
1082 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1083 }
1084
1085 i *= 8;
1087
1088 if i != len {
1090 #[inline(never)]
1092 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1093 let mut s: u32 = 0;
1094 for i in from..x.len() {
1095 let ix = unsafe { x.get_unchecked(i) } as u32;
1097 let iy = unsafe { y.get_unchecked(i) } as u32;
1099 s += ix * iy;
1100 }
1101 s
1102 }
1103 s += fallback(x, y, i);
1104 }
1105
1106 Ok(MV::new(s))
1107 }
1108}
1109
1110#[cfg(target_arch = "x86_64")]
1124impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1125 for InnerProduct
1126{
1127 #[inline(always)]
1128 fn run(
1129 self,
1130 arch: diskann_wide::arch::x86_64::V3,
1131 x: USlice<'_, 2>,
1132 y: USlice<'_, 2>,
1133 ) -> MathematicalResult<u32> {
1134 let len = check_lengths!(x, y)?;
1135
1136 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1137 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1138 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1139
1140 let px_u32: *const u32 = x.as_ptr().cast();
1141 let py_u32: *const u32 = y.as_ptr().cast();
1142
1143 let mut i = 0;
1144 let mut s: u32 = 0;
1145
1146 let blocks = len / 16;
1148 if i < blocks {
1149 let mut s0 = i32s::default(arch);
1150 let mut s1 = i32s::default(arch);
1151 let mut s2 = i32s::default(arch);
1152 let mut s3 = i32s::default(arch);
1153 let mask = u32s::splat(arch, 0x00030003);
1154 while i + 8 < blocks {
1155 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1160
1161 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1165
1166 let wx: i16s = (vx & mask).reinterpret_simd();
1167 let wy: i16s = (vy & mask).reinterpret_simd();
1168 s0 = s0.dot_simd(wx, wy);
1169
1170 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1171 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1172 s1 = s1.dot_simd(wx, wy);
1173
1174 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1175 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1176 s2 = s2.dot_simd(wx, wy);
1177
1178 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1179 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1180 s3 = s3.dot_simd(wx, wy);
1181
1182 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1183 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1184 s0 = s0.dot_simd(wx, wy);
1185
1186 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1187 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1188 s1 = s1.dot_simd(wx, wy);
1189
1190 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1191 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1192 s2 = s2.dot_simd(wx, wy);
1193
1194 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1195 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1196 s3 = s3.dot_simd(wx, wy);
1197
1198 i += 8;
1199 }
1200
1201 let remainder = blocks - i;
1202
1203 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1209
1210 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1214 let wx: i16s = (vx & mask).reinterpret_simd();
1215 let wy: i16s = (vy & mask).reinterpret_simd();
1216 s0 = s0.dot_simd(wx, wy);
1217
1218 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1219 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1220 s1 = s1.dot_simd(wx, wy);
1221
1222 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1223 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1224 s2 = s2.dot_simd(wx, wy);
1225
1226 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1227 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1228 s3 = s3.dot_simd(wx, wy);
1229
1230 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1231 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1232 s0 = s0.dot_simd(wx, wy);
1233
1234 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1235 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1236 s1 = s1.dot_simd(wx, wy);
1237
1238 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1239 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1240 s2 = s2.dot_simd(wx, wy);
1241
1242 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1243 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1244 s3 = s3.dot_simd(wx, wy);
1245
1246 i += remainder;
1247
1248 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1249 }
1250
1251 i *= 16;
1253
1254 if i != len {
1256 #[inline(never)]
1258 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1259 let mut s: u32 = 0;
1260 for i in from..x.len() {
1261 let ix = unsafe { x.get_unchecked(i) } as u32;
1263 let iy = unsafe { y.get_unchecked(i) } as u32;
1265 s += ix * iy;
1266 }
1267 s
1268 }
1269 s += fallback(x, y, i);
1270 }
1271
1272 Ok(MV::new(s))
1273 }
1274}
1275
1276impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1280where
1281 A: Architecture,
1282{
1283 #[inline(always)]
1284 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1285 bitvector_op::<Self, Unsigned>(x, y)
1286 }
1287}
1288
1289macro_rules! impl_fallback_ip {
1291 ($N:literal) => {
1292 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for InnerProduct {
1300 #[inline(never)]
1301 fn run(
1302 self,
1303 _: diskann_wide::arch::Scalar,
1304 x: USlice<'_, $N>,
1305 y: USlice<'_, $N>
1306 ) -> MathematicalResult<u32> {
1307 let len = check_lengths!(x, y)?;
1308
1309 let mut accum: u32 = 0;
1310 for i in 0..len {
1311 let ix = unsafe { x.get_unchecked(i) } as u32;
1313 let iy = unsafe { y.get_unchecked(i) } as u32;
1315 accum += ix * iy;
1316 }
1317 Ok(MV::new(accum))
1318 }
1319 }
1320 };
1321 ($($N:literal),+ $(,)?) => {
1322 $(impl_fallback_ip!($N);)+
1323 };
1324}
1325
1326impl_fallback_ip!(7, 6, 5, 4, 3, 2);
1327
1328#[cfg(target_arch = "x86_64")]
1329retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1330
1331#[cfg(target_arch = "x86_64")]
1332retarget!(diskann_wide::arch::x86_64::V4, InnerProduct, 7, 6, 4, 5, 3);
1333
1334#[cfg(target_arch = "aarch64")]
1335retarget!(
1336 diskann_wide::arch::aarch64::Neon,
1337 InnerProduct,
1338 7,
1339 6,
1340 4,
1341 5,
1342 3,
1343 2
1344);
1345
1346dispatch_pure!(InnerProduct, 1, 2, 3, 4, 5, 6, 7, 8);
1347
1348impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1377 for InnerProduct
1378where
1379 A: Architecture,
1380{
1381 #[inline(always)]
1382 fn run(
1383 self,
1384 _: A,
1385 x: USlice<'_, 4, BitTranspose>,
1386 y: USlice<'_, 1, Dense>,
1387 ) -> MathematicalResult<u32> {
1388 let len = check_lengths!(x, y)?;
1389
1390 let px: *const u64 = x.as_ptr().cast();
1397 let py: *const u64 = y.as_ptr().cast();
1398
1399 let mut i = 0;
1400 let mut s: u32 = 0;
1401
1402 let blocks = len / 64;
1403 while i < blocks {
1404 let bits = unsafe { py.add(i).read_unaligned() };
1406
1407 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1413 s += (bits & b0).count_ones();
1414
1415 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1417 s += (bits & b1).count_ones() << 1;
1418
1419 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1421 s += (bits & b2).count_ones() << 2;
1422
1423 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1425 s += (bits & b3).count_ones() << 3;
1426
1427 i += 1;
1428 }
1429
1430 if 64 * i == len {
1432 return Ok(MV::new(s));
1433 }
1434
1435 let k = i * 8;
1437
1438 let py = unsafe { py.cast::<u8>().add(k) };
1443 let bytes_remaining = y.bytes() - k;
1444 let mut bits: u64 = 0;
1445
1446 for j in 0..bytes_remaining.min(8) {
1449 bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1452 }
1453
1454 bits &= (0x01u64 << (len - (64 * i))) - 1;
1457
1458 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1463 s += (bits & b0).count_ones();
1464
1465 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1467 s += (bits & b1).count_ones() << 1;
1468
1469 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1471 s += (bits & b2).count_ones() << 2;
1472
1473 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1475 s += (bits & b3).count_ones() << 3;
1476
1477 Ok(MV::new(s))
1478 }
1479}
1480
1481impl
1482 PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1483 for InnerProduct
1484{
1485 fn evaluate(
1486 x: USlice<'_, 4, BitTranspose>,
1487 y: USlice<'_, 1, Dense>,
1488 ) -> MathematicalResult<u32> {
1489 (diskann_wide::ARCH).run2(Self, x, y)
1490 }
1491}
1492
1493#[cfg(target_arch = "x86_64")]
1524impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
1525 for InnerProduct
1526{
1527 #[inline(always)]
1528 fn run(
1529 self,
1530 arch: diskann_wide::arch::x86_64::V3,
1531 x: &[f32],
1532 y: USlice<'_, 1>,
1533 ) -> MathematicalResult<f32> {
1534 let len = check_lengths!(x, y)?;
1535
1536 use std::arch::x86_64::*;
1537
1538 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1539 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1540
1541 let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1544
1545 let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
1547
1548 let px: *const f32 = x.as_ptr();
1549 let py: *const u32 = y.as_ptr().cast();
1550
1551 let mut i = 0;
1552 let mut s = f32s::default(arch);
1553
1554 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1555 let to_f32 = |v: u32s| -> f32s {
1556 f32s::from_underlying(arch, unsafe {
1559 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1560 })
1561 };
1562
1563 let blocks = len / 32;
1565 if i < blocks {
1566 let mut s0 = f32s::default(arch);
1567 let mut s1 = f32s::default(arch);
1568
1569 while i < blocks {
1570 let iy = prep(unsafe { py.add(i).read_unaligned() });
1572
1573 let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
1575 let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
1577 let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
1579 let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
1581
1582 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1583 s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
1584 s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
1585 s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
1586
1587 i += 1;
1588 }
1589 s = s0 + s1;
1590 }
1591
1592 let remainder = len % 32;
1593 if remainder != 0 {
1594 let tail = if len % 8 == 0 { 8 } else { len % 8 };
1595
1596 let py = unsafe { py.add(blocks) };
1599
1600 if remainder <= 8 {
1601 unsafe {
1604 load_one(py, |iy| {
1605 let iy = prep(iy);
1606 let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
1607 s = ix.mul_add_simd(to_f32(iy), s);
1608 })
1609 }
1610 } else if remainder <= 16 {
1611 unsafe {
1614 load_two(py, |iy| {
1615 let iy = prep(iy);
1616 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1617 let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
1618 s = ix0.mul_add_simd(to_f32(iy), s);
1619 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1620 })
1621 }
1622 } else if remainder <= 24 {
1623 unsafe {
1626 load_three(py, |iy| {
1627 let iy = prep(iy);
1628
1629 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1630 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1631 let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
1632
1633 s = ix0.mul_add_simd(to_f32(iy), s);
1634 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1635 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1636 })
1637 }
1638 } else {
1639 unsafe {
1642 load_four(py, |iy| {
1643 let iy = prep(iy);
1644
1645 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1646 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1647 let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
1648 let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
1649
1650 s = ix0.mul_add_simd(to_f32(iy), s);
1651 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1652 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1653 s = ix3.mul_add_simd(to_f32(iy >> 24), s);
1654 })
1655 }
1656 }
1657 }
1658
1659 Ok(MV::new(s.sum_tree()))
1660 }
1661}
1662
1663#[cfg(target_arch = "x86_64")]
1667impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
1668 for InnerProduct
1669{
1670 #[inline(always)]
1671 fn run(
1672 self,
1673 arch: diskann_wide::arch::x86_64::V3,
1674 x: &[f32],
1675 y: USlice<'_, 2>,
1676 ) -> MathematicalResult<f32> {
1677 let len = check_lengths!(x, y)?;
1678
1679 use std::arch::x86_64::*;
1680
1681 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1682 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1683
1684 let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
1688
1689 let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
1691
1692 let px: *const f32 = x.as_ptr();
1693 let py: *const u32 = y.as_ptr().cast();
1694
1695 let mut i = 0;
1696 let mut s = f32s::default(arch);
1697
1698 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1699 let to_f32 = |v: u32s| -> f32s {
1700 f32s::from_underlying(arch, unsafe {
1703 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1704 })
1705 };
1706
1707 let blocks = len / 16;
1708 if blocks != 0 {
1709 let mut s0 = f32s::default(arch);
1710 let mut s1 = f32s::default(arch);
1711
1712 while i + 2 <= blocks {
1714 let iy = prep(unsafe { py.add(i).read_unaligned() });
1717
1718 let (ix0, ix1) = unsafe {
1721 (
1722 f32s::load_simd(arch, px.add(16 * i)),
1723 f32s::load_simd(arch, px.add(16 * i + 8)),
1724 )
1725 };
1726
1727 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1728 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1729
1730 let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
1733
1734 let (ix0, ix1) = unsafe {
1736 (
1737 f32s::load_simd(arch, px.add(16 * (i + 1))),
1738 f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
1739 )
1740 };
1741
1742 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1743 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1744
1745 i += 2;
1746 }
1747
1748 if i < blocks {
1750 let iy = prep(unsafe { py.add(i).read_unaligned() });
1753
1754 let (ix0, ix1) = unsafe {
1756 (
1757 f32s::load_simd(arch, px.add(16 * i)),
1758 f32s::load_simd(arch, px.add(16 * i + 8)),
1759 )
1760 };
1761
1762 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1763 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1764 }
1765
1766 s = s0 + s1;
1767 }
1768
1769 let remainder = len % 16;
1770 if remainder != 0 {
1771 let tail = if len % 8 == 0 { 8 } else { len % 8 };
1772 let py = unsafe { py.add(blocks) };
1775
1776 if remainder <= 4 {
1777 unsafe {
1780 load_one(py, |iy| {
1781 let iy = prep(iy);
1782 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1783 s = ix.mul_add_simd(to_f32(iy), s);
1784 });
1785 }
1786 } else if remainder <= 8 {
1787 unsafe {
1790 load_two(py, |iy| {
1791 let iy = prep(iy);
1792 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1793 s = ix.mul_add_simd(to_f32(iy), s);
1794 });
1795 }
1796 } else if remainder <= 12 {
1797 unsafe {
1800 load_three(py, |iy| {
1801 let iy = prep(iy);
1802 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1803 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1804 s = ix0.mul_add_simd(to_f32(iy), s);
1805 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1806 });
1807 }
1808 } else {
1809 unsafe {
1812 load_four(py, |iy| {
1813 let iy = prep(iy);
1814 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1815 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1816 s = ix0.mul_add_simd(to_f32(iy), s);
1817 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1818 });
1819 }
1820 }
1821 }
1822
1823 Ok(MV::new(s.sum_tree()))
1824 }
1825}
1826
1827#[cfg(target_arch = "x86_64")]
1832impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
1833 for InnerProduct
1834{
1835 #[inline(always)]
1836 fn run(
1837 self,
1838 arch: diskann_wide::arch::x86_64::V3,
1839 x: &[f32],
1840 y: USlice<'_, 4>,
1841 ) -> MathematicalResult<f32> {
1842 let len = check_lengths!(x, y)?;
1843
1844 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1845 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1846
1847 let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
1848 let mask = i32s::splat(arch, 0x0f);
1849
1850 let to_f32 = |v: u32| -> f32s {
1851 ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
1852 };
1853
1854 let px: *const f32 = x.as_ptr();
1855 let py: *const u32 = y.as_ptr().cast();
1856
1857 let mut i = 0;
1858 let mut s = f32s::default(arch);
1859
1860 let blocks = len / 8;
1861 while i < blocks {
1862 let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
1864 let iy = to_f32(unsafe { py.add(i).read_unaligned() });
1866 s = ix.mul_add_simd(iy, s);
1867
1868 i += 1;
1869 }
1870
1871 let remainder = len % 8;
1872 if remainder != 0 {
1873 let f = |iy| {
1874 let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
1878 s = ix.mul_add_simd(to_f32(iy), s);
1879 };
1880
1881 let py = unsafe { py.add(blocks) };
1884
1885 if remainder <= 2 {
1886 unsafe { load_one(py, f) };
1888 } else if remainder <= 4 {
1889 unsafe { load_two(py, f) };
1891 } else if remainder <= 6 {
1892 unsafe { load_three(py, f) };
1894 } else {
1895 unsafe { load_four(py, f) };
1897 }
1898 }
1899
1900 Ok(MV::new(s.sum_tree()))
1901 }
1902}
1903
1904impl<const N: usize>
1905 Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
1906 for InnerProduct
1907where
1908 Unsigned: Representation<N>,
1909{
1910 #[inline(always)]
1913 fn run(
1914 self,
1915 _: diskann_wide::arch::Scalar,
1916 x: &[f32],
1917 y: USlice<'_, N>,
1918 ) -> MathematicalResult<f32> {
1919 check_lengths!(x, y)?;
1920
1921 let mut s = 0.0;
1922 for (i, x) in x.iter().enumerate() {
1923 let y = unsafe { y.get_unchecked(i) } as f32;
1926 s += x * y;
1927 }
1928
1929 Ok(MV::new(s))
1930 }
1931}
1932
1933macro_rules! ip_retarget {
1935 ($arch:path, $N:literal) => {
1936 impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
1937 for InnerProduct
1938 {
1939 #[inline(always)]
1940 fn run(
1941 self,
1942 arch: $arch,
1943 x: &[f32],
1944 y: USlice<'_, $N>,
1945 ) -> MathematicalResult<f32> {
1946 self.run(arch.retarget(), x, y)
1947 }
1948 }
1949 };
1950 ($arch:path, $($Ns:literal),*) => {
1951 $(ip_retarget!($arch, $Ns);)*
1952 }
1953}
1954
1955#[cfg(target_arch = "x86_64")]
1956ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
1957
1958#[cfg(target_arch = "x86_64")]
1959ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
1960
1961#[cfg(target_arch = "aarch64")]
1962ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8);
1963
1964macro_rules! dispatch_full_ip {
1967 ($N:literal) => {
1968 impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
1972 for InnerProduct
1973 {
1974 fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
1975 Self.run(ARCH, x, y)
1976 }
1977 }
1978 };
1979 ($($Ns:literal),*) => {
1980 $(dispatch_full_ip!($Ns);)*
1981 }
1982}
1983
1984dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
1985
1986#[cfg(test)]
1991mod tests {
1992 use std::{collections::HashMap, sync::LazyLock};
1993
1994 use diskann_utils::Reborrow;
1995 use rand::{
1996 Rng, SeedableRng,
1997 distr::{Distribution, Uniform},
1998 rngs::StdRng,
1999 seq::IndexedRandom,
2000 };
2001
2002 use super::*;
2003 use crate::bits::{BoxedBitSlice, Representation, Unsigned};
2004
2005 type MR = MathematicalResult<u32>;
2006
2007 fn test_bitslice_distances<const NBITS: usize, R>(
2017 dim_max: usize,
2018 trials_per_dim: usize,
2019 evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2020 evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2021 context: &str,
2022 rng: &mut R,
2023 ) where
2024 Unsigned: Representation<NBITS>,
2025 R: Rng,
2026 {
2027 let domain = Unsigned::domain_const::<NBITS>();
2028 let min: i64 = *domain.start();
2029 let max: i64 = *domain.end();
2030
2031 let dist = Uniform::new_inclusive(min, max).unwrap();
2032
2033 for dim in 0..dim_max {
2034 let mut x_reference: Vec<u8> = vec![0; dim];
2035 let mut y_reference: Vec<u8> = vec![0; dim];
2036
2037 let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2038 let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2039
2040 for trial in 0..trials_per_dim {
2041 x_reference
2042 .iter_mut()
2043 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2044 y_reference
2045 .iter_mut()
2046 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2047
2048 x.as_mut_slice().fill(u8::MAX);
2051 y.as_mut_slice().fill(u8::MAX);
2052
2053 for i in 0..dim {
2054 x.set(i, x_reference[i].into()).unwrap();
2055 y.set(i, y_reference[i].into()).unwrap();
2056 }
2057
2058 let expected: MV<f32> =
2060 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2061
2062 let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2063
2064 assert_eq!(
2066 expected.into_inner(),
2067 got.into_inner() as f32,
2068 "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2069 NBITS,
2070 dim,
2071 trial,
2072 context,
2073 );
2074
2075 let expected: MV<f32> =
2077 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2078
2079 let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2080
2081 assert_eq!(
2083 expected.into_inner(),
2084 got.into_inner() as f32,
2085 "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2086 NBITS,
2087 dim,
2088 trial,
2089 context,
2090 );
2091 }
2092 }
2093
2094 let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2096 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2097
2098 assert!(
2099 evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2100 "context: {}",
2101 context
2102 );
2103 assert!(
2104 evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2105 "context: {}",
2106 context
2107 );
2108
2109 assert!(
2110 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2111 "context: {}",
2112 context
2113 );
2114 assert!(
2115 evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2116 "context: {}",
2117 context
2118 );
2119 }
2120
2121 cfg_if::cfg_if! {
2122 if #[cfg(miri)] {
2123 const MAX_DIM: usize = 128;
2124 const TRIALS_PER_DIM: usize = 1;
2125 } else {
2126 const MAX_DIM: usize = 256;
2127 const TRIALS_PER_DIM: usize = 20;
2128 }
2129 }
2130
2131 static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2141 use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4};
2142 [
2143 (Key::new(1, Scalar), Bounds::new(64, 64)),
2144 (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2145 (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2146 (Key::new(1, Neon), Bounds::new(64, 64)),
2147 (Key::new(2, Scalar), Bounds::new(64, 64)),
2148 (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2150 (Key::new(2, X86_64_V4), Bounds::new(768, 600)), (Key::new(2, Neon), Bounds::new(64, 64)),
2152 (Key::new(3, Scalar), Bounds::new(64, 64)),
2153 (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2154 (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2155 (Key::new(3, Neon), Bounds::new(64, 64)),
2156 (Key::new(4, Scalar), Bounds::new(64, 64)),
2157 (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2159 (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2160 (Key::new(4, Neon), Bounds::new(64, 64)),
2161 (Key::new(5, Scalar), Bounds::new(64, 64)),
2162 (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2163 (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2164 (Key::new(5, Neon), Bounds::new(64, 64)),
2165 (Key::new(6, Scalar), Bounds::new(64, 64)),
2166 (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2167 (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2168 (Key::new(6, Neon), Bounds::new(64, 64)),
2169 (Key::new(7, Scalar), Bounds::new(64, 64)),
2170 (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2171 (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2172 (Key::new(7, Neon), Bounds::new(64, 64)),
2173 (Key::new(8, Scalar), Bounds::new(64, 64)),
2174 (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2175 (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2176 (Key::new(8, Neon), Bounds::new(64, 64)),
2177 ]
2178 .into_iter()
2179 .collect()
2180 });
2181
2182 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2183 enum ArchKey {
2184 Scalar,
2185 #[expect(non_camel_case_types)]
2186 X86_64_V3,
2187 #[expect(non_camel_case_types)]
2188 X86_64_V4,
2189 Neon,
2190 }
2191
2192 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2193 struct Key {
2194 nbits: usize,
2195 arch: ArchKey,
2196 }
2197
2198 impl Key {
2199 fn new(nbits: usize, arch: ArchKey) -> Self {
2200 Self { nbits, arch }
2201 }
2202 }
2203
2204 #[derive(Debug, Clone, Copy)]
2205 struct Bounds {
2206 standard: usize,
2207 miri: usize,
2208 }
2209
2210 impl Bounds {
2211 fn new(standard: usize, miri: usize) -> Self {
2212 Self { standard, miri }
2213 }
2214
2215 fn get(&self) -> usize {
2216 if cfg!(miri) { self.miri } else { self.standard }
2217 }
2218 }
2219
2220 macro_rules! test_bitslice {
2221 ($name:ident, $nbits:literal, $seed:literal) => {
2222 #[test]
2223 fn $name() {
2224 let mut rng = StdRng::seed_from_u64($seed);
2225
2226 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2227
2228 test_bitslice_distances::<$nbits, _>(
2229 max_dim,
2230 TRIALS_PER_DIM,
2231 &|x, y| SquaredL2::evaluate(x, y),
2232 &|x, y| InnerProduct::evaluate(x, y),
2233 "pure distance function",
2234 &mut rng,
2235 );
2236
2237 test_bitslice_distances::<$nbits, _>(
2238 max_dim,
2239 TRIALS_PER_DIM,
2240 &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2241 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2242 "scalar arch",
2243 &mut rng,
2244 );
2245
2246 #[cfg(target_arch = "x86_64")]
2248 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2249 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2250 test_bitslice_distances::<$nbits, _>(
2251 max_dim,
2252 TRIALS_PER_DIM,
2253 &|x, y| arch.run2(SquaredL2, x, y),
2254 &|x, y| arch.run2(InnerProduct, x, y),
2255 "x86-64-v3",
2256 &mut rng,
2257 );
2258 }
2259
2260 #[cfg(target_arch = "x86_64")]
2261 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2262 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2263 test_bitslice_distances::<$nbits, _>(
2264 max_dim,
2265 TRIALS_PER_DIM,
2266 &|x, y| arch.run2(SquaredL2, x, y),
2267 &|x, y| arch.run2(InnerProduct, x, y),
2268 "x86-64-v4",
2269 &mut rng,
2270 );
2271 }
2272
2273 #[cfg(target_arch = "aarch64")]
2274 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2275 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get();
2276 test_bitslice_distances::<$nbits, _>(
2277 max_dim,
2278 TRIALS_PER_DIM,
2279 &|x, y| arch.run2(SquaredL2, x, y),
2280 &|x, y| arch.run2(InnerProduct, x, y),
2281 "neon",
2282 &mut rng,
2283 );
2284 }
2285 }
2286 };
2287 }
2288
2289 test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2290 test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2291 test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2292 test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2293 test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2294 test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2295 test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2296 test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2297
2298 fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2303 where
2304 R: Rng,
2305 {
2306 let dist: [i8; 2] = [-1, 1];
2307
2308 for dim in 0..dim_max {
2309 let mut x_reference: Vec<i8> = vec![1; dim];
2310 let mut y_reference: Vec<i8> = vec![1; dim];
2311
2312 let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2313 let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2314
2315 for _ in 0..trials_per_dim {
2316 x_reference
2317 .iter_mut()
2318 .for_each(|i| *i = *dist.choose(rng).unwrap());
2319 y_reference
2320 .iter_mut()
2321 .for_each(|i| *i = *dist.choose(rng).unwrap());
2322
2323 x.as_mut_slice().fill(u8::MAX);
2326 y.as_mut_slice().fill(u8::MAX);
2327
2328 for i in 0..dim {
2329 x.set(i, x_reference[i].into()).unwrap();
2330 y.set(i, y_reference[i].into()).unwrap();
2331 }
2332
2333 let expected: MV<f32> =
2339 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2340 let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2341 assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2342 }
2343 }
2344
2345 let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2346 let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2347 assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2348 assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2349 }
2350
2351 #[test]
2352 fn test_hamming_distance() {
2353 let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2354 test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2355 }
2356
2357 fn test_bit_transpose_distances<R>(
2362 dim_max: usize,
2363 trials_per_dim: usize,
2364 evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2365 context: &str,
2366 rng: &mut R,
2367 ) where
2368 R: Rng,
2369 {
2370 let dist_4bit = {
2371 let domain = Unsigned::domain_const::<4>();
2372 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2373 };
2374
2375 let dist_1bit = {
2376 let domain = Unsigned::domain_const::<1>();
2377 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2378 };
2379
2380 for dim in 0..dim_max {
2381 let mut x_reference: Vec<u8> = vec![0; dim];
2382 let mut y_reference: Vec<u8> = vec![0; dim];
2383
2384 let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2385 let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2386
2387 for trial in 0..trials_per_dim {
2388 x_reference
2389 .iter_mut()
2390 .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2391 y_reference
2392 .iter_mut()
2393 .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2394
2395 x.as_mut_slice().fill(u8::MAX);
2397 y.as_mut_slice().fill(u8::MAX);
2398
2399 for i in 0..dim {
2400 x.set(i, x_reference[i].into()).unwrap();
2401 y.set(i, y_reference[i].into()).unwrap();
2402 }
2403
2404 let expected: MV<f32> =
2406 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2407
2408 let got = evaluate_ip(x.reborrow(), y.reborrow());
2409
2410 assert_eq!(
2412 expected.into_inner(),
2413 got.unwrap().into_inner() as f32,
2414 "faild InnerProduct for dim = {}, trial = {} -- context {}",
2415 dim,
2416 trial,
2417 context,
2418 );
2419 }
2420 }
2421
2422 let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2423 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2424 assert!(
2425 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2426 "context: {}",
2427 context
2428 );
2429
2430 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2431 assert!(
2432 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2433 "context: {}",
2434 context
2435 );
2436 }
2437
2438 #[test]
2439 fn test_bit_transpose_distance() {
2440 let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2441
2442 test_bit_transpose_distances(
2443 MAX_DIM,
2444 TRIALS_PER_DIM,
2445 &|x, y| InnerProduct::evaluate(x, y),
2446 "pure distance function",
2447 &mut rng,
2448 );
2449
2450 test_bit_transpose_distances(
2451 MAX_DIM,
2452 TRIALS_PER_DIM,
2453 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2454 "scalar",
2455 &mut rng,
2456 );
2457
2458 #[cfg(target_arch = "x86_64")]
2460 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2461 test_bit_transpose_distances(
2462 MAX_DIM,
2463 TRIALS_PER_DIM,
2464 &|x, y| arch.run2(InnerProduct, x, y),
2465 "x86-64-v3",
2466 &mut rng,
2467 );
2468 }
2469
2470 #[cfg(target_arch = "x86_64")]
2472 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2473 test_bit_transpose_distances(
2474 MAX_DIM,
2475 TRIALS_PER_DIM,
2476 &|x, y| arch.run2(InnerProduct, x, y),
2477 "x86-64-v4",
2478 &mut rng,
2479 );
2480 }
2481
2482 #[cfg(target_arch = "aarch64")]
2484 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2485 test_bit_transpose_distances(
2486 MAX_DIM,
2487 TRIALS_PER_DIM,
2488 &|x, y| arch.run2(InnerProduct, x, y),
2489 "neon",
2490 &mut rng,
2491 );
2492 }
2493 }
2494
2495 fn test_full_distances<const NBITS: usize>(
2500 dim_max: usize,
2501 trials_per_dim: usize,
2502 evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
2503 context: &str,
2504 rng: &mut impl Rng,
2505 ) where
2506 Unsigned: Representation<NBITS>,
2507 {
2508 let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
2510 let dist_bit = {
2511 let domain = Unsigned::domain_const::<NBITS>();
2512 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2513 };
2514
2515 for dim in 0..dim_max {
2516 let mut x: Vec<f32> = vec![0.0; dim];
2517
2518 let mut y_reference: Vec<u8> = vec![0; dim];
2519 let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
2520
2521 for trial in 0..trials_per_dim {
2522 x.iter_mut()
2523 .for_each(|i| *i = *dist_float.choose(rng).unwrap());
2524 y_reference
2525 .iter_mut()
2526 .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
2527
2528 y.as_mut_slice().fill(u8::MAX);
2530
2531 let mut expected = 0.0;
2532 for i in 0..dim {
2533 y.set(i, y_reference[i].into()).unwrap();
2534 expected += y_reference[i] as f32 * x[i];
2535 }
2536
2537 let got = evaluate_ip(&x, y.reborrow()).unwrap();
2539
2540 assert_eq!(
2542 expected,
2543 got.into_inner(),
2544 "faild InnerProduct for dim = {}, trial = {} -- context {}",
2545 dim,
2546 trial,
2547 context,
2548 );
2549
2550 let scalar: MV<f32> = InnerProduct
2553 .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
2554 .unwrap();
2555 assert_eq!(got.into_inner(), scalar.into_inner());
2556 }
2557 }
2558
2559 let x = vec![0.0; 10];
2561 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2562 assert!(
2563 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2564 "context: {}",
2565 context
2566 );
2567
2568 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
2569 assert!(
2570 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2571 "context: {}",
2572 context
2573 );
2574 }
2575
2576 macro_rules! test_full {
2577 ($name:ident, $nbits:literal, $seed:literal) => {
2578 #[test]
2579 fn $name() {
2580 let mut rng = StdRng::seed_from_u64($seed);
2581
2582 test_full_distances::<$nbits>(
2583 MAX_DIM,
2584 TRIALS_PER_DIM,
2585 &|x, y| InnerProduct::evaluate(x, y),
2586 "pure distance function",
2587 &mut rng,
2588 );
2589
2590 test_full_distances::<$nbits>(
2591 MAX_DIM,
2592 TRIALS_PER_DIM,
2593 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2594 "scalar",
2595 &mut rng,
2596 );
2597
2598 #[cfg(target_arch = "x86_64")]
2600 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2601 test_full_distances::<$nbits>(
2602 MAX_DIM,
2603 TRIALS_PER_DIM,
2604 &|x, y| arch.run2(InnerProduct, x, y),
2605 "x86-64-v3",
2606 &mut rng,
2607 );
2608 }
2609
2610 #[cfg(target_arch = "x86_64")]
2611 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2612 test_full_distances::<$nbits>(
2613 MAX_DIM,
2614 TRIALS_PER_DIM,
2615 &|x, y| arch.run2(InnerProduct, x, y),
2616 "x86-64-v4",
2617 &mut rng,
2618 );
2619 }
2620
2621 #[cfg(target_arch = "aarch64")]
2622 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2623 test_full_distances::<$nbits>(
2624 MAX_DIM,
2625 TRIALS_PER_DIM,
2626 &|x, y| arch.run2(InnerProduct, x, y),
2627 "neon",
2628 &mut rng,
2629 );
2630 }
2631 }
2632 };
2633 }
2634
2635 test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
2636 test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
2637 test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
2638 test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
2639 test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
2640 test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
2641 test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
2642 test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
2643}