1use diskann_vector::PureDistanceFunction;
107use diskann_wide::{ARCH, Architecture, arch::Target2};
108#[cfg(target_arch = "x86_64")]
109use diskann_wide::{
110 SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
111};
112
113use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
114use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
115
116type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
118
119#[cfg(target_arch = "x86_64")]
122macro_rules! retarget {
123 ($arch:path, $op:ty, $N:literal) => {
124 impl Target2<
125 $arch,
126 MathematicalResult<u32>,
127 USlice<'_, $N>,
128 USlice<'_, $N>,
129 > for $op {
130 #[inline(always)]
131 fn run(
132 self,
133 arch: $arch,
134 x: USlice<'_, $N>,
135 y: USlice<'_, $N>
136 ) -> MathematicalResult<u32> {
137 self.run(arch.retarget(), x, y)
138 }
139 }
140 };
141 ($arch:path, $op:ty, $($N:literal),+ $(,)?) => {
142 $(retarget!($arch, $op, $N);)+
143 }
144}
145
146macro_rules! dispatch_pure {
148 ($op:ty, $N:literal) => {
149 impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $N>, MathematicalResult<u32>> for $op {
151 #[inline(always)]
152 fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $N>) -> MathematicalResult<u32> {
153 (diskann_wide::ARCH).run2(Self, x, y)
154 }
155 }
156 };
157 ($op:ty, $($N:literal),+ $(,)?) => {
158 $(dispatch_pure!($op, $N);)+
159 }
160}
161
162#[cfg(target_arch = "x86_64")]
169unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
170where
171 F: FnMut(u32) -> R,
172{
173 f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
175}
176
177#[cfg(target_arch = "x86_64")]
184unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
185where
186 F: FnMut(u32) -> R,
187{
188 f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
190}
191
192#[cfg(target_arch = "x86_64")]
199unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
200where
201 F: FnMut(u32) -> R,
202{
203 let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
205 let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
207 f(lo | hi << 16)
208}
209
210#[cfg(target_arch = "x86_64")]
217unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
218where
219 F: FnMut(u32) -> R,
220{
221 f(unsafe { ptr.read_unaligned() })
223}
224
225trait BitVectorOp<Repr>
237where
238 Repr: Representation<1>,
239{
240 fn on_u64(x: u64, y: u64) -> u32;
242
243 fn on_u8(x: u8, y: u8) -> u32;
248}
249
250impl BitVectorOp<Unsigned> for SquaredL2 {
252 #[inline(always)]
253 fn on_u64(x: u64, y: u64) -> u32 {
254 (x ^ y).count_ones()
255 }
256 #[inline(always)]
257 fn on_u8(x: u8, y: u8) -> u32 {
258 (x ^ y).count_ones()
259 }
260}
261
262impl BitVectorOp<Binary> for Hamming {
264 #[inline(always)]
265 fn on_u64(x: u64, y: u64) -> u32 {
266 (x ^ y).count_ones()
267 }
268 #[inline(always)]
269 fn on_u8(x: u8, y: u8) -> u32 {
270 (x ^ y).count_ones()
271 }
272}
273
274impl BitVectorOp<Unsigned> for InnerProduct {
282 #[inline(always)]
283 fn on_u64(x: u64, y: u64) -> u32 {
284 (x & y).count_ones()
285 }
286 #[inline(always)]
287 fn on_u8(x: u8, y: u8) -> u32 {
288 (x & y).count_ones()
289 }
290}
291
292#[inline(always)]
297fn bitvector_op<Op, Repr>(
298 x: BitSlice<'_, 1, Repr>,
299 y: BitSlice<'_, 1, Repr>,
300) -> MathematicalResult<u32>
301where
302 Repr: Representation<1>,
303 Op: BitVectorOp<Repr>,
304{
305 let len = check_lengths!(x, y)?;
306
307 let px: *const u64 = x.as_ptr().cast();
308 let py: *const u64 = y.as_ptr().cast();
309
310 let mut i = 0;
311 let mut s: u32 = 0;
312
313 let blocks = len / 64;
315 while i < blocks {
316 let vx = unsafe { px.add(i).read_unaligned() };
320
321 let vy = unsafe { py.add(i).read_unaligned() };
325
326 s += Op::on_u64(vx, vy);
327 i += 1;
328 }
329
330 i *= 8;
332 let px: *const u8 = x.as_ptr();
333 let py: *const u8 = y.as_ptr();
334
335 let blocks = len / 8;
336 while i < blocks {
337 let vx = unsafe { px.add(i).read_unaligned() };
340
341 let vy = unsafe { py.add(i).read_unaligned() };
345 s += Op::on_u8(vx, vy);
346 i += 1;
347 }
348
349 if i * 8 != len {
350 let vx = unsafe { px.add(i).read_unaligned() };
353
354 let vy = unsafe { py.add(i).read_unaligned() };
356 let m = (0x01u8 << (len - 8 * i)) - 1;
357
358 s += Op::on_u8(vx & m, vy & m)
359 }
360 Ok(MV::new(s))
361}
362
363impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
367 for Hamming
368{
369 fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
370 bitvector_op::<Hamming, Binary>(x, y)
371 }
372}
373
374impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
387where
388 A: Architecture,
389 diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
390{
391 #[inline(always)]
392 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
393 check_lengths!(x, y)?;
394
395 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
396 diskann_vector::distance::SquaredL2 {},
397 arch,
398 x.as_slice(),
399 y.as_slice(),
400 );
401
402 Ok(MV::new(r.into_inner() as u32))
403 }
404}
405
406#[cfg(target_arch = "x86_64")]
420impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
421 for SquaredL2
422{
423 #[inline(always)]
424 fn run(
425 self,
426 arch: diskann_wide::arch::x86_64::V3,
427 x: USlice<'_, 4>,
428 y: USlice<'_, 4>,
429 ) -> MathematicalResult<u32> {
430 let len = check_lengths!(x, y)?;
431
432 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
433 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
434 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
435
436 let px_u32: *const u32 = x.as_ptr().cast();
437 let py_u32: *const u32 = y.as_ptr().cast();
438
439 let mut i = 0;
440 let mut s: u32 = 0;
441
442 let blocks = len / 8;
444 if i < blocks {
445 let mut s0 = i32s::default(arch);
446 let mut s1 = i32s::default(arch);
447 let mut s2 = i32s::default(arch);
448 let mut s3 = i32s::default(arch);
449 let mask = u32s::splat(arch, 0x000f000f);
450 while i + 8 < blocks {
451 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
456
457 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
461
462 let wx: i16s = (vx & mask).reinterpret_simd();
463 let wy: i16s = (vy & mask).reinterpret_simd();
464 let d = wx - wy;
465 s0 = s0.dot_simd(d, d);
466
467 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
468 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
469 let d = wx - wy;
470 s1 = s1.dot_simd(d, d);
471
472 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
473 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
474 let d = wx - wy;
475 s2 = s2.dot_simd(d, d);
476
477 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
478 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
479 let d = wx - wy;
480 s3 = s3.dot_simd(d, d);
481
482 i += 8;
483 }
484
485 let remainder = blocks - i;
486
487 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
493
494 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
498
499 let wx: i16s = (vx & mask).reinterpret_simd();
500 let wy: i16s = (vy & mask).reinterpret_simd();
501 let d = wx - wy;
502 s0 = s0.dot_simd(d, d);
503
504 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
505 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
506 let d = wx - wy;
507 s1 = s1.dot_simd(d, d);
508
509 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
510 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
511 let d = wx - wy;
512 s2 = s2.dot_simd(d, d);
513
514 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
515 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
516 let d = wx - wy;
517 s3 = s3.dot_simd(d, d);
518
519 i += remainder;
520
521 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
522 }
523
524 i *= 8;
526
527 if i != len {
529 #[inline(never)]
531 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
532 let mut s: i32 = 0;
533 for i in from..x.len() {
534 let ix = unsafe { x.get_unchecked(i) } as i32;
536 let iy = unsafe { y.get_unchecked(i) } as i32;
538 let d = ix - iy;
539 s += d * d;
540 }
541 s as u32
542 }
543 s += fallback(x, y, i);
544 }
545
546 Ok(MV::new(s))
547 }
548}
549
550#[cfg(target_arch = "x86_64")]
564impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
565 for SquaredL2
566{
567 #[inline(always)]
568 fn run(
569 self,
570 arch: diskann_wide::arch::x86_64::V3,
571 x: USlice<'_, 2>,
572 y: USlice<'_, 2>,
573 ) -> MathematicalResult<u32> {
574 let len = check_lengths!(x, y)?;
575
576 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
577 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
578 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
579
580 let px_u32: *const u32 = x.as_ptr().cast();
581 let py_u32: *const u32 = y.as_ptr().cast();
582
583 let mut i = 0;
584 let mut s: u32 = 0;
585
586 let blocks = len / 16;
588 if i < blocks {
589 let mut s0 = i32s::default(arch);
590 let mut s1 = i32s::default(arch);
591 let mut s2 = i32s::default(arch);
592 let mut s3 = i32s::default(arch);
593 let mask = u32s::splat(arch, 0x00030003);
594 while i + 8 < blocks {
595 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
600
601 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
605
606 let wx: i16s = (vx & mask).reinterpret_simd();
607 let wy: i16s = (vy & mask).reinterpret_simd();
608 let d = wx - wy;
609 s0 = s0.dot_simd(d, d);
610
611 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
612 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
613 let d = wx - wy;
614 s1 = s1.dot_simd(d, d);
615
616 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
617 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
618 let d = wx - wy;
619 s2 = s2.dot_simd(d, d);
620
621 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
622 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
623 let d = wx - wy;
624 s3 = s3.dot_simd(d, d);
625
626 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
627 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
628 let d = wx - wy;
629 s0 = s0.dot_simd(d, d);
630
631 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
632 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
633 let d = wx - wy;
634 s1 = s1.dot_simd(d, d);
635
636 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
637 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
638 let d = wx - wy;
639 s2 = s2.dot_simd(d, d);
640
641 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
642 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
643 let d = wx - wy;
644 s3 = s3.dot_simd(d, d);
645
646 i += 8;
647 }
648
649 let remainder = blocks - i;
650
651 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
657
658 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
662 let wx: i16s = (vx & mask).reinterpret_simd();
663 let wy: i16s = (vy & mask).reinterpret_simd();
664 let d = wx - wy;
665 s0 = s0.dot_simd(d, d);
666
667 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
668 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
669 let d = wx - wy;
670 s1 = s1.dot_simd(d, d);
671
672 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
673 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
674 let d = wx - wy;
675 s2 = s2.dot_simd(d, d);
676
677 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
678 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
679 let d = wx - wy;
680 s3 = s3.dot_simd(d, d);
681
682 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
683 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
684 let d = wx - wy;
685 s0 = s0.dot_simd(d, d);
686
687 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
688 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
689 let d = wx - wy;
690 s1 = s1.dot_simd(d, d);
691
692 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
693 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
694 let d = wx - wy;
695 s2 = s2.dot_simd(d, d);
696
697 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
698 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
699 let d = wx - wy;
700 s3 = s3.dot_simd(d, d);
701
702 i += remainder;
703
704 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
705 }
706
707 i *= 16;
709
710 if i != len {
712 #[inline(never)]
714 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
715 let mut s: i32 = 0;
716 for i in from..x.len() {
717 let ix = unsafe { x.get_unchecked(i) } as i32;
719 let iy = unsafe { y.get_unchecked(i) } as i32;
721 let d = ix - iy;
722 s += d * d;
723 }
724 s as u32
725 }
726 s += fallback(x, y, i);
727 }
728
729 Ok(MV::new(s))
730 }
731}
732
733impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
737where
738 A: Architecture,
739{
740 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
741 bitvector_op::<Self, Unsigned>(x, y)
742 }
743}
744
745macro_rules! impl_fallback_l2 {
747 ($N:literal) => {
748 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
756 #[inline(never)]
757 fn run(
758 self,
759 _: diskann_wide::arch::Scalar,
760 x: USlice<'_, $N>,
761 y: USlice<'_, $N>
762 ) -> MathematicalResult<u32> {
763 let len = check_lengths!(x, y)?;
764
765 let mut accum: i32 = 0;
766 for i in 0..len {
767 let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
769 let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
771 let diff = ix - iy;
772 accum += diff * diff;
773 }
774 Ok(MV::new(accum as u32))
775 }
776 }
777 };
778 ($($N:literal),+ $(,)?) => {
779 $(impl_fallback_l2!($N);)+
780 };
781}
782
783impl_fallback_l2!(7, 6, 5, 4, 3, 2);
784
785#[cfg(target_arch = "x86_64")]
786retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
787
788#[cfg(target_arch = "x86_64")]
789retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
790
791dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
792
793impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
806where
807 A: Architecture,
808 diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
809{
810 #[inline(always)]
811 fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
812 check_lengths!(x, y)?;
813 let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
814 diskann_vector::distance::InnerProduct {},
815 arch,
816 x.as_slice(),
817 y.as_slice(),
818 );
819
820 Ok(MV::new(r.into_inner() as u32))
821 }
822}
823
824#[cfg(target_arch = "x86_64")]
841impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
842 for InnerProduct
843{
844 #[expect(non_camel_case_types)]
845 #[inline(always)]
846 fn run(
847 self,
848 arch: diskann_wide::arch::x86_64::V4,
849 x: USlice<'_, 2>,
850 y: USlice<'_, 2>,
851 ) -> MathematicalResult<u32> {
852 let len = check_lengths!(x, y)?;
853
854 type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
855 type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
856 type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
857 type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
858
859 let px_u32: *const u32 = x.as_ptr().cast();
860 let py_u32: *const u32 = y.as_ptr().cast();
861
862 let mut i = 0;
863 let mut s: u32 = 0;
864
865 let blocks = len.div_ceil(16);
867 if i < blocks {
868 let mut s0 = i32s::default(arch);
869 let mut s1 = i32s::default(arch);
870 let mut s2 = i32s::default(arch);
871 let mut s3 = i32s::default(arch);
872 let mask = u32s::splat(arch, 0x03030303);
873 while i + 16 < blocks {
874 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
879
880 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
884
885 let wx: u8s = (vx & mask).reinterpret_simd();
886 let wy: i8s = (vy & mask).reinterpret_simd();
887 s0 = s0.dot_simd(wx, wy);
888
889 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
890 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
891 s1 = s1.dot_simd(wx, wy);
892
893 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
894 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
895 s2 = s2.dot_simd(wx, wy);
896
897 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
898 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
899 s3 = s3.dot_simd(wx, wy);
900
901 i += 16;
902 }
903
904 let remainder = len / 4 - 4 * i;
908
909 let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
914 let vx: u32s = vx.reinterpret_simd();
915
916 let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
920 let vy: u32s = vy.reinterpret_simd();
921
922 let wx: u8s = (vx & mask).reinterpret_simd();
923 let wy: i8s = (vy & mask).reinterpret_simd();
924 s0 = s0.dot_simd(wx, wy);
925
926 let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
927 let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
928 s1 = s1.dot_simd(wx, wy);
929
930 let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
931 let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
932 s2 = s2.dot_simd(wx, wy);
933
934 let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
935 let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
936 s3 = s3.dot_simd(wx, wy);
937
938 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
939 i = (4 * i) + remainder;
940 }
941
942 i *= 4;
944
945 debug_assert!(len - i <= 3);
947 let rest = (len - i).min(3);
948 if i != len {
949 for j in 0..rest {
950 let ix = unsafe { x.get_unchecked(i + j) } as u32;
952 let iy = unsafe { y.get_unchecked(i + j) } as u32;
954 s += ix * iy;
955 }
956 }
957
958 Ok(MV::new(s))
959 }
960}
961
962#[cfg(target_arch = "x86_64")]
976impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
977 for InnerProduct
978{
979 #[inline(always)]
980 fn run(
981 self,
982 arch: diskann_wide::arch::x86_64::V3,
983 x: USlice<'_, 4>,
984 y: USlice<'_, 4>,
985 ) -> MathematicalResult<u32> {
986 let len = check_lengths!(x, y)?;
987
988 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
989 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
990 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
991
992 let px_u32: *const u32 = x.as_ptr().cast();
993 let py_u32: *const u32 = y.as_ptr().cast();
994
995 let mut i = 0;
996 let mut s: u32 = 0;
997
998 let blocks = len / 8;
999 if i < blocks {
1000 let mut s0 = i32s::default(arch);
1001 let mut s1 = i32s::default(arch);
1002 let mut s2 = i32s::default(arch);
1003 let mut s3 = i32s::default(arch);
1004 let mask = u32s::splat(arch, 0x000f000f);
1005 while i + 8 < blocks {
1006 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1011
1012 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1016
1017 let wx: i16s = (vx & mask).reinterpret_simd();
1018 let wy: i16s = (vy & mask).reinterpret_simd();
1019 s0 = s0.dot_simd(wx, wy);
1020
1021 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1022 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1023 s1 = s1.dot_simd(wx, wy);
1024
1025 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1026 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1027 s2 = s2.dot_simd(wx, wy);
1028
1029 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1030 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1031 s3 = s3.dot_simd(wx, wy);
1032
1033 i += 8;
1034 }
1035
1036 let remainder = blocks - i;
1037
1038 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1044
1045 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1049
1050 let wx: i16s = (vx & mask).reinterpret_simd();
1051 let wy: i16s = (vy & mask).reinterpret_simd();
1052 s0 = s0.dot_simd(wx, wy);
1053
1054 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1055 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1056 s1 = s1.dot_simd(wx, wy);
1057
1058 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1059 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1060 s2 = s2.dot_simd(wx, wy);
1061
1062 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1063 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1064 s3 = s3.dot_simd(wx, wy);
1065
1066 i += remainder;
1067
1068 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1069 }
1070
1071 i *= 8;
1073
1074 if i != len {
1076 #[inline(never)]
1078 fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1079 let mut s: u32 = 0;
1080 for i in from..x.len() {
1081 let ix = unsafe { x.get_unchecked(i) } as u32;
1083 let iy = unsafe { y.get_unchecked(i) } as u32;
1085 s += ix * iy;
1086 }
1087 s
1088 }
1089 s += fallback(x, y, i);
1090 }
1091
1092 Ok(MV::new(s))
1093 }
1094}
1095
1096#[cfg(target_arch = "x86_64")]
1110impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1111 for InnerProduct
1112{
1113 #[inline(always)]
1114 fn run(
1115 self,
1116 arch: diskann_wide::arch::x86_64::V3,
1117 x: USlice<'_, 2>,
1118 y: USlice<'_, 2>,
1119 ) -> MathematicalResult<u32> {
1120 let len = check_lengths!(x, y)?;
1121
1122 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1123 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1124 diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1125
1126 let px_u32: *const u32 = x.as_ptr().cast();
1127 let py_u32: *const u32 = y.as_ptr().cast();
1128
1129 let mut i = 0;
1130 let mut s: u32 = 0;
1131
1132 let blocks = len / 16;
1134 if i < blocks {
1135 let mut s0 = i32s::default(arch);
1136 let mut s1 = i32s::default(arch);
1137 let mut s2 = i32s::default(arch);
1138 let mut s3 = i32s::default(arch);
1139 let mask = u32s::splat(arch, 0x00030003);
1140 while i + 8 < blocks {
1141 let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1146
1147 let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1151
1152 let wx: i16s = (vx & mask).reinterpret_simd();
1153 let wy: i16s = (vy & mask).reinterpret_simd();
1154 s0 = s0.dot_simd(wx, wy);
1155
1156 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1157 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1158 s1 = s1.dot_simd(wx, wy);
1159
1160 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1161 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1162 s2 = s2.dot_simd(wx, wy);
1163
1164 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1165 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1166 s3 = s3.dot_simd(wx, wy);
1167
1168 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1169 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1170 s0 = s0.dot_simd(wx, wy);
1171
1172 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1173 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1174 s1 = s1.dot_simd(wx, wy);
1175
1176 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1177 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1178 s2 = s2.dot_simd(wx, wy);
1179
1180 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1181 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1182 s3 = s3.dot_simd(wx, wy);
1183
1184 i += 8;
1185 }
1186
1187 let remainder = blocks - i;
1188
1189 let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1195
1196 let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1200 let wx: i16s = (vx & mask).reinterpret_simd();
1201 let wy: i16s = (vy & mask).reinterpret_simd();
1202 s0 = s0.dot_simd(wx, wy);
1203
1204 let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1205 let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1206 s1 = s1.dot_simd(wx, wy);
1207
1208 let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1209 let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1210 s2 = s2.dot_simd(wx, wy);
1211
1212 let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1213 let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1214 s3 = s3.dot_simd(wx, wy);
1215
1216 let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1217 let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1218 s0 = s0.dot_simd(wx, wy);
1219
1220 let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1221 let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1222 s1 = s1.dot_simd(wx, wy);
1223
1224 let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1225 let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1226 s2 = s2.dot_simd(wx, wy);
1227
1228 let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1229 let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1230 s3 = s3.dot_simd(wx, wy);
1231
1232 i += remainder;
1233
1234 s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1235 }
1236
1237 i *= 16;
1239
1240 if i != len {
1242 #[inline(never)]
1244 fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1245 let mut s: u32 = 0;
1246 for i in from..x.len() {
1247 let ix = unsafe { x.get_unchecked(i) } as u32;
1249 let iy = unsafe { y.get_unchecked(i) } as u32;
1251 s += ix * iy;
1252 }
1253 s
1254 }
1255 s += fallback(x, y, i);
1256 }
1257
1258 Ok(MV::new(s))
1259 }
1260}
1261
1262impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1266where
1267 A: Architecture,
1268{
1269 #[inline(always)]
1270 fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1271 bitvector_op::<Self, Unsigned>(x, y)
1272 }
1273}
1274
1275macro_rules! impl_fallback_ip {
1277 ($N:literal) => {
1278 impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for InnerProduct {
1286 #[inline(never)]
1287 fn run(
1288 self,
1289 _: diskann_wide::arch::Scalar,
1290 x: USlice<'_, $N>,
1291 y: USlice<'_, $N>
1292 ) -> MathematicalResult<u32> {
1293 let len = check_lengths!(x, y)?;
1294
1295 let mut accum: u32 = 0;
1296 for i in 0..len {
1297 let ix = unsafe { x.get_unchecked(i) } as u32;
1299 let iy = unsafe { y.get_unchecked(i) } as u32;
1301 accum += ix * iy;
1302 }
1303 Ok(MV::new(accum))
1304 }
1305 }
1306 };
1307 ($($N:literal),+ $(,)?) => {
1308 $(impl_fallback_ip!($N);)+
1309 };
1310}
1311
1312impl_fallback_ip!(7, 6, 5, 4, 3, 2);
1313
1314#[cfg(target_arch = "x86_64")]
1315retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1316
1317#[cfg(target_arch = "x86_64")]
1318retarget!(diskann_wide::arch::x86_64::V4, InnerProduct, 7, 6, 4, 5, 3);
1319
1320dispatch_pure!(InnerProduct, 1, 2, 3, 4, 5, 6, 7, 8);
1321
1322impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1351 for InnerProduct
1352where
1353 A: Architecture,
1354{
1355 #[inline(always)]
1356 fn run(
1357 self,
1358 _: A,
1359 x: USlice<'_, 4, BitTranspose>,
1360 y: USlice<'_, 1, Dense>,
1361 ) -> MathematicalResult<u32> {
1362 let len = check_lengths!(x, y)?;
1363
1364 let px: *const u64 = x.as_ptr().cast();
1371 let py: *const u64 = y.as_ptr().cast();
1372
1373 let mut i = 0;
1374 let mut s: u32 = 0;
1375
1376 let blocks = len / 64;
1377 while i < blocks {
1378 let bits = unsafe { py.add(i).read_unaligned() };
1380
1381 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1387 s += (bits & b0).count_ones();
1388
1389 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1391 s += (bits & b1).count_ones() << 1;
1392
1393 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1395 s += (bits & b2).count_ones() << 2;
1396
1397 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1399 s += (bits & b3).count_ones() << 3;
1400
1401 i += 1;
1402 }
1403
1404 if 64 * i == len {
1406 return Ok(MV::new(s));
1407 }
1408
1409 let k = i * 8;
1411
1412 let py = unsafe { py.cast::<u8>().add(k) };
1417 let bytes_remaining = y.bytes() - k;
1418 let mut bits: u64 = 0;
1419
1420 for j in 0..bytes_remaining.min(8) {
1423 bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1426 }
1427
1428 bits &= (0x01u64 << (len - (64 * i))) - 1;
1431
1432 let b0 = unsafe { px.add(4 * i).read_unaligned() };
1437 s += (bits & b0).count_ones();
1438
1439 let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1441 s += (bits & b1).count_ones() << 1;
1442
1443 let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1445 s += (bits & b2).count_ones() << 2;
1446
1447 let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1449 s += (bits & b3).count_ones() << 3;
1450
1451 Ok(MV::new(s))
1452 }
1453}
1454
1455impl
1456 PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1457 for InnerProduct
1458{
1459 fn evaluate(
1460 x: USlice<'_, 4, BitTranspose>,
1461 y: USlice<'_, 1, Dense>,
1462 ) -> MathematicalResult<u32> {
1463 (diskann_wide::ARCH).run2(Self, x, y)
1464 }
1465}
1466
1467#[cfg(target_arch = "x86_64")]
1498impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
1499 for InnerProduct
1500{
1501 #[inline(always)]
1502 fn run(
1503 self,
1504 arch: diskann_wide::arch::x86_64::V3,
1505 x: &[f32],
1506 y: USlice<'_, 1>,
1507 ) -> MathematicalResult<f32> {
1508 let len = check_lengths!(x, y)?;
1509
1510 use std::arch::x86_64::*;
1511
1512 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1513 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1514
1515 let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1518
1519 let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
1521
1522 let px: *const f32 = x.as_ptr();
1523 let py: *const u32 = y.as_ptr().cast();
1524
1525 let mut i = 0;
1526 let mut s = f32s::default(arch);
1527
1528 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1529 let to_f32 = |v: u32s| -> f32s {
1530 f32s::from_underlying(arch, unsafe {
1533 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1534 })
1535 };
1536
1537 let blocks = len / 32;
1539 if i < blocks {
1540 let mut s0 = f32s::default(arch);
1541 let mut s1 = f32s::default(arch);
1542
1543 while i < blocks {
1544 let iy = prep(unsafe { py.add(i).read_unaligned() });
1546
1547 let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
1549 let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
1551 let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
1553 let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
1555
1556 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1557 s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
1558 s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
1559 s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
1560
1561 i += 1;
1562 }
1563 s = s0 + s1;
1564 }
1565
1566 let remainder = len % 32;
1567 if remainder != 0 {
1568 let tail = if len % 8 == 0 { 8 } else { len % 8 };
1569
1570 let py = unsafe { py.add(blocks) };
1573
1574 if remainder <= 8 {
1575 unsafe {
1578 load_one(py, |iy| {
1579 let iy = prep(iy);
1580 let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
1581 s = ix.mul_add_simd(to_f32(iy), s);
1582 })
1583 }
1584 } else if remainder <= 16 {
1585 unsafe {
1588 load_two(py, |iy| {
1589 let iy = prep(iy);
1590 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1591 let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
1592 s = ix0.mul_add_simd(to_f32(iy), s);
1593 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1594 })
1595 }
1596 } else if remainder <= 24 {
1597 unsafe {
1600 load_three(py, |iy| {
1601 let iy = prep(iy);
1602
1603 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1604 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1605 let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
1606
1607 s = ix0.mul_add_simd(to_f32(iy), s);
1608 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1609 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1610 })
1611 }
1612 } else {
1613 unsafe {
1616 load_four(py, |iy| {
1617 let iy = prep(iy);
1618
1619 let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1620 let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1621 let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
1622 let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
1623
1624 s = ix0.mul_add_simd(to_f32(iy), s);
1625 s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1626 s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1627 s = ix3.mul_add_simd(to_f32(iy >> 24), s);
1628 })
1629 }
1630 }
1631 }
1632
1633 Ok(MV::new(s.sum_tree()))
1634 }
1635}
1636
1637#[cfg(target_arch = "x86_64")]
1641impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
1642 for InnerProduct
1643{
1644 #[inline(always)]
1645 fn run(
1646 self,
1647 arch: diskann_wide::arch::x86_64::V3,
1648 x: &[f32],
1649 y: USlice<'_, 2>,
1650 ) -> MathematicalResult<f32> {
1651 let len = check_lengths!(x, y)?;
1652
1653 use std::arch::x86_64::*;
1654
1655 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1656 diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1657
1658 let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
1662
1663 let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
1665
1666 let px: *const f32 = x.as_ptr();
1667 let py: *const u32 = y.as_ptr().cast();
1668
1669 let mut i = 0;
1670 let mut s = f32s::default(arch);
1671
1672 let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1673 let to_f32 = |v: u32s| -> f32s {
1674 f32s::from_underlying(arch, unsafe {
1677 _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1678 })
1679 };
1680
1681 let blocks = len / 16;
1682 if blocks != 0 {
1683 let mut s0 = f32s::default(arch);
1684 let mut s1 = f32s::default(arch);
1685
1686 while i + 2 <= blocks {
1688 let iy = prep(unsafe { py.add(i).read_unaligned() });
1691
1692 let (ix0, ix1) = unsafe {
1695 (
1696 f32s::load_simd(arch, px.add(16 * i)),
1697 f32s::load_simd(arch, px.add(16 * i + 8)),
1698 )
1699 };
1700
1701 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1702 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1703
1704 let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
1707
1708 let (ix0, ix1) = unsafe {
1710 (
1711 f32s::load_simd(arch, px.add(16 * (i + 1))),
1712 f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
1713 )
1714 };
1715
1716 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1717 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1718
1719 i += 2;
1720 }
1721
1722 if i < blocks {
1724 let iy = prep(unsafe { py.add(i).read_unaligned() });
1727
1728 let (ix0, ix1) = unsafe {
1730 (
1731 f32s::load_simd(arch, px.add(16 * i)),
1732 f32s::load_simd(arch, px.add(16 * i + 8)),
1733 )
1734 };
1735
1736 s0 = ix0.mul_add_simd(to_f32(iy), s0);
1737 s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1738 }
1739
1740 s = s0 + s1;
1741 }
1742
1743 let remainder = len % 16;
1744 if remainder != 0 {
1745 let tail = if len % 8 == 0 { 8 } else { len % 8 };
1746 let py = unsafe { py.add(blocks) };
1749
1750 if remainder <= 4 {
1751 unsafe {
1754 load_one(py, |iy| {
1755 let iy = prep(iy);
1756 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1757 s = ix.mul_add_simd(to_f32(iy), s);
1758 });
1759 }
1760 } else if remainder <= 8 {
1761 unsafe {
1764 load_two(py, |iy| {
1765 let iy = prep(iy);
1766 let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1767 s = ix.mul_add_simd(to_f32(iy), s);
1768 });
1769 }
1770 } else if remainder <= 12 {
1771 unsafe {
1774 load_three(py, |iy| {
1775 let iy = prep(iy);
1776 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1777 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1778 s = ix0.mul_add_simd(to_f32(iy), s);
1779 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1780 });
1781 }
1782 } else {
1783 unsafe {
1786 load_four(py, |iy| {
1787 let iy = prep(iy);
1788 let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1789 let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1790 s = ix0.mul_add_simd(to_f32(iy), s);
1791 s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1792 });
1793 }
1794 }
1795 }
1796
1797 Ok(MV::new(s.sum_tree()))
1798 }
1799}
1800
1801#[cfg(target_arch = "x86_64")]
1806impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
1807 for InnerProduct
1808{
1809 #[inline(always)]
1810 fn run(
1811 self,
1812 arch: diskann_wide::arch::x86_64::V3,
1813 x: &[f32],
1814 y: USlice<'_, 4>,
1815 ) -> MathematicalResult<f32> {
1816 let len = check_lengths!(x, y)?;
1817
1818 diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1819 diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1820
1821 let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
1822 let mask = i32s::splat(arch, 0x0f);
1823
1824 let to_f32 = |v: u32| -> f32s {
1825 ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
1826 };
1827
1828 let px: *const f32 = x.as_ptr();
1829 let py: *const u32 = y.as_ptr().cast();
1830
1831 let mut i = 0;
1832 let mut s = f32s::default(arch);
1833
1834 let blocks = len / 8;
1835 while i < blocks {
1836 let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
1838 let iy = to_f32(unsafe { py.add(i).read_unaligned() });
1840 s = ix.mul_add_simd(iy, s);
1841
1842 i += 1;
1843 }
1844
1845 let remainder = len % 8;
1846 if remainder != 0 {
1847 let f = |iy| {
1848 let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
1852 s = ix.mul_add_simd(to_f32(iy), s);
1853 };
1854
1855 let py = unsafe { py.add(blocks) };
1858
1859 if remainder <= 2 {
1860 unsafe { load_one(py, f) };
1862 } else if remainder <= 4 {
1863 unsafe { load_two(py, f) };
1865 } else if remainder <= 6 {
1866 unsafe { load_three(py, f) };
1868 } else {
1869 unsafe { load_four(py, f) };
1871 }
1872 }
1873
1874 Ok(MV::new(s.sum_tree()))
1875 }
1876}
1877
1878impl<const N: usize>
1879 Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
1880 for InnerProduct
1881where
1882 Unsigned: Representation<N>,
1883{
1884 #[inline(always)]
1887 fn run(
1888 self,
1889 _: diskann_wide::arch::Scalar,
1890 x: &[f32],
1891 y: USlice<'_, N>,
1892 ) -> MathematicalResult<f32> {
1893 check_lengths!(x, y)?;
1894
1895 let mut s = 0.0;
1896 for (i, x) in x.iter().enumerate() {
1897 let y = unsafe { y.get_unchecked(i) } as f32;
1900 s += x * y;
1901 }
1902
1903 Ok(MV::new(s))
1904 }
1905}
1906
1907#[cfg(target_arch = "x86_64")]
1909macro_rules! ip_retarget {
1910 ($arch:path, $N:literal) => {
1911 impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
1912 for InnerProduct
1913 {
1914 #[inline(always)]
1915 fn run(
1916 self,
1917 arch: $arch,
1918 x: &[f32],
1919 y: USlice<'_, $N>,
1920 ) -> MathematicalResult<f32> {
1921 self.run(arch.retarget(), x, y)
1922 }
1923 }
1924 };
1925 ($arch:path, $($Ns:literal),*) => {
1926 $(ip_retarget!($arch, $Ns);)*
1927 }
1928}
1929
1930#[cfg(target_arch = "x86_64")]
1931ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
1932
1933#[cfg(target_arch = "x86_64")]
1934ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
1935
1936macro_rules! dispatch_full_ip {
1939 ($N:literal) => {
1940 impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
1944 for InnerProduct
1945 {
1946 fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
1947 Self.run(ARCH, x, y)
1948 }
1949 }
1950 };
1951 ($($Ns:literal),*) => {
1952 $(dispatch_full_ip!($Ns);)*
1953 }
1954}
1955
1956dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
1957
1958#[cfg(test)]
1963mod tests {
1964 use std::{collections::HashMap, sync::LazyLock};
1965
1966 use diskann_utils::Reborrow;
1967 use rand::{
1968 Rng, SeedableRng,
1969 distr::{Distribution, Uniform},
1970 rngs::StdRng,
1971 seq::IndexedRandom,
1972 };
1973
1974 use super::*;
1975 use crate::bits::{BoxedBitSlice, Representation, Unsigned};
1976
1977 type MR = MathematicalResult<u32>;
1978
1979 fn test_bitslice_distances<const NBITS: usize, R>(
1989 dim_max: usize,
1990 trials_per_dim: usize,
1991 evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
1992 evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
1993 context: &str,
1994 rng: &mut R,
1995 ) where
1996 Unsigned: Representation<NBITS>,
1997 R: Rng,
1998 {
1999 let domain = Unsigned::domain_const::<NBITS>();
2000 let min: i64 = *domain.start();
2001 let max: i64 = *domain.end();
2002
2003 let dist = Uniform::new_inclusive(min, max).unwrap();
2004
2005 for dim in 0..dim_max {
2006 let mut x_reference: Vec<u8> = vec![0; dim];
2007 let mut y_reference: Vec<u8> = vec![0; dim];
2008
2009 let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2010 let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2011
2012 for trial in 0..trials_per_dim {
2013 x_reference
2014 .iter_mut()
2015 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2016 y_reference
2017 .iter_mut()
2018 .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2019
2020 x.as_mut_slice().fill(u8::MAX);
2023 y.as_mut_slice().fill(u8::MAX);
2024
2025 for i in 0..dim {
2026 x.set(i, x_reference[i].into()).unwrap();
2027 y.set(i, y_reference[i].into()).unwrap();
2028 }
2029
2030 let expected: MV<f32> =
2032 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2033
2034 let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2035
2036 assert_eq!(
2038 expected.into_inner(),
2039 got.into_inner() as f32,
2040 "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2041 NBITS,
2042 dim,
2043 trial,
2044 context,
2045 );
2046
2047 let expected: MV<f32> =
2049 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2050
2051 let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2052
2053 assert_eq!(
2055 expected.into_inner(),
2056 got.into_inner() as f32,
2057 "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2058 NBITS,
2059 dim,
2060 trial,
2061 context,
2062 );
2063 }
2064 }
2065
2066 let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2068 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2069
2070 assert!(
2071 evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2072 "context: {}",
2073 context
2074 );
2075 assert!(
2076 evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2077 "context: {}",
2078 context
2079 );
2080
2081 assert!(
2082 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2083 "context: {}",
2084 context
2085 );
2086 assert!(
2087 evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2088 "context: {}",
2089 context
2090 );
2091 }
2092
2093 cfg_if::cfg_if! {
2094 if #[cfg(miri)] {
2095 const MAX_DIM: usize = 128;
2096 const TRIALS_PER_DIM: usize = 1;
2097 } else {
2098 const MAX_DIM: usize = 256;
2099 const TRIALS_PER_DIM: usize = 20;
2100 }
2101 }
2102
2103 static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2113 use ArchKey::{Scalar, X86_64_V3, X86_64_V4};
2114 [
2115 (Key::new(1, Scalar), Bounds::new(64, 64)),
2116 (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2117 (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2118 (Key::new(2, Scalar), Bounds::new(64, 64)),
2119 (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2121 (Key::new(2, X86_64_V4), Bounds::new(768, 600)), (Key::new(3, Scalar), Bounds::new(64, 64)),
2123 (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2124 (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2125 (Key::new(4, Scalar), Bounds::new(64, 64)),
2126 (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2128 (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2129 (Key::new(5, Scalar), Bounds::new(64, 64)),
2130 (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2131 (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2132 (Key::new(6, Scalar), Bounds::new(64, 64)),
2133 (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2134 (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2135 (Key::new(7, Scalar), Bounds::new(64, 64)),
2136 (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2137 (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2138 (Key::new(8, Scalar), Bounds::new(64, 64)),
2139 (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2140 (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2141 ]
2142 .into_iter()
2143 .collect()
2144 });
2145
2146 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2147 enum ArchKey {
2148 Scalar,
2149 #[expect(non_camel_case_types)]
2150 X86_64_V3,
2151 #[expect(non_camel_case_types)]
2152 X86_64_V4,
2153 }
2154
2155 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2156 struct Key {
2157 nbits: usize,
2158 arch: ArchKey,
2159 }
2160
2161 impl Key {
2162 fn new(nbits: usize, arch: ArchKey) -> Self {
2163 Self { nbits, arch }
2164 }
2165 }
2166
2167 #[derive(Debug, Clone, Copy)]
2168 struct Bounds {
2169 standard: usize,
2170 miri: usize,
2171 }
2172
2173 impl Bounds {
2174 fn new(standard: usize, miri: usize) -> Self {
2175 Self { standard, miri }
2176 }
2177
2178 fn get(&self) -> usize {
2179 if cfg!(miri) { self.miri } else { self.standard }
2180 }
2181 }
2182
2183 macro_rules! test_bitslice {
2184 ($name:ident, $nbits:literal, $seed:literal) => {
2185 #[test]
2186 fn $name() {
2187 let mut rng = StdRng::seed_from_u64($seed);
2188
2189 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2190
2191 test_bitslice_distances::<$nbits, _>(
2192 max_dim,
2193 TRIALS_PER_DIM,
2194 &|x, y| SquaredL2::evaluate(x, y),
2195 &|x, y| InnerProduct::evaluate(x, y),
2196 "pure distance function",
2197 &mut rng,
2198 );
2199
2200 test_bitslice_distances::<$nbits, _>(
2201 max_dim,
2202 TRIALS_PER_DIM,
2203 &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2204 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2205 "scalar arch",
2206 &mut rng,
2207 );
2208
2209 #[cfg(target_arch = "x86_64")]
2211 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2212 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2213 test_bitslice_distances::<$nbits, _>(
2214 max_dim,
2215 TRIALS_PER_DIM,
2216 &|x, y| arch.run2(SquaredL2, x, y),
2217 &|x, y| arch.run2(InnerProduct, x, y),
2218 "x86-64-v3",
2219 &mut rng,
2220 );
2221 }
2222
2223 #[cfg(target_arch = "x86_64")]
2224 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2225 let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2226 test_bitslice_distances::<$nbits, _>(
2227 max_dim,
2228 TRIALS_PER_DIM,
2229 &|x, y| arch.run2(SquaredL2, x, y),
2230 &|x, y| arch.run2(InnerProduct, x, y),
2231 "x86-64-v4",
2232 &mut rng,
2233 );
2234 }
2235 }
2236 };
2237 }
2238
2239 test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2240 test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2241 test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2242 test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2243 test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2244 test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2245 test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2246 test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2247
2248 fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2253 where
2254 R: Rng,
2255 {
2256 let dist: [i8; 2] = [-1, 1];
2257
2258 for dim in 0..dim_max {
2259 let mut x_reference: Vec<i8> = vec![1; dim];
2260 let mut y_reference: Vec<i8> = vec![1; dim];
2261
2262 let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2263 let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2264
2265 for _ in 0..trials_per_dim {
2266 x_reference
2267 .iter_mut()
2268 .for_each(|i| *i = *dist.choose(rng).unwrap());
2269 y_reference
2270 .iter_mut()
2271 .for_each(|i| *i = *dist.choose(rng).unwrap());
2272
2273 x.as_mut_slice().fill(u8::MAX);
2276 y.as_mut_slice().fill(u8::MAX);
2277
2278 for i in 0..dim {
2279 x.set(i, x_reference[i].into()).unwrap();
2280 y.set(i, y_reference[i].into()).unwrap();
2281 }
2282
2283 let expected: MV<f32> =
2289 diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2290 let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2291 assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2292 }
2293 }
2294
2295 let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2296 let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2297 assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2298 assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2299 }
2300
2301 #[test]
2302 fn test_hamming_distance() {
2303 let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2304 test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2305 }
2306
2307 fn test_bit_transpose_distances<R>(
2312 dim_max: usize,
2313 trials_per_dim: usize,
2314 evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2315 context: &str,
2316 rng: &mut R,
2317 ) where
2318 R: Rng,
2319 {
2320 let dist_4bit = {
2321 let domain = Unsigned::domain_const::<4>();
2322 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2323 };
2324
2325 let dist_1bit = {
2326 let domain = Unsigned::domain_const::<1>();
2327 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2328 };
2329
2330 for dim in 0..dim_max {
2331 let mut x_reference: Vec<u8> = vec![0; dim];
2332 let mut y_reference: Vec<u8> = vec![0; dim];
2333
2334 let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2335 let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2336
2337 for trial in 0..trials_per_dim {
2338 x_reference
2339 .iter_mut()
2340 .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2341 y_reference
2342 .iter_mut()
2343 .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2344
2345 x.as_mut_slice().fill(u8::MAX);
2347 y.as_mut_slice().fill(u8::MAX);
2348
2349 for i in 0..dim {
2350 x.set(i, x_reference[i].into()).unwrap();
2351 y.set(i, y_reference[i].into()).unwrap();
2352 }
2353
2354 let expected: MV<f32> =
2356 diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2357
2358 let got = evaluate_ip(x.reborrow(), y.reborrow());
2359
2360 assert_eq!(
2362 expected.into_inner(),
2363 got.unwrap().into_inner() as f32,
2364 "faild InnerProduct for dim = {}, trial = {} -- context {}",
2365 dim,
2366 trial,
2367 context,
2368 );
2369 }
2370 }
2371
2372 let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2373 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2374 assert!(
2375 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2376 "context: {}",
2377 context
2378 );
2379
2380 let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2381 assert!(
2382 evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2383 "context: {}",
2384 context
2385 );
2386 }
2387
2388 #[test]
2389 fn test_bit_transpose_distance() {
2390 let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2391
2392 test_bit_transpose_distances(
2393 MAX_DIM,
2394 TRIALS_PER_DIM,
2395 &|x, y| InnerProduct::evaluate(x, y),
2396 "pure distance function",
2397 &mut rng,
2398 );
2399
2400 test_bit_transpose_distances(
2401 MAX_DIM,
2402 TRIALS_PER_DIM,
2403 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2404 "scalar",
2405 &mut rng,
2406 );
2407
2408 #[cfg(target_arch = "x86_64")]
2410 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2411 test_bit_transpose_distances(
2412 MAX_DIM,
2413 TRIALS_PER_DIM,
2414 &|x, y| arch.run2(InnerProduct, x, y),
2415 "x86-64-v3",
2416 &mut rng,
2417 );
2418 }
2419
2420 #[cfg(target_arch = "x86_64")]
2422 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2423 test_bit_transpose_distances(
2424 MAX_DIM,
2425 TRIALS_PER_DIM,
2426 &|x, y| arch.run2(InnerProduct, x, y),
2427 "x86-64-v4",
2428 &mut rng,
2429 );
2430 }
2431 }
2432
2433 fn test_full_distances<const NBITS: usize>(
2438 dim_max: usize,
2439 trials_per_dim: usize,
2440 evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
2441 context: &str,
2442 rng: &mut impl Rng,
2443 ) where
2444 Unsigned: Representation<NBITS>,
2445 {
2446 let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
2448 let dist_bit = {
2449 let domain = Unsigned::domain_const::<NBITS>();
2450 Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2451 };
2452
2453 for dim in 0..dim_max {
2454 let mut x: Vec<f32> = vec![0.0; dim];
2455
2456 let mut y_reference: Vec<u8> = vec![0; dim];
2457 let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
2458
2459 for trial in 0..trials_per_dim {
2460 x.iter_mut()
2461 .for_each(|i| *i = *dist_float.choose(rng).unwrap());
2462 y_reference
2463 .iter_mut()
2464 .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
2465
2466 y.as_mut_slice().fill(u8::MAX);
2468
2469 let mut expected = 0.0;
2470 for i in 0..dim {
2471 y.set(i, y_reference[i].into()).unwrap();
2472 expected += y_reference[i] as f32 * x[i];
2473 }
2474
2475 let got = evaluate_ip(&x, y.reborrow()).unwrap();
2477
2478 assert_eq!(
2480 expected,
2481 got.into_inner(),
2482 "faild InnerProduct for dim = {}, trial = {} -- context {}",
2483 dim,
2484 trial,
2485 context,
2486 );
2487
2488 let scalar: MV<f32> = InnerProduct
2491 .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
2492 .unwrap();
2493 assert_eq!(got.into_inner(), scalar.into_inner());
2494 }
2495 }
2496
2497 let x = vec![0.0; 10];
2499 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2500 assert!(
2501 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2502 "context: {}",
2503 context
2504 );
2505
2506 let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
2507 assert!(
2508 evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2509 "context: {}",
2510 context
2511 );
2512 }
2513
2514 macro_rules! test_full {
2515 ($name:ident, $nbits:literal, $seed:literal) => {
2516 #[test]
2517 fn $name() {
2518 let mut rng = StdRng::seed_from_u64($seed);
2519
2520 test_full_distances::<$nbits>(
2521 MAX_DIM,
2522 TRIALS_PER_DIM,
2523 &|x, y| InnerProduct::evaluate(x, y),
2524 "pure distance function",
2525 &mut rng,
2526 );
2527
2528 test_full_distances::<$nbits>(
2529 MAX_DIM,
2530 TRIALS_PER_DIM,
2531 &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2532 "scalar",
2533 &mut rng,
2534 );
2535
2536 #[cfg(target_arch = "x86_64")]
2538 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2539 test_full_distances::<$nbits>(
2540 MAX_DIM,
2541 TRIALS_PER_DIM,
2542 &|x, y| arch.run2(InnerProduct, x, y),
2543 "x86-64-v3",
2544 &mut rng,
2545 );
2546 }
2547
2548 #[cfg(target_arch = "x86_64")]
2549 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked() {
2550 test_full_distances::<$nbits>(
2551 MAX_DIM,
2552 TRIALS_PER_DIM,
2553 &|x, y| arch.run2(InnerProduct, x, y),
2554 "x86-64-v4",
2555 &mut rng,
2556 );
2557 }
2558 }
2559 };
2560 }
2561
2562 test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
2563 test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
2564 test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
2565 test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
2566 test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
2567 test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
2568 test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
2569 test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
2570}