Skip to main content

diskann_quantization/bits/
distances.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! # Low-level functions
7//!
8//! The methods here are meant to be primitives used by the distance functions for the
9//! various scalar-quantized-like quantizers.
10//!
11//! As such, they typically return integer distance results since they largely operate over
12//! raw bit-slices.
13//!
14//! ## Micro-architecture Mapping
15//!
16//! There are two interfaces for interacting with the distance primitives:
17//!
18//! * [`diskann_wide::arch::Target2`]: A micro-architecture aware interface where the target
19//!   micro-architecture is provided as an explicit argument.
20//!
21//!   This can be used in conjunction with [`diskann_wide::Architecture::run2`] to apply the
22//!   necessary target-features to opt-into newer architecture code generation when
23//!   compiling the whole binary for an older architecture.
24//!
25//!   This interface is also composable with micro-architecture dispatching done higher in
26//!   the callstack, and so should be preferred when incorporating into quantizer distance
27//!   computations.
28//!
29//! * [`diskann_vector::PureDistanceFunction`]: If micro-architecture awareness is not needed,
30//!   this provides a simple interface targeting [`diskann_wide::ARCH`] (the current compilation
31//!   architecture).
32//!
33//!   This interface will always yield a binary compatible with the compilation architecture
34//!   target, but will not enable faster code-paths when compiling for older architectures.
35//!
36//! The following table summarizes the implementation status of kernels. All kernels have
37//! `diskann_wide::arch::Scalar` implementation fallbacks.
38//!
39//! Implementation Kind:
40//!
41//! * "Fallback": A fallback implementation using scalar indexing.
42//!
43//! * "Optimized": A better implementation than "fallback" that does not contain
44//!   target-depeendent code, instead relying on compiler optimizations.
45//!
46//!   Micro-architecture dispatch is still relevant as it allows the compiler to generate
47//!   better code for newer machines.
48//!
49//! * "Yes": Architecture specific SIMD implementation exists.
50//!
51//! * "No": Architecture specific implementation does not exist - the next most-specific
52//!   implementation is used. For example, if a `x86-64-v3` implementation does not exist,
53//!   then the "scalar" implementation will be used instead.
54//!
55//! Type Aliases
56//!
57//! * `USlice<N>`: `BitSlice<N, Unsigned, Dense>`
58//! * `TSlice<N>`: `BitSlice<N, Unsigned, BitTranspose>`
59//! * `BSlice`: `BitSlice<1, Binary, Dense>`
60//!
61//! * `MV<T>`: [`diskann_vector::MathematicalValue<T>`]
62//!
63//! ### Inner Product
64//!
65//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
66//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
67//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
68//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Yes       | Fallback  |
69//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
70//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
71//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
72//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
73//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
74//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       | Fallback  |
75//! |               |               |           |           |               |           |           |
76//! | `USlice<8>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
77//! | `USlice<8>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
78//! | `USlice<8>`   | `USlice<1>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
79//! |               |               | `       ` |           |               |           |           |
80//! | `TSlice<4>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Optimized | Optimized |
81//! |               |               | `       ` |           |               |           |           |
82//! | `&[f32]`      | `USlice<1>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
83//! | `&[f32]`      | `USlice<2>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
84//! | `&[f32]`      | `USlice<3>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
85//! | `&[f32]`      | `USlice<4>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
86//! | `&[f32]`      | `USlice<5>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
87//! | `&[f32]`      | `USlice<6>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
88//! | `&[f32]`      | `USlice<7>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
89//! | `&[f32]`      | `USlice<8>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
90//!
91//! ### Squared L2
92//!
93//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
94//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
95//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
96//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
97//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
98//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
99//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
100//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
101//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
102//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       | Fallback  |
103//!
104//! ### Hamming
105//!
106//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
107//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
108//! | `BSlice`      | `BSlice`      | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
109
110use diskann_vector::PureDistanceFunction;
111use diskann_wide::{ARCH, Architecture, arch::Target2};
112#[cfg(target_arch = "x86_64")]
113use diskann_wide::{
114    SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
115};
116
117use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
118use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
119
120// Convenience alias.
121type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
122
123/// Retarget the architectures via the `retarget` inherent method.
124///
125/// * [`diskann_wide::arch::x86_64::V3`] -> [`diskann_wide::arch::Scalar`]
126/// * [`diskann_wide::arch::x86_64::V4`] -> [`diskann_wide::arch::x86_64::V3`]
127/// * [`diskann_wide::arch::aarch64::Neon`] -> [`diskann_wide::arch::Scalar`]
128macro_rules! retarget {
129    ($arch:path, $op:ty, ($N:literal, $M:literal)) => {
130        impl Target2<
131            $arch,
132            MathematicalResult<u32>,
133            USlice<'_, $N>,
134            USlice<'_, $M>,
135        > for $op {
136            #[inline(always)]
137            fn run(
138                self,
139                arch: $arch,
140                x: USlice<'_, $N>,
141                y: USlice<'_, $M>
142            ) -> MathematicalResult<u32> {
143                self.run(arch.retarget(), x, y)
144            }
145        }
146    };
147    ($arch:path, $op:ty, $N:literal) => {
148        retarget!($arch, $op, ($N, $N));
149    };
150    ($arch:path, $op:ty, $($args:tt),+ $(,)?) => {
151        $(retarget!($arch, $op, $args);)+
152    };
153}
154
155/// Impledment [`diskann_vector::PureDistanceFunction`] using the current compilation architecture
156macro_rules! dispatch_pure {
157    ($op:ty, ($N:literal, $M:literal)) => {
158        impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $M>, MathematicalResult<u32>> for $op {
159            #[inline(always)]
160            fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $M>) -> MathematicalResult<u32> {
161                (diskann_wide::ARCH).run2(Self, x, y)
162            }
163        }
164    };
165    ($op:ty, $N:literal) => {
166        dispatch_pure!($op, ($N, $N));
167    };
168    ($op:ty, $($args:tt),+ $(,)?) => {
169        $(dispatch_pure!($op, $args);)+
170    }
171}
172
173/// Load 1 byte beginning at `ptr` and invoke `f` with that byte.
174///
175/// # Safety
176///
177/// * The memory range `[ptr, ptr + 1)` (in bytes) must be dereferencable.
178/// * `ptr` does not need to be aligned.
179#[cfg(target_arch = "x86_64")]
180unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
181where
182    F: FnMut(u32) -> R,
183{
184    // SAFETY: Caller asserts that one byte is readable.
185    f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
186}
187
188/// Load 2 bytes beginning at `ptr` and invoke `f` with the value.
189///
190/// # Safety
191///
192/// * The memory range `[ptr, ptr + 2)` (in bytes) must be dereferencable.
193/// * `ptr` does not need to be aligned.
194#[cfg(target_arch = "x86_64")]
195unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
196where
197    F: FnMut(u32) -> R,
198{
199    // SAFETY: Caller asserts that two bytes are readable.
200    f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
201}
202
203/// Load 3 bytes beginning at `ptr` and invoke `f` with the value.
204///
205/// # Safety
206///
207/// * The memory range `[ptr, ptr + 3)` (in bytes) must be dereferencable.
208/// * `ptr` does not need to be aligned.
209#[cfg(target_arch = "x86_64")]
210unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
211where
212    F: FnMut(u32) -> R,
213{
214    // SAFETY: Caller asserts that three bytes are readable. This loads the first two.
215    let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
216    // SAFETY: Caller asserts that three bytes are readable. This loads the third.
217    let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
218    f(lo | hi << 16)
219}
220
221/// Load 4 bytes beginning at `ptr` and invoke `f` with the value.
222///
223/// # Safety
224///
225/// * The memory range `[ptr, ptr + 4)` (in bytes) must be dereferencable.
226/// * `ptr` does not need to be aligned.
227#[cfg(target_arch = "x86_64")]
228unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
229where
230    F: FnMut(u32) -> R,
231{
232    // SAFETY: Caller asserts that four bytes are readable.
233    f(unsafe { ptr.read_unaligned() })
234}
235
236////////////////////////////
237// Distances on BitSlices //
238////////////////////////////
239
240/// Operations to apply to 1-bit encodings.
241///
242/// The general structure of 1-bit vector operations is the same, but the element wise
243/// operator is different. This trait encapsulates the differences in behavior required
244/// for different distance function.
245///
246/// The exact operations to apply depending on the representation of the bit encoding.
247trait BitVectorOp<Repr>
248where
249    Repr: Representation<1>,
250{
251    /// Apply the op to all bits in the 64-bit arguments.
252    fn on_u64(x: u64, y: u64) -> u32;
253
254    /// Apply the op to all bits in the 8-bit arguments.
255    ///
256    /// NOTE: Implementations must have the correct behavior when the upper bits of `x`
257    /// and `y` are set to 0 when handling epilogues.
258    fn on_u8(x: u8, y: u8) -> u32;
259}
260
261/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
262impl BitVectorOp<Unsigned> for SquaredL2 {
263    #[inline(always)]
264    fn on_u64(x: u64, y: u64) -> u32 {
265        (x ^ y).count_ones()
266    }
267    #[inline(always)]
268    fn on_u8(x: u8, y: u8) -> u32 {
269        (x ^ y).count_ones()
270    }
271}
272
273/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
274impl BitVectorOp<Binary> for Hamming {
275    #[inline(always)]
276    fn on_u64(x: u64, y: u64) -> u32 {
277        (x ^ y).count_ones()
278    }
279    #[inline(always)]
280    fn on_u8(x: u8, y: u8) -> u32 {
281        (x ^ y).count_ones()
282    }
283}
284
285/// The implementation as `and` is not straight-forward.
286///
287/// Recall that scalar quantization encodings are unsigned, so "0" is zero and "1" is some
288/// non-zero value.
289///
290/// When computing the inner product, `0 * x == 0` for all `x` and only `x * x` has a
291/// non-zero value. Therefore, the elementwise op is an `and` and not `xnor`.
292impl BitVectorOp<Unsigned> for InnerProduct {
293    #[inline(always)]
294    fn on_u64(x: u64, y: u64) -> u32 {
295        (x & y).count_ones()
296    }
297    #[inline(always)]
298    fn on_u8(x: u8, y: u8) -> u32 {
299        (x & y).count_ones()
300    }
301}
302
303/// A general algorithm for applying a bitwise operand to two dense bit vectors of equal
304/// but arbitrary length.
305///
306/// NOTE: The `inline(always)` attribute is required to inheret the caller's target-features.
307#[inline(always)]
308fn bitvector_op<Op, Repr>(
309    x: BitSlice<'_, 1, Repr>,
310    y: BitSlice<'_, 1, Repr>,
311) -> MathematicalResult<u32>
312where
313    Repr: Representation<1>,
314    Op: BitVectorOp<Repr>,
315{
316    let len = check_lengths!(x, y)?;
317
318    let px: *const u64 = x.as_ptr().cast();
319    let py: *const u64 = y.as_ptr().cast();
320
321    let mut i = 0;
322    let mut s: u32 = 0;
323
324    // Work in groups of 64
325    let blocks = len / 64;
326    while i < blocks {
327        // SAFETY: We know at least 64-bits (8-bytes) are valid from this offset (by
328        // guarantee of the `BitSlice`). All bit-patterns of a `u64` are valid, `u64: Copy`,
329        // and an `unaligned` read is used.
330        let vx = unsafe { px.add(i).read_unaligned() };
331
332        // SAFETY: The same logic applies to `y` because:
333        // 1. It has the same type as `x`.
334        // 2. We've verified that it has the same length as `x`.
335        let vy = unsafe { py.add(i).read_unaligned() };
336
337        s += Op::on_u64(vx, vy);
338        i += 1;
339    }
340
341    // Work in groups of 8
342    i *= 8;
343    let px: *const u8 = x.as_ptr();
344    let py: *const u8 = y.as_ptr();
345
346    let blocks = len / 8;
347    while i < blocks {
348        // SAFETY: The underlying pointer is a `*const u8` and we have checked that this
349        // offset is within the bounds of the slice underlying the bitslice.
350        let vx = unsafe { px.add(i).read_unaligned() };
351
352        // SAFETY: The same logic applies to `y` because:
353        // 1. It has the same type as `x`.
354        // 2. We've verified that it has the same length as `x`.
355        let vy = unsafe { py.add(i).read_unaligned() };
356        s += Op::on_u8(vx, vy);
357        i += 1;
358    }
359
360    if i * 8 != len {
361        // SAFETY: The underlying slice is readable in the range
362        // `[px, px + floor(len / 8) + 1)`. This accesses `px + floor(len / 8)`.
363        let vx = unsafe { px.add(i).read_unaligned() };
364
365        // SAFETY: Same as above.
366        let vy = unsafe { py.add(i).read_unaligned() };
367        let m = (0x01u8 << (len - 8 * i)) - 1;
368
369        s += Op::on_u8(vx & m, vy & m)
370    }
371    Ok(MV::new(s))
372}
373
374/// Compute the hamming distance between `x` and `y`.
375///
376/// Returns an error if the arguments have different lengths.
377impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
378    for Hamming
379{
380    fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
381        bitvector_op::<Hamming, Binary>(x, y)
382    }
383}
384
385///////////////
386// SquaredL2 //
387///////////////
388
389/// Compute the squared L2 distance between `x` and `y`.
390///
391/// Returns an error if the arguments have different lengths.
392///
393/// # Implementation Notes
394///
395/// This can directly invoke the methods implemented in `vector` because
396/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
397impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
398where
399    A: Architecture,
400    diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
401{
402    #[inline(always)]
403    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
404        check_lengths!(x, y)?;
405
406        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
407            diskann_vector::distance::SquaredL2 {},
408            arch,
409            x.as_slice(),
410            y.as_slice(),
411        );
412
413        Ok(MV::new(r.into_inner() as u32))
414    }
415}
416
417/// Compute the squared L2 distance between `x` and `y`.
418///
419/// Returns an error if the arguments have different lengths.
420///
421/// # Implementation Notes
422///
423/// This implementation is optimized around x86 with the AVX2 vector extension.
424/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
425/// hit the `_mm256_madd_epi16` intrinsic.
426///
427/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
428/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
429/// This works because we need to apply the same shift to all lanes.
430#[cfg(target_arch = "x86_64")]
431impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
432    for SquaredL2
433{
434    #[inline(always)]
435    fn run(
436        self,
437        arch: diskann_wide::arch::x86_64::V3,
438        x: USlice<'_, 4>,
439        y: USlice<'_, 4>,
440    ) -> MathematicalResult<u32> {
441        let len = check_lengths!(x, y)?;
442
443        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
444        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
445        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
446
447        let px_u32: *const u32 = x.as_ptr().cast();
448        let py_u32: *const u32 = y.as_ptr().cast();
449
450        let mut i = 0;
451        let mut s: u32 = 0;
452
453        // The number of 32-bit blocks over the underlying slice.
454        let blocks = len / 8;
455        if i < blocks {
456            let mut s0 = i32s::default(arch);
457            let mut s1 = i32s::default(arch);
458            let mut s2 = i32s::default(arch);
459            let mut s3 = i32s::default(arch);
460            let mask = u32s::splat(arch, 0x000f000f);
461            while i + 8 < blocks {
462                // SAFETY: We have checked that `i + 8 < blocks` which means the address
463                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
464                //
465                // The load has no alignment requirements.
466                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
467
468                // SAFETY: The same logic applies to `y` because:
469                // 1. It has the same type as `x`.
470                // 2. We've verified that it has the same length as `x`.
471                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
472
473                let wx: i16s = (vx & mask).reinterpret_simd();
474                let wy: i16s = (vy & mask).reinterpret_simd();
475                let d = wx - wy;
476                s0 = s0.dot_simd(d, d);
477
478                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
479                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
480                let d = wx - wy;
481                s1 = s1.dot_simd(d, d);
482
483                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
484                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
485                let d = wx - wy;
486                s2 = s2.dot_simd(d, d);
487
488                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
489                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
490                let d = wx - wy;
491                s3 = s3.dot_simd(d, d);
492
493                i += 8;
494            }
495
496            let remainder = blocks - i;
497
498            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
499            // at offset `i`. The exact number is computed as `remainder`.
500            //
501            // The predicated load is guaranteed not to access memory after `remainder` and
502            // has no alignment requirements.
503            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
504
505            // SAFETY: The same logic applies to `y` because:
506            // 1. It has the same type as `x`.
507            // 2. We've verified that it has the same length as `x`.
508            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
509
510            let wx: i16s = (vx & mask).reinterpret_simd();
511            let wy: i16s = (vy & mask).reinterpret_simd();
512            let d = wx - wy;
513            s0 = s0.dot_simd(d, d);
514
515            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
516            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
517            let d = wx - wy;
518            s1 = s1.dot_simd(d, d);
519
520            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
521            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
522            let d = wx - wy;
523            s2 = s2.dot_simd(d, d);
524
525            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
526            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
527            let d = wx - wy;
528            s3 = s3.dot_simd(d, d);
529
530            i += remainder;
531
532            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
533        }
534
535        // Convert blocks to indexes.
536        i *= 8;
537
538        // Deal with the remainder the slow way.
539        if i != len {
540            // Outline the fallback routine to keep code-generation at this level cleaner.
541            #[inline(never)]
542            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
543                let mut s: i32 = 0;
544                for i in from..x.len() {
545                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
546                    let ix = unsafe { x.get_unchecked(i) } as i32;
547                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
548                    let iy = unsafe { y.get_unchecked(i) } as i32;
549                    let d = ix - iy;
550                    s += d * d;
551                }
552                s as u32
553            }
554            s += fallback(x, y, i);
555        }
556
557        Ok(MV::new(s))
558    }
559}
560
561/// Compute the squared L2 distance between `x` and `y`.
562///
563/// Returns an error if the arguments have different lengths.
564///
565/// # Implementation Notes
566///
567/// This implementation is optimized around x86 with the AVX2 vector extension.
568/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
569/// hit the `_mm256_madd_epi16` intrinsic.
570///
571/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
572/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
573/// This works because we need to apply the same shift to all lanes.
574#[cfg(target_arch = "x86_64")]
575impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
576    for SquaredL2
577{
578    #[inline(always)]
579    fn run(
580        self,
581        arch: diskann_wide::arch::x86_64::V3,
582        x: USlice<'_, 2>,
583        y: USlice<'_, 2>,
584    ) -> MathematicalResult<u32> {
585        let len = check_lengths!(x, y)?;
586
587        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
588        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
589        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
590
591        let px_u32: *const u32 = x.as_ptr().cast();
592        let py_u32: *const u32 = y.as_ptr().cast();
593
594        let mut i = 0;
595        let mut s: u32 = 0;
596
597        // The number of 32-bit blocks over the underlying slice.
598        let blocks = len / 16;
599        if i < blocks {
600            let mut s0 = i32s::default(arch);
601            let mut s1 = i32s::default(arch);
602            let mut s2 = i32s::default(arch);
603            let mut s3 = i32s::default(arch);
604            let mask = u32s::splat(arch, 0x00030003);
605            while i + 8 < blocks {
606                // SAFETY: We have checked that `i + 8 < blocks` which means the address
607                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
608                //
609                // The load has no alignment requirements.
610                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
611
612                // SAFETY: The same logic applies to `y` because:
613                // 1. It has the same type as `x`.
614                // 2. We've verified that it has the same length as `x`.
615                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
616
617                let wx: i16s = (vx & mask).reinterpret_simd();
618                let wy: i16s = (vy & mask).reinterpret_simd();
619                let d = wx - wy;
620                s0 = s0.dot_simd(d, d);
621
622                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
623                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
624                let d = wx - wy;
625                s1 = s1.dot_simd(d, d);
626
627                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
628                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
629                let d = wx - wy;
630                s2 = s2.dot_simd(d, d);
631
632                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
633                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
634                let d = wx - wy;
635                s3 = s3.dot_simd(d, d);
636
637                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
638                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
639                let d = wx - wy;
640                s0 = s0.dot_simd(d, d);
641
642                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
643                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
644                let d = wx - wy;
645                s1 = s1.dot_simd(d, d);
646
647                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
648                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
649                let d = wx - wy;
650                s2 = s2.dot_simd(d, d);
651
652                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
653                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
654                let d = wx - wy;
655                s3 = s3.dot_simd(d, d);
656
657                i += 8;
658            }
659
660            let remainder = blocks - i;
661
662            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
663            // at offset `i`. The exact number is computed as `remainder`.
664            //
665            // The predicated load is guaranteed not to access memory after `remainder` and
666            // has no alignment requirements.
667            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
668
669            // SAFETY: The same logic applies to `y` because:
670            // 1. It has the same type as `x`.
671            // 2. We've verified that it has the same length as `x`.
672            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
673            let wx: i16s = (vx & mask).reinterpret_simd();
674            let wy: i16s = (vy & mask).reinterpret_simd();
675            let d = wx - wy;
676            s0 = s0.dot_simd(d, d);
677
678            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
679            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
680            let d = wx - wy;
681            s1 = s1.dot_simd(d, d);
682
683            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
684            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
685            let d = wx - wy;
686            s2 = s2.dot_simd(d, d);
687
688            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
689            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
690            let d = wx - wy;
691            s3 = s3.dot_simd(d, d);
692
693            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
694            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
695            let d = wx - wy;
696            s0 = s0.dot_simd(d, d);
697
698            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
699            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
700            let d = wx - wy;
701            s1 = s1.dot_simd(d, d);
702
703            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
704            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
705            let d = wx - wy;
706            s2 = s2.dot_simd(d, d);
707
708            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
709            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
710            let d = wx - wy;
711            s3 = s3.dot_simd(d, d);
712
713            i += remainder;
714
715            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
716        }
717
718        // Convert blocks to indexes.
719        i *= 16;
720
721        // Deal with the remainder the slow way.
722        if i != len {
723            // Outline the fallback routine to keep code-generation at this level cleaner.
724            #[inline(never)]
725            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
726                let mut s: i32 = 0;
727                for i in from..x.len() {
728                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
729                    let ix = unsafe { x.get_unchecked(i) } as i32;
730                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
731                    let iy = unsafe { y.get_unchecked(i) } as i32;
732                    let d = ix - iy;
733                    s += d * d;
734                }
735                s as u32
736            }
737            s += fallback(x, y, i);
738        }
739
740        Ok(MV::new(s))
741    }
742}
743
744/// Compute the squared L2 distance between bitvectors `x` and `y`.
745///
746/// Returns an error if the arguments have different lengths.
747impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
748where
749    A: Architecture,
750{
751    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
752        bitvector_op::<Self, Unsigned>(x, y)
753    }
754}
755
756/// An implementation for L2 distance that uses scalar indexing for the implementation.
757macro_rules! impl_fallback_l2 {
758    ($N:literal) => {
759        /// Compute the squared L2 distance between `x` and `y`.
760        ///
761        /// Returns an error if the arguments have different lengths.
762        ///
763        /// # Performance
764        ///
765        /// This function uses a generic implementation and therefore is not very fast.
766        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
767            #[inline(never)]
768            fn run(
769                self,
770                _: diskann_wide::arch::Scalar,
771                x: USlice<'_, $N>,
772                y: USlice<'_, $N>
773            ) -> MathematicalResult<u32> {
774                let len = check_lengths!(x, y)?;
775
776                let mut accum: i32 = 0;
777                for i in 0..len {
778                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
779                    let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
780                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
781                    let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
782                    let diff = ix - iy;
783                    accum += diff * diff;
784                }
785                Ok(MV::new(accum as u32))
786            }
787        }
788    };
789    ($($N:literal),+ $(,)?) => {
790        $(impl_fallback_l2!($N);)+
791    };
792}
793
794impl_fallback_l2!(7, 6, 5, 4, 3, 2);
795
796#[cfg(target_arch = "x86_64")]
797retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
798
799#[cfg(target_arch = "x86_64")]
800retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
801
802dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
803#[cfg(target_arch = "aarch64")]
804retarget!(
805    diskann_wide::arch::aarch64::Neon,
806    SquaredL2,
807    7,
808    6,
809    5,
810    4,
811    3,
812    2
813);
814
815///////////////////
816// Inner Product //
817///////////////////
818
819/// Compute the inner product between `x` and `y`.
820///
821/// Returns an error if the arguments have different lengths.
822///
823/// # Implementation Notes
824///
825/// This can directly invoke the methods implemented in `vector` because
826/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
827impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
828where
829    A: Architecture,
830    diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
831{
832    #[inline(always)]
833    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
834        check_lengths!(x, y)?;
835        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
836            diskann_vector::distance::InnerProduct {},
837            arch,
838            x.as_slice(),
839            y.as_slice(),
840        );
841
842        Ok(MV::new(r.into_inner() as u32))
843    }
844}
845
846/// Compute the inner product between `x` and `y`.
847///
848/// Returns an error if the arguments have different lengths.
849///
850/// # Implementation Notes
851///
852/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the
853/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with
854/// an `i32` accumulation vector.
855///
856/// One quirk of this instruction is that one argument must be unsigned and the other must
857/// be signed. Since thie kernsl works on 2-bit integers, this is not a limitation. Just
858/// something to be aware of.
859///
860/// Since AVX512 does not have an 8-bit shift instruction, we generally load data as
861/// `u32x16` (which has a native shift) and bit-cast it to `u8x64` as needed.
862#[cfg(target_arch = "x86_64")]
863impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
864    for InnerProduct
865{
866    #[expect(non_camel_case_types)]
867    #[inline(always)]
868    fn run(
869        self,
870        arch: diskann_wide::arch::x86_64::V4,
871        x: USlice<'_, 2>,
872        y: USlice<'_, 2>,
873    ) -> MathematicalResult<u32> {
874        let len = check_lengths!(x, y)?;
875
876        type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
877        type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
878        type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
879        type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
880
881        let px_u32: *const u32 = x.as_ptr().cast();
882        let py_u32: *const u32 = y.as_ptr().cast();
883
884        let mut i = 0;
885        let mut s: u32 = 0;
886
887        // The number of 32-bit blocks over the underlying slice.
888        let blocks = len.div_ceil(16);
889        if i < blocks {
890            let mut s0 = i32s::default(arch);
891            let mut s1 = i32s::default(arch);
892            let mut s2 = i32s::default(arch);
893            let mut s3 = i32s::default(arch);
894            let mask = u32s::splat(arch, 0x03030303);
895            while i + 16 < blocks {
896                // SAFETY: We have checked that `i + 16 < blocks` which means the address
897                // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::<u32>())` is valid.
898                //
899                // The load has no alignment requirements.
900                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
901
902                // SAFETY: The same logic applies to `y` because:
903                // 1. It has the same type as `x`.
904                // 2. We've verified that it has the same length as `x`.
905                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
906
907                let wx: u8s = (vx & mask).reinterpret_simd();
908                let wy: i8s = (vy & mask).reinterpret_simd();
909                s0 = s0.dot_simd(wx, wy);
910
911                let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
912                let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
913                s1 = s1.dot_simd(wx, wy);
914
915                let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
916                let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
917                s2 = s2.dot_simd(wx, wy);
918
919                let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
920                let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
921                s3 = s3.dot_simd(wx, wy);
922
923                i += 16;
924            }
925
926            // Here
927            // * `len / 4` gives the number of full bytes
928            // * `4 * i` gives the number of bytes processed.
929            let remainder = len / 4 - 4 * i;
930
931            // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`.
932            //
933            // The predicated load is guaranteed not to access memory after `remainder` and
934            // has no alignment requirements.
935            let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
936            let vx: u32s = vx.reinterpret_simd();
937
938            // SAFETY: The same logic applies to `y` because:
939            // 1. It has the same type as `x`.
940            // 2. We've verified that it has the same length as `x`.
941            let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
942            let vy: u32s = vy.reinterpret_simd();
943
944            let wx: u8s = (vx & mask).reinterpret_simd();
945            let wy: i8s = (vy & mask).reinterpret_simd();
946            s0 = s0.dot_simd(wx, wy);
947
948            let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
949            let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
950            s1 = s1.dot_simd(wx, wy);
951
952            let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
953            let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
954            s2 = s2.dot_simd(wx, wy);
955
956            let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
957            let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
958            s3 = s3.dot_simd(wx, wy);
959
960            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
961            i = (4 * i) + remainder;
962        }
963
964        // Convert blocks to indexes.
965        i *= 4;
966
967        // Deal with the remainder the slow way.
968        debug_assert!(len - i <= 3);
969        let rest = (len - i).min(3);
970        if i != len {
971            for j in 0..rest {
972                // SAFETY: `i` is guaranteed to be less than `x.len()`.
973                let ix = unsafe { x.get_unchecked(i + j) } as u32;
974                // SAFETY: `i` is guaranteed to be less than `y.len()`.
975                let iy = unsafe { y.get_unchecked(i + j) } as u32;
976                s += ix * iy;
977            }
978        }
979
980        Ok(MV::new(s))
981    }
982}
983
984/// Compute the inner product between `x` and `y`.
985///
986/// Returns an error if the arguments have different lengths.
987///
988/// # Implementation Notes
989///
990/// This implementation is optimized around x86 with the AVX2 vector extension.
991/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
992/// hit the `_mm256_madd_epi16` intrinsic.
993///
994/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
995/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
996/// This works because we need to apply the same shift to all lanes.
997#[cfg(target_arch = "x86_64")]
998impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
999    for InnerProduct
1000{
1001    #[inline(always)]
1002    fn run(
1003        self,
1004        arch: diskann_wide::arch::x86_64::V3,
1005        x: USlice<'_, 4>,
1006        y: USlice<'_, 4>,
1007    ) -> MathematicalResult<u32> {
1008        let len = check_lengths!(x, y)?;
1009
1010        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1011        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1012        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1013
1014        let px_u32: *const u32 = x.as_ptr().cast();
1015        let py_u32: *const u32 = y.as_ptr().cast();
1016
1017        let mut i = 0;
1018        let mut s: u32 = 0;
1019
1020        let blocks = len / 8;
1021        if i < blocks {
1022            let mut s0 = i32s::default(arch);
1023            let mut s1 = i32s::default(arch);
1024            let mut s2 = i32s::default(arch);
1025            let mut s3 = i32s::default(arch);
1026            let mask = u32s::splat(arch, 0x000f000f);
1027            while i + 8 < blocks {
1028                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1029                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1030                //
1031                // The load has no alignment requirements.
1032                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1033
1034                // SAFETY: The same logic applies to `y` because:
1035                // 1. It has the same type as `x`.
1036                // 2. We've verified that it has the same length as `x`.
1037                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1038
1039                let wx: i16s = (vx & mask).reinterpret_simd();
1040                let wy: i16s = (vy & mask).reinterpret_simd();
1041                s0 = s0.dot_simd(wx, wy);
1042
1043                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1044                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1045                s1 = s1.dot_simd(wx, wy);
1046
1047                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1048                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1049                s2 = s2.dot_simd(wx, wy);
1050
1051                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1052                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1053                s3 = s3.dot_simd(wx, wy);
1054
1055                i += 8;
1056            }
1057
1058            let remainder = blocks - i;
1059
1060            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1061            // at offset `i`. The exact number is computed as `remainder`.
1062            //
1063            // The predicated load is guaranteed not to access memory after `remainder` and
1064            // has no alignment requirements.
1065            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1066
1067            // SAFETY: The same logic applies to `y` because:
1068            // 1. It has the same type as `x`.
1069            // 2. We've verified that it has the same length as `x`.
1070            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1071
1072            let wx: i16s = (vx & mask).reinterpret_simd();
1073            let wy: i16s = (vy & mask).reinterpret_simd();
1074            s0 = s0.dot_simd(wx, wy);
1075
1076            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1077            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1078            s1 = s1.dot_simd(wx, wy);
1079
1080            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1081            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1082            s2 = s2.dot_simd(wx, wy);
1083
1084            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1085            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1086            s3 = s3.dot_simd(wx, wy);
1087
1088            i += remainder;
1089
1090            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1091        }
1092
1093        // Convert blocks to indexes.
1094        i *= 8;
1095
1096        // Deal with the remainder the slow way.
1097        if i != len {
1098            // Outline the fallback routine to keep code-generation at this level cleaner.
1099            #[inline(never)]
1100            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1101                let mut s: u32 = 0;
1102                for i in from..x.len() {
1103                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1104                    let ix = unsafe { x.get_unchecked(i) } as u32;
1105                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1106                    let iy = unsafe { y.get_unchecked(i) } as u32;
1107                    s += ix * iy;
1108                }
1109                s
1110            }
1111            s += fallback(x, y, i);
1112        }
1113
1114        Ok(MV::new(s))
1115    }
1116}
1117
1118/// Compute the inner product between `x` and `y`.
1119///
1120/// Returns an error if the arguments have different lengths.
1121///
1122/// # Implementation Notes
1123///
1124/// This implementation is optimized around x86 with the AVX2 vector extension.
1125/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
1126/// hit the `_mm256_madd_epi16` intrinsic.
1127///
1128/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
1129/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
1130/// This works because we need to apply the same shift to all lanes.
1131#[cfg(target_arch = "x86_64")]
1132impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1133    for InnerProduct
1134{
1135    #[inline(always)]
1136    fn run(
1137        self,
1138        arch: diskann_wide::arch::x86_64::V3,
1139        x: USlice<'_, 2>,
1140        y: USlice<'_, 2>,
1141    ) -> MathematicalResult<u32> {
1142        let len = check_lengths!(x, y)?;
1143
1144        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1145        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1146        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1147
1148        let px_u32: *const u32 = x.as_ptr().cast();
1149        let py_u32: *const u32 = y.as_ptr().cast();
1150
1151        let mut i = 0;
1152        let mut s: u32 = 0;
1153
1154        // The number of 32-bit blocks over the underlying slice.
1155        let blocks = len / 16;
1156        if i < blocks {
1157            let mut s0 = i32s::default(arch);
1158            let mut s1 = i32s::default(arch);
1159            let mut s2 = i32s::default(arch);
1160            let mut s3 = i32s::default(arch);
1161            let mask = u32s::splat(arch, 0x00030003);
1162            while i + 8 < blocks {
1163                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1164                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1165                //
1166                // The load has no alignment requirements.
1167                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1168
1169                // SAFETY: The same logic applies to `y` because:
1170                // 1. It has the same type as `x`.
1171                // 2. We've verified that it has the same length as `x`.
1172                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1173
1174                let wx: i16s = (vx & mask).reinterpret_simd();
1175                let wy: i16s = (vy & mask).reinterpret_simd();
1176                s0 = s0.dot_simd(wx, wy);
1177
1178                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1179                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1180                s1 = s1.dot_simd(wx, wy);
1181
1182                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1183                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1184                s2 = s2.dot_simd(wx, wy);
1185
1186                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1187                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1188                s3 = s3.dot_simd(wx, wy);
1189
1190                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1191                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1192                s0 = s0.dot_simd(wx, wy);
1193
1194                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1195                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1196                s1 = s1.dot_simd(wx, wy);
1197
1198                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1199                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1200                s2 = s2.dot_simd(wx, wy);
1201
1202                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1203                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1204                s3 = s3.dot_simd(wx, wy);
1205
1206                i += 8;
1207            }
1208
1209            let remainder = blocks - i;
1210
1211            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1212            // at offset `i`. The exact number is computed as `remainder`.
1213            //
1214            // The predicated load is guaranteed not to access memory after `remainder` and
1215            // has no alignment requirements.
1216            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1217
1218            // SAFETY: The same logic applies to `y` because:
1219            // 1. It has the same type as `x`.
1220            // 2. We've verified that it has the same length as `x`.
1221            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1222            let wx: i16s = (vx & mask).reinterpret_simd();
1223            let wy: i16s = (vy & mask).reinterpret_simd();
1224            s0 = s0.dot_simd(wx, wy);
1225
1226            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1227            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1228            s1 = s1.dot_simd(wx, wy);
1229
1230            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1231            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1232            s2 = s2.dot_simd(wx, wy);
1233
1234            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1235            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1236            s3 = s3.dot_simd(wx, wy);
1237
1238            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1239            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1240            s0 = s0.dot_simd(wx, wy);
1241
1242            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1243            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1244            s1 = s1.dot_simd(wx, wy);
1245
1246            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1247            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1248            s2 = s2.dot_simd(wx, wy);
1249
1250            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1251            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1252            s3 = s3.dot_simd(wx, wy);
1253
1254            i += remainder;
1255
1256            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1257        }
1258
1259        // Convert blocks to indexes.
1260        i *= 16;
1261
1262        // Deal with the remainder the slow way.
1263        if i != len {
1264            // Outline the fallback routine to keep code-generation at this level cleaner.
1265            #[inline(never)]
1266            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1267                let mut s: u32 = 0;
1268                for i in from..x.len() {
1269                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1270                    let ix = unsafe { x.get_unchecked(i) } as u32;
1271                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1272                    let iy = unsafe { y.get_unchecked(i) } as u32;
1273                    s += ix * iy;
1274                }
1275                s
1276            }
1277            s += fallback(x, y, i);
1278        }
1279
1280        Ok(MV::new(s))
1281    }
1282}
1283
1284/// Compute the inner product between bitvectors `x` and `y`.
1285///
1286/// Returns an error if the arguments have different lengths.
1287impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1288where
1289    A: Architecture,
1290{
1291    #[inline(always)]
1292    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1293        bitvector_op::<Self, Unsigned>(x, y)
1294    }
1295}
1296
1297/// An implementation for inner products that uses scalar indexing for the implementation.
1298macro_rules! impl_fallback_ip {
1299    (($N:literal, $M:literal)) => {
1300        /// Compute the inner product between `x` and `y`.
1301        ///
1302        /// Returns an error if the arguments have different lengths.
1303        ///
1304        /// # Performance
1305        ///
1306        /// This function uses a generic implementation and therefore is not very fast.
1307        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $M>> for InnerProduct {
1308            #[inline(never)]
1309            fn run(
1310                self,
1311                _: diskann_wide::arch::Scalar,
1312                x: USlice<'_, $N>,
1313                y: USlice<'_, $M>
1314            ) -> MathematicalResult<u32> {
1315                let len = check_lengths!(x, y)?;
1316
1317                let mut accum: u32 = 0;
1318                for i in 0..len {
1319                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1320                    let ix = unsafe { x.get_unchecked(i) } as u32;
1321                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1322                    let iy = unsafe { y.get_unchecked(i) } as u32;
1323                    accum += ix * iy;
1324                }
1325                Ok(MV::new(accum))
1326            }
1327        }
1328    };
1329    ($N:literal) => {
1330        impl_fallback_ip!(($N, $N));
1331    };
1332    ($($args:tt),+ $(,)?) => {
1333        $(impl_fallback_ip!($args);)+
1334    };
1335}
1336
1337impl_fallback_ip!(7, 6, 5, 4, 3, 2, (8, 4), (8, 2), (8, 1));
1338
1339#[cfg(target_arch = "x86_64")]
1340retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1341
1342#[cfg(target_arch = "x86_64")]
1343retarget!(
1344    diskann_wide::arch::x86_64::V4,
1345    InnerProduct,
1346    7,
1347    6,
1348    5,
1349    4,
1350    3,
1351    (8, 4),
1352    (8, 2),
1353    (8, 1)
1354);
1355
1356dispatch_pure!(
1357    InnerProduct,
1358    1,
1359    2,
1360    3,
1361    4,
1362    5,
1363    6,
1364    7,
1365    (8, 8),
1366    (8, 4),
1367    (8, 2),
1368    (8, 1)
1369);
1370
1371//////////////////////////////////////////
1372// Heterogeneous USlice<8> × USlice<M> //
1373/////////////////////////////////////////
1374#[cfg(target_arch = "x86_64")]
1375impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 4>>
1376    for InnerProduct
1377{
1378    /// Computes the inner product of 8-bit unsigned × 4-bit unsigned vectors using V3 intrinsics.
1379    ///
1380    /// # Strategy
1381    ///
1382    /// Unpack each 16-byte chunk of `y` into 32 nibble values via [`unpack_half_bytes`],
1383    /// then multiply with the corresponding 32 bytes of `x` using `_mm256_maddubs_epi16`
1384    /// (u8 × u8 → i16, pairwise horizontal add).
1385    ///
1386    /// The main loop is 4× unrolled: four i16 products are summed in i16 before a single
1387    /// `_mm256_madd_epi16(…, 1)` widens to i32. This is safe because
1388    /// 4 × (255 × 15 × 2) = 30_600 < i16::MAX.
1389    #[inline(always)]
1390    fn run(
1391        self,
1392        arch: diskann_wide::arch::x86_64::V3,
1393        x: USlice<'_, 8>,
1394        y: USlice<'_, 4>,
1395    ) -> MathematicalResult<u32> {
1396        use std::arch::x86_64::_mm256_maddubs_epi16;
1397
1398        let len = check_lengths!(x, y)?;
1399
1400        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1401        diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
1402        diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1403        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1404
1405        let px: *const u8 = x.as_ptr();
1406        let py: *const u8 = y.as_ptr();
1407
1408        let mut i: usize = 0;
1409        let mut s: u32 = 0;
1410
1411        #[inline(always)]
1412        fn unpack_half(input: u8s_16) -> u8s_32 {
1413            let combined = diskann_wide::LoHi::new(input, input >> 4).zip::<u8s_32>();
1414            combined & u8s_32::splat(input.arch(), (1u8 << 4) - 1)
1415        }
1416
1417        // Each block processes 32 elements: 32 bytes from x, 16 packed bytes from y.
1418        let blocks = len / 32;
1419        if blocks > 0 {
1420            let mut acc = i32s::default(arch);
1421
1422            let products = |x: u8s_32, y: u8s_32| -> i16s {
1423                // SAFETY: `arch` is V3 (AVX2), which provides `_mm256_maddubs_epi16`.
1424                i16s::from_underlying(arch, unsafe {
1425                    _mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
1426                })
1427            };
1428
1429            let ones = i16s::splat(arch, 1);
1430
1431            // Main loop: 4× unrolled, processing 128 elements per iteration.
1432            //
1433            // `products` returns i16 lanes, each at most 255 × 15 × 2 = 7_650.
1434            // We sum 4 such values in i16 before widening to i32:
1435            //   4 × 7_650 = 30_600 < i16::MAX (32_767)
1436            while i + 4 <= blocks {
1437                // Block 0
1438                // SAFETY: `i + 4 <= blocks` guarantees at least 4 full blocks remain.
1439                // Each block needs 32 bytes from `px` and 16 bytes from `py`.
1440                let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
1441                // SAFETY: `i + 4 <= blocks` guarantees 16 bytes readable at `py.add(16 * i)`.
1442                let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
1443                let m0 = products(vx, unpack_half(vy));
1444
1445                // Block 1
1446                // SAFETY: `i + 1 < i + 4 <= blocks`.
1447                let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 1))) };
1448                // SAFETY: same bound; 16 bytes readable at `py.add(16 * (i + 1))`.
1449                let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 1))) };
1450                let m1 = products(vx, unpack_half(vy));
1451
1452                // Block 2
1453                // SAFETY: `i + 2 < i + 4 <= blocks`.
1454                let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 2))) };
1455                // SAFETY: same bound; 16 bytes readable at `py.add(16 * (i + 2))`.
1456                let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 2))) };
1457                let m2 = products(vx, unpack_half(vy));
1458
1459                // Block 3
1460                // SAFETY: `i + 3 < i + 4 <= blocks`.
1461                let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 3))) };
1462                // SAFETY: same bound; 16 bytes readable at `py.add(16 * (i + 3))`.
1463                let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 3))) };
1464                let m3 = products(vx, unpack_half(vy));
1465
1466                acc = acc.dot_simd(m0 + m1 + m2 + m3, ones);
1467                i += 4;
1468            }
1469
1470            // Drain remaining full 32-element blocks (0..3 iterations).
1471            while i < blocks {
1472                // SAFETY: `i < blocks` guarantees 32 bytes from `px` and 16 bytes
1473                // from `py` are readable at this offset.
1474                let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
1475                // SAFETY: `i < blocks` guarantees 16 bytes readable at `py.add(16 * i)`.
1476                let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
1477                acc = acc.dot_simd(products(vx, unpack_half(vy)), ones);
1478                i += 1;
1479            }
1480
1481            s = acc.sum_tree() as u32;
1482        }
1483
1484        // Convert block count to element index.
1485        i *= 32;
1486
1487        // Scalar fallback for the remaining < 32 elements.
1488        if i != len {
1489            #[inline(never)]
1490            fn fallback(x: USlice<'_, 8>, y: USlice<'_, 4>, from: usize) -> u32 {
1491                let mut s: u32 = 0;
1492                for i in from..x.len() {
1493                    // SAFETY: `i` is bounded by `x.len()`.
1494                    let ix = unsafe { x.get_unchecked(i) } as u32;
1495                    // SAFETY: `i` is bounded by `x.len()` which equals `y.len()`.
1496                    let iy = unsafe { y.get_unchecked(i) } as u32;
1497                    s += ix * iy;
1498                }
1499                s
1500            }
1501            s += fallback(x, y, i);
1502        }
1503
1504        Ok(MV::new(s))
1505    }
1506}
1507
1508#[cfg(target_arch = "x86_64")]
1509impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 2>>
1510    for InnerProduct
1511{
1512    /// Computes the inner product of 8-bit unsigned × 2-bit unsigned vectors using AVX2.
1513    ///
1514    /// # Strategy
1515    ///
1516    /// Unpack each 16-byte chunk of `y` into 64 crumb values via a two-level cascade:
1517    /// first [`unpack_half_bytes`] splits bytes into nibbles, then a second pass splits
1518    /// nibbles into crumbs (masked with `0x03`). Each unpacked half is paired with 32
1519    /// bytes of `x` and multiplied via `_mm256_maddubs_epi16`.
1520    ///
1521    /// The main loop is 4× unrolled: eight i16 products (4 blocks × 2 halves) are summed
1522    /// in i16 before a single `_mm256_madd_epi16(…, 1)` widens to i32. This is safe
1523    /// because 8 × (255 × 3 × 2) = 12_240 < i16::MAX.
1524    #[inline(always)]
1525    fn run(
1526        self,
1527        arch: diskann_wide::arch::x86_64::V3,
1528        x: USlice<'_, 8>,
1529        y: USlice<'_, 2>,
1530    ) -> MathematicalResult<u32> {
1531        use diskann_wide::SplitJoin;
1532        use std::arch::x86_64::_mm256_maddubs_epi16;
1533
1534        let len = check_lengths!(x, y)?;
1535
1536        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1537        diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
1538        diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1539        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1540
1541        let px: *const u8 = x.as_ptr();
1542        let py: *const u8 = y.as_ptr();
1543
1544        let mut i: usize = 0;
1545        let mut s: u32 = 0;
1546
1547        // Each block processes 64 elements: 64 bytes from x, 16 packed bytes from y.
1548        let blocks = len / 64;
1549        if blocks > 0 {
1550            let mut acc = i32s::default(arch);
1551
1552            let products = |x: u8s_32, y: u8s_32| -> i16s {
1553                // SAFETY: `arch` is V3 (AVX2), which provides `_mm256_maddubs_epi16`.
1554                i16s::from_underlying(arch, unsafe {
1555                    _mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
1556                })
1557            };
1558
1559            #[inline(always)]
1560            fn unpack_sub<const N: u8>(input: u8s_16) -> u8s_32 {
1561                let combined = diskann_wide::LoHi::new(input, input >> N).zip::<u8s_32>();
1562                combined & u8s_32::splat(input.arch(), (1u8 << N) - 1)
1563            }
1564
1565            let unpack_crumbs = |x: u8s_16| -> (u8s_32, u8s_32) {
1566                // Level 1: nibble split → 32 × 4-bit values in a u8x32.
1567                let nibbles = unpack_sub::<4>(x);
1568
1569                // Split the u8x32 into two u8x16 halves (lo = elements 0..15,
1570                // hi = elements 16..31), then apply Level 2 crumb split to each.
1571                let diskann_wide::LoHi { lo, hi } = nibbles.split();
1572                let lower = unpack_sub::<2>(lo);
1573                let upper = unpack_sub::<2>(hi);
1574
1575                (lower, upper)
1576            };
1577
1578            let ones = i16s::splat(arch, 1);
1579
1580            // Main loop: 4× unrolled, processing 256 elements per iteration.
1581            //
1582            // `products` returns i16 lanes, each at most 255 × 3 × 2 = 1_530.
1583            // Each iteration produces 4 blocks × 2 products = 8 values.
1584            // We sum all 8 in i16 before widening to i32:
1585            //   8 × 1_530 = 12_240 < i16::MAX (32_767)
1586            while i + 4 <= blocks {
1587                // Block 0
1588                // SAFETY: `i + 4 <= blocks` guarantees at least 4 full blocks remain.
1589                // Each block needs 64 bytes from `px` and 16 bytes from `py`.
1590                let (vx0, vx1, (vy0, vy1)) = unsafe {
1591                    (
1592                        u8s_32::load_simd(arch, px.add(64 * i)),
1593                        u8s_32::load_simd(arch, px.add(64 * i + 32)),
1594                        unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
1595                    )
1596                };
1597                let m0a = products(vx0, vy0);
1598                let m0b = products(vx1, vy1);
1599
1600                // Block 1
1601                // SAFETY: `i + 1 < i + 4 <= blocks`.
1602                let (vx0, vx1, (vy0, vy1)) = unsafe {
1603                    (
1604                        u8s_32::load_simd(arch, px.add(64 * (i + 1))),
1605                        u8s_32::load_simd(arch, px.add(64 * (i + 1) + 32)),
1606                        unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 1)))),
1607                    )
1608                };
1609                let m1a = products(vx0, vy0);
1610                let m1b = products(vx1, vy1);
1611
1612                // Block 2
1613                // SAFETY: `i + 2 < i + 4 <= blocks`.
1614                let (vx0, vx1, (vy0, vy1)) = unsafe {
1615                    (
1616                        u8s_32::load_simd(arch, px.add(64 * (i + 2))),
1617                        u8s_32::load_simd(arch, px.add(64 * (i + 2) + 32)),
1618                        unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 2)))),
1619                    )
1620                };
1621                let m2a = products(vx0, vy0);
1622                let m2b = products(vx1, vy1);
1623
1624                // Block 3
1625                // SAFETY: `i + 3 < i + 4 <= blocks`.
1626                let (vx0, vx1, (vy0, vy1)) = unsafe {
1627                    (
1628                        u8s_32::load_simd(arch, px.add(64 * (i + 3))),
1629                        u8s_32::load_simd(arch, px.add(64 * (i + 3) + 32)),
1630                        unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 3)))),
1631                    )
1632                };
1633                let m3a = products(vx0, vy0);
1634                let m3b = products(vx1, vy1);
1635
1636                acc = acc.dot_simd((m0a + m0b + m1a + m1b) + (m2a + m2b + m3a + m3b), ones);
1637                i += 4;
1638            }
1639
1640            // Drain remaining full 64-element blocks (0..3 iterations).
1641            while i < blocks {
1642                // SAFETY: `i < blocks` guarantees 64 bytes from `px` and 16 bytes
1643                // from `py` are readable at this offset.
1644                let (vx0, vx1, (vy0, vy1)) = unsafe {
1645                    (
1646                        u8s_32::load_simd(arch, px.add(64 * i)),
1647                        u8s_32::load_simd(arch, px.add(64 * i + 32)),
1648                        unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
1649                    )
1650                };
1651                acc = acc.dot_simd(products(vx0, vy0) + products(vx1, vy1), ones);
1652                i += 1;
1653            }
1654
1655            s = acc.sum_tree() as u32;
1656        }
1657
1658        // Convert block count to element index.
1659        i *= 64;
1660
1661        // Scalar fallback for the remaining < 64 elements.
1662        if i != len {
1663            #[inline(never)]
1664            fn fallback(x: USlice<'_, 8>, y: USlice<'_, 2>, from: usize) -> u32 {
1665                let mut s: u32 = 0;
1666                for i in from..x.len() {
1667                    // SAFETY: `i` is in `from..x.len()`, which equals `y.len()`.
1668                    let (ix, iy) =
1669                        unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
1670                    s += ix * iy;
1671                }
1672                s
1673            }
1674            s += fallback(x, y, i);
1675        }
1676
1677        Ok(MV::new(s))
1678    }
1679}
1680
1681#[cfg(target_arch = "x86_64")]
1682impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 1>>
1683    for InnerProduct
1684{
1685    /// Computes the inner product of 8-bit unsigned × 1-bit unsigned vectors using V3 intrinsics.
1686    ///
1687    /// For each 32-element block we load 32 bytes from `x` and 4 bytes (32 bits) from `y`.
1688    /// ANDing the data with the mask created from 4 bytes from `y` zeroes unselected lanes.
1689    /// Finally, `_mm256_sad_epu8` horizontally sums the masked bytes in groups of 8.
1690    ///
1691    /// The main loop is 4× unrolled, processing 128 elements per iteration.
1692    ///
1693    /// ## Overflow
1694    ///
1695    /// Each `sad` output lane holds at most `8 × 255 = 2_040`. Accumulated across `d/32`
1696    /// blocks, the per-lane max is `(d/32) × 2_040`. At dim = 3072: `96 × 2_040 = 195_840`,
1697    /// well within i32 range.
1698    #[inline(always)]
1699    fn run(
1700        self,
1701        arch: diskann_wide::arch::x86_64::V3,
1702        x: USlice<'_, 8>,
1703        y: USlice<'_, 1>,
1704    ) -> MathematicalResult<u32> {
1705        use diskann_wide::{FromInt, SIMDMask};
1706        use std::arch::x86_64::_mm256_sad_epu8;
1707
1708        let len = check_lengths!(x, y)?;
1709
1710        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1711        diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
1712
1713        type Mask32 = diskann_wide::BitMask<32, diskann_wide::arch::x86_64::V3>;
1714        type Mask8x32 = diskann_wide::arch::x86_64::v3::masks::mask8x32;
1715
1716        let px: *const u8 = x.as_ptr();
1717        let py: *const u8 = y.as_ptr();
1718
1719        let mut i: usize = 0;
1720        let mut s: u32 = 0;
1721
1722        // Each block processes 32 elements: 32 bytes from x, 4 bytes (32 bits) from y.
1723        let blocks = len / 32;
1724        if blocks > 0 {
1725            let mut acc = i32s::default(arch);
1726            let zero = u8s_32::default(arch);
1727
1728            // Expand 32 bits of `y` to a byte mask, AND with `x`, and horizontally sum
1729            // the masked bytes via `_mm256_sad_epu8`.
1730            //
1731            // Returns an `i32x8` where lanes 0, 2, 4, 6 hold group sums (each ≤ 2040)
1732            // and lanes 1, 3, 5, 7 are zero.
1733            let masked_sad = |vx: u8s_32, bits: u32| -> i32s {
1734                // Expand 32-bit mask → 32 bytes of 0xFF/0x00.
1735                let byte_mask: Mask8x32 = Mask32::from_int(arch, bits).into();
1736
1737                // AND data with mask: zeroes lanes where the bit is 0.
1738                let masked = vx & u8s_32::from_underlying(arch, byte_mask.to_underlying());
1739
1740                // Horizontal byte sum in groups of 8 → 4 × u64 partial sums.
1741                // SAFETY: `arch` is V3 (AVX2), `_mm256_sad_epu8` is available.
1742                i32s::from_underlying(arch, unsafe {
1743                    _mm256_sad_epu8(masked.to_underlying(), zero.to_underlying())
1744                })
1745            };
1746
1747            // Main loop: 4× unrolled, processing 128 elements per iteration.
1748            while i + 4 <= blocks {
1749                // SAFETY: `i + 4 <= blocks` guarantees at least 4 full blocks remain.
1750                // Each block needs 32 bytes from `px` and 4 bytes from `py`.
1751                let s0 = unsafe {
1752                    let vx = u8s_32::load_simd(arch, px.add(32 * i));
1753                    let bits = (py.add(4 * i) as *const u32).read_unaligned();
1754                    masked_sad(vx, bits)
1755                };
1756
1757                // SAFETY: `i + 1 < i + 4 <= blocks`.
1758                let s1 = unsafe {
1759                    let vx = u8s_32::load_simd(arch, px.add(32 * (i + 1)));
1760                    let bits = (py.add(4 * (i + 1)) as *const u32).read_unaligned();
1761                    masked_sad(vx, bits)
1762                };
1763
1764                // SAFETY: `i + 2 < i + 4 <= blocks`.
1765                let s2 = unsafe {
1766                    let vx = u8s_32::load_simd(arch, px.add(32 * (i + 2)));
1767                    let bits = (py.add(4 * (i + 2)) as *const u32).read_unaligned();
1768                    masked_sad(vx, bits)
1769                };
1770
1771                // SAFETY: `i + 3 < i + 4 <= blocks`.
1772                let s3 = unsafe {
1773                    let vx = u8s_32::load_simd(arch, px.add(32 * (i + 3)));
1774                    let bits = (py.add(4 * (i + 3)) as *const u32).read_unaligned();
1775                    masked_sad(vx, bits)
1776                };
1777
1778                acc = acc + s0 + s1 + s2 + s3;
1779                i += 4;
1780            }
1781
1782            // Drain remaining full 32-element blocks (0..3 iterations).
1783            while i < blocks {
1784                // SAFETY: `i < blocks` guarantees 32 bytes from `px` and 4 bytes from `py`
1785                // are readable at this offset.
1786                let si = unsafe {
1787                    let vx = u8s_32::load_simd(arch, px.add(32 * i));
1788                    let bits = (py.add(4 * i) as *const u32).read_unaligned();
1789                    masked_sad(vx, bits)
1790                };
1791                acc = acc + si;
1792                i += 1;
1793            }
1794
1795            s = acc.sum_tree() as u32;
1796        }
1797
1798        // Convert block count to element index.
1799        i *= 32;
1800
1801        // Scalar fallback for the remaining < 32 elements.
1802        if i != len {
1803            #[inline(never)]
1804            fn fallback(x: USlice<'_, 8>, y: USlice<'_, 1>, from: usize) -> u32 {
1805                let mut s: u32 = 0;
1806                for i in from..x.len() {
1807                    // SAFETY: `i` is in `from..x.len()`, which equals `y.len()`.
1808                    let (ix, iy) =
1809                        unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
1810                    s += ix * iy;
1811                }
1812                s
1813            }
1814            s += fallback(x, y, i);
1815        }
1816
1817        Ok(MV::new(s))
1818    }
1819}
1820
1821#[cfg(target_arch = "aarch64")]
1822retarget!(
1823    diskann_wide::arch::aarch64::Neon,
1824    InnerProduct,
1825    7,
1826    6,
1827    5,
1828    4,
1829    3,
1830    2,
1831    (8, 4),
1832    (8, 2),
1833    (8, 1)
1834);
1835
1836//////////////////
1837// BitTranspose //
1838//////////////////
1839
1840/// The strategy is to compute the inner product `<x, y>` by decomposing the problem into
1841/// groups of 64-dimensions.
1842///
1843/// For each group, we load the 64-bits of `y` into a word `bits`. And the four 64-bit words
1844/// of the group in `x` in `b0`, `b1`, b2`, and `b3`.
1845///
1846/// Note that bit `i` in `b0` is bit-0 of the `i`-th value in ths group. Likewise, bit `i`
1847/// in `b1` is bit-1 of the same word.
1848///
1849/// This means that we can compute the partial inner product for this group as
1850/// ```math
1851/// (bits & b0).count_ones()                // Contribution of bit 0
1852///     + 2 * (bits & b1).count_ones()      // Contribution of bit 1
1853///     + 4 * (bits & b2).count_ones()      // Contribution of bit 2
1854///     + 8 * (bits & b3).count_ones()      // Contribution of bit 3
1855/// ```
1856/// We process as many full groups as we can.
1857///
1858/// To handle the remainder, we need to be careful about acessing `y` because `BitSlice`
1859/// only guarantees the validity of reads at the byte level. That is - we cannot assume that
1860/// a full 64-bit read is valid.
1861///
1862/// The bit-tranposed `x`, on the other hand, guarantees allocations in blocks of
1863/// 4 * 64-bits, so it can be treated as normal.
1864impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1865    for InnerProduct
1866where
1867    A: Architecture,
1868{
1869    #[inline(always)]
1870    fn run(
1871        self,
1872        _: A,
1873        x: USlice<'_, 4, BitTranspose>,
1874        y: USlice<'_, 1, Dense>,
1875    ) -> MathematicalResult<u32> {
1876        let len = check_lengths!(x, y)?;
1877
1878        // We work in blocks of 64 element.
1879        //
1880        // The `BitTranspose` guarantees read are valid in blocks of 64 elements (32 byte).
1881        // However, the `Dense` representation only pads to bytes.
1882        // Our strategy for dealing with fewer than 64 remaining elements is to reconstruct
1883        // a 64-bit integer from bytes.
1884        let px: *const u64 = x.as_ptr().cast();
1885        let py: *const u64 = y.as_ptr().cast();
1886
1887        let mut i = 0;
1888        let mut s: u32 = 0;
1889
1890        let blocks = len / 64;
1891        while i < blocks {
1892            // SAFETY: `y` is valid for at least `blocks` 64-bit reads and `i < blocks`.
1893            let bits = unsafe { py.add(i).read_unaligned() };
1894
1895            // SAFETY: The layout for `x` is grouped into 32-byte blocks. We've ensured that
1896            // the lengths of the two vectors are the same, so we know that `x` has at least
1897            // `blocks` such regions.
1898            //
1899            // This loads the first 64-bits of block `i` where `i < blocks`.
1900            let b0 = unsafe { px.add(4 * i).read_unaligned() };
1901            s += (bits & b0).count_ones();
1902
1903            // SAFETY: This loads the second 64-bit word of block `i`.
1904            let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1905            s += (bits & b1).count_ones() << 1;
1906
1907            // SAFETY: This loads the third 64-bit word of block `i`.
1908            let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1909            s += (bits & b2).count_ones() << 2;
1910
1911            // SAFETY: This loads the fourth 64-bit word of block `i`.
1912            let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1913            s += (bits & b3).count_ones() << 3;
1914
1915            i += 1;
1916        }
1917
1918        // If the input length is a multiple of 64 - then we're done.
1919        if 64 * i == len {
1920            return Ok(MV::new(s));
1921        }
1922
1923        // Convert blocks to bytes.
1924        let k = i * 8;
1925
1926        // Unpack the last elements from the bit-vector.
1927        //
1928        // SAFETY: The length of the 1-bit BitSlice is `ceil(len / 8)`. This computation
1929        // effectively computes `ceil((64 * floor(len / 64)) / 8)`, which is less.
1930        let py = unsafe { py.cast::<u8>().add(k) };
1931        let bytes_remaining = y.bytes() - k;
1932        let mut bits: u64 = 0;
1933
1934        // Code - generation: Applying `min(8)` gives a constant upper-bound to the
1935        // compiler, allowing better code-generation.
1936        for j in 0..bytes_remaining.min(8) {
1937            // SAFETY: Starting at `py`, there are `bytes_remaining` valid bytes. This
1938            // accesses all of them.
1939            bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1940        }
1941
1942        // Because the upper-bits of the last loaded byte can contain indeterminate bits,
1943        // we must mask out all out-of-bounds bits.
1944        bits &= (0x01u64 << (len - (64 * i))) - 1;
1945
1946        // Combine with the remainders.
1947        //
1948        // SAFETY: The `BitTranspose` permutation always allocates in granularies of blocks.
1949        // This loads the first 64-bit word of the last block.
1950        let b0 = unsafe { px.add(4 * i).read_unaligned() };
1951        s += (bits & b0).count_ones();
1952
1953        // SAFETY: This loads the second 64-bit word of the last block.
1954        let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1955        s += (bits & b1).count_ones() << 1;
1956
1957        // SAFETY: This loads the third 64-bit word of the last block.
1958        let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1959        s += (bits & b2).count_ones() << 2;
1960
1961        // SAFETY: This loads the fourth 64-bit word of the last block.
1962        let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1963        s += (bits & b3).count_ones() << 3;
1964
1965        Ok(MV::new(s))
1966    }
1967}
1968
1969impl
1970    PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1971    for InnerProduct
1972{
1973    fn evaluate(
1974        x: USlice<'_, 4, BitTranspose>,
1975        y: USlice<'_, 1, Dense>,
1976    ) -> MathematicalResult<u32> {
1977        (diskann_wide::ARCH).run2(Self, x, y)
1978    }
1979}
1980
1981////////////////////
1982// Full Precision //
1983////////////////////
1984
1985/// The main trick here is avoiding explicit conversion from 1 bit integers to 32-bit
1986/// floating-point numbers by using `_mm256_permutevar_ps`, which performs a shuffle on two
1987/// independent 128-bit lanes of `f32` values in a register `A` using the lower 2-bits of
1988/// each 32-bit integer in a register `B`.
1989///
1990/// Importantly, this instruction only takes a single cycle and we can avoid any kind of
1991/// masking. Going the route of conversion would require and `AND` operation to isolate
1992/// bottom bits and a somewhat lengthy 32-bit integer to `f32` conversion instruction.
1993///
1994/// The overall strategy broadcasts a 32-bit integer (consisting of 32, 1-bit values) across
1995/// 8 lanes into a register `A`.
1996///
1997/// Each lane is then shifted by a different amount so:
1998///
1999/// * Lane 0 has value 0 as its least significant bit (LSB)
2000/// * Lane 1 has value 1 as its LSB.
2001/// * Lane 2 has value 2 as its LSB.
2002/// * etc.
2003///
2004/// These LSB's are used to power the shuffle function to convert to `f32` values (either
2005/// 0.0 or 1.0) and we can FMA as needed.
2006///
2007/// To process the next group of 8 values, we shift all lanes in `A` by 8-bits so lane 0
2008/// has value 8 as its LSB, lane 1 has value 9 etc.
2009///
2010/// A total of three shifts are applied to extract all 32 1-bit value as `f32` in order.
2011#[cfg(target_arch = "x86_64")]
2012impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
2013    for InnerProduct
2014{
2015    #[inline(always)]
2016    fn run(
2017        self,
2018        arch: diskann_wide::arch::x86_64::V3,
2019        x: &[f32],
2020        y: USlice<'_, 1>,
2021    ) -> MathematicalResult<f32> {
2022        let len = check_lengths!(x, y)?;
2023
2024        use std::arch::x86_64::*;
2025
2026        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2027        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
2028
2029        // Replicate 0s and 1s so we effectively get a shuffle that only depends on the
2030        // bottom bit (instead of the lowest 2).
2031        let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
2032
2033        // Shifts required to offset each lane.
2034        let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
2035
2036        let px: *const f32 = x.as_ptr();
2037        let py: *const u32 = y.as_ptr().cast();
2038
2039        let mut i = 0;
2040        let mut s = f32s::default(arch);
2041
2042        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
2043        let to_f32 = |v: u32s| -> f32s {
2044            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
2045            // which the presence of the `x86_64::V3` architecture guarantees is available.
2046            f32s::from_underlying(arch, unsafe {
2047                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
2048            })
2049        };
2050
2051        // Data is processed in groups of 32 elements.
2052        let blocks = len / 32;
2053        if i < blocks {
2054            let mut s0 = f32s::default(arch);
2055            let mut s1 = f32s::default(arch);
2056
2057            while i < blocks {
2058                // SAFETY: `i < blocks` implies 32-bits are readable from this offset.
2059                let iy = prep(unsafe { py.add(i).read_unaligned() });
2060
2061                // SAFETY: `i < blocks` implies 32 f32 values are readable beginning at `32*i`.
2062                let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
2063                // SAFETY: See above.
2064                let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
2065                // SAFETY: See above.
2066                let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
2067                // SAFETY: See above.
2068                let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
2069
2070                s0 = ix0.mul_add_simd(to_f32(iy), s0);
2071                s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
2072                s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
2073                s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
2074
2075                i += 1;
2076            }
2077            s = s0 + s1;
2078        }
2079
2080        let remainder = len % 32;
2081        if remainder != 0 {
2082            let tail = if len % 8 == 0 { 8 } else { len % 8 };
2083
2084            // SAFETY: Because `remainder != 0`, there is valid memory beginning at the
2085            // offset `blocks`, so this addition remains within an allocated object.
2086            let py = unsafe { py.add(blocks) };
2087
2088            if remainder <= 8 {
2089                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
2090                // The same logic applies to the SIMD loads.
2091                unsafe {
2092                    load_one(py, |iy| {
2093                        let iy = prep(iy);
2094                        let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
2095                        s = ix.mul_add_simd(to_f32(iy), s);
2096                    })
2097                }
2098            } else if remainder <= 16 {
2099                // SAFETY: At least two bytes are readable for `py`.
2100                // The same logic applies to the SIMD loads.
2101                unsafe {
2102                    load_two(py, |iy| {
2103                        let iy = prep(iy);
2104                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2105                        let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
2106                        s = ix0.mul_add_simd(to_f32(iy), s);
2107                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2108                    })
2109                }
2110            } else if remainder <= 24 {
2111                // SAFETY: At least three bytes are readable for `py`.
2112                // The same logic applies to the SIMD loads.
2113                unsafe {
2114                    load_three(py, |iy| {
2115                        let iy = prep(iy);
2116
2117                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2118                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
2119                        let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
2120
2121                        s = ix0.mul_add_simd(to_f32(iy), s);
2122                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2123                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
2124                    })
2125                }
2126            } else {
2127                // SAFETY: At least four bytes are readable for `py`.
2128                // The same logic applies to the SIMD loads.
2129                unsafe {
2130                    load_four(py, |iy| {
2131                        let iy = prep(iy);
2132
2133                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
2134                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
2135                        let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
2136                        let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
2137
2138                        s = ix0.mul_add_simd(to_f32(iy), s);
2139                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
2140                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
2141                        s = ix3.mul_add_simd(to_f32(iy >> 24), s);
2142                    })
2143                }
2144            }
2145        }
2146
2147        Ok(MV::new(s.sum_tree()))
2148    }
2149}
2150
2151/// The strategy used here is almost identical to that used for 1-bit distances. The main
2152/// difference is that now we use the full 2-bit shuffle capabilities of `_mm256_permutevar_ps`
2153/// and ths relatives sizes of the shifts are slightly different.
2154#[cfg(target_arch = "x86_64")]
2155impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
2156    for InnerProduct
2157{
2158    #[inline(always)]
2159    fn run(
2160        self,
2161        arch: diskann_wide::arch::x86_64::V3,
2162        x: &[f32],
2163        y: USlice<'_, 2>,
2164    ) -> MathematicalResult<f32> {
2165        let len = check_lengths!(x, y)?;
2166
2167        use std::arch::x86_64::*;
2168
2169        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2170        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
2171
2172        // This is the lookup table mapping 2-bit patterns to their equivalent `f32`
2173        // representation. The AVX2 shuffle only applies within each 128-bit group of the
2174        // full 256-bit register, so we replicate the contents.
2175        let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
2176
2177        // Shifts required to get logical dimensions shifted to the lower 2-bits of each lane.
2178        let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
2179
2180        let px: *const f32 = x.as_ptr();
2181        let py: *const u32 = y.as_ptr().cast();
2182
2183        let mut i = 0;
2184        let mut s = f32s::default(arch);
2185
2186        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
2187        let to_f32 = |v: u32s| -> f32s {
2188            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
2189            // which the presense of the `x86_64::V3` architecture guarantees is available.
2190            f32s::from_underlying(arch, unsafe {
2191                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
2192            })
2193        };
2194
2195        let blocks = len / 16;
2196        if blocks != 0 {
2197            let mut s0 = f32s::default(arch);
2198            let mut s1 = f32s::default(arch);
2199
2200            // Process 32 elements.
2201            while i + 2 <= blocks {
2202                // SAFETY: `i + 2 <= blocks` implies `py.add(i)` is in-bounds and readable
2203                // for 4 unaligned bytes.
2204                let iy = prep(unsafe { py.add(i).read_unaligned() });
2205
2206                // SAFETY: Same logic as above, just applied to `f32` values instead of
2207                // packed bits.
2208                let (ix0, ix1) = unsafe {
2209                    (
2210                        f32s::load_simd(arch, px.add(16 * i)),
2211                        f32s::load_simd(arch, px.add(16 * i + 8)),
2212                    )
2213                };
2214
2215                s0 = ix0.mul_add_simd(to_f32(iy), s0);
2216                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2217
2218                // SAFETY: `i + 2 <= blocks` implies `py.add(i + 1)` is in-bounds and readable
2219                // for 4 unaligned bytes.
2220                let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
2221
2222                // SAFETY: Same logic as above.
2223                let (ix0, ix1) = unsafe {
2224                    (
2225                        f32s::load_simd(arch, px.add(16 * (i + 1))),
2226                        f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
2227                    )
2228                };
2229
2230                s0 = ix0.mul_add_simd(to_f32(iy), s0);
2231                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2232
2233                i += 2;
2234            }
2235
2236            // Process 16 elements
2237            if i < blocks {
2238                // SAFETY: `i < blocks` implies `py.add(i)` is in-bounds and readable for
2239                // 4 unaligned bytes.
2240                let iy = prep(unsafe { py.add(i).read_unaligned() });
2241
2242                // SAFETY: Same logic as above.
2243                let (ix0, ix1) = unsafe {
2244                    (
2245                        f32s::load_simd(arch, px.add(16 * i)),
2246                        f32s::load_simd(arch, px.add(16 * i + 8)),
2247                    )
2248                };
2249
2250                s0 = ix0.mul_add_simd(to_f32(iy), s0);
2251                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
2252            }
2253
2254            s = s0 + s1;
2255        }
2256
2257        let remainder = len % 16;
2258        if remainder != 0 {
2259            let tail = if len % 8 == 0 { 8 } else { len % 8 };
2260            // SAFETY: Non-zero remainder implies there are readable bytes after the offset
2261            // `blocks`, so the addition is valid.
2262            let py = unsafe { py.add(blocks) };
2263
2264            if remainder <= 4 {
2265                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
2266                // The same logic applies to the SIMD loads.
2267                unsafe {
2268                    load_one(py, |iy| {
2269                        let iy = prep(iy);
2270                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
2271                        s = ix.mul_add_simd(to_f32(iy), s);
2272                    });
2273                }
2274            } else if remainder <= 8 {
2275                // SAFETY: At least two bytes are readable for `py`.
2276                // The same logic applies to the SIMD loads.
2277                unsafe {
2278                    load_two(py, |iy| {
2279                        let iy = prep(iy);
2280                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
2281                        s = ix.mul_add_simd(to_f32(iy), s);
2282                    });
2283                }
2284            } else if remainder <= 12 {
2285                // SAFETY: At least three bytes are readable for `py`.
2286                // The same logic applies to the SIMD loads.
2287                unsafe {
2288                    load_three(py, |iy| {
2289                        let iy = prep(iy);
2290                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
2291                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
2292                        s = ix0.mul_add_simd(to_f32(iy), s);
2293                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
2294                    });
2295                }
2296            } else {
2297                // SAFETY: At least four bytes are readable for `py`.
2298                // The same logic applies to the SIMD loads.
2299                unsafe {
2300                    load_four(py, |iy| {
2301                        let iy = prep(iy);
2302                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
2303                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
2304                        s = ix0.mul_add_simd(to_f32(iy), s);
2305                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
2306                    });
2307                }
2308            }
2309        }
2310
2311        Ok(MV::new(s.sum_tree()))
2312    }
2313}
2314
2315/// The strategy here is similar to the 1 and 2-bit strategies. However, instead of using
2316/// `_mm256_permutevar_ps`, we now go directly for 32-bit integer to 32-bit floating point.
2317///
2318/// This is because the shuffle intrinsic only supports 2-bit shuffles.
2319#[cfg(target_arch = "x86_64")]
2320impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
2321    for InnerProduct
2322{
2323    #[inline(always)]
2324    fn run(
2325        self,
2326        arch: diskann_wide::arch::x86_64::V3,
2327        x: &[f32],
2328        y: USlice<'_, 4>,
2329    ) -> MathematicalResult<f32> {
2330        let len = check_lengths!(x, y)?;
2331
2332        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
2333        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
2334
2335        let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
2336        let mask = i32s::splat(arch, 0x0f);
2337
2338        let to_f32 = |v: u32| -> f32s {
2339            ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
2340        };
2341
2342        let px: *const f32 = x.as_ptr();
2343        let py: *const u32 = y.as_ptr().cast();
2344
2345        let mut i = 0;
2346        let mut s = f32s::default(arch);
2347
2348        let blocks = len / 8;
2349        while i < blocks {
2350            // SAFETY: `i < blocks` implies that 8 `f32` values are readable from `8*i`.
2351            let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
2352            // SAFETY: Same logic as above - but applied to the packed bits.
2353            let iy = to_f32(unsafe { py.add(i).read_unaligned() });
2354            s = ix.mul_add_simd(iy, s);
2355
2356            i += 1;
2357        }
2358
2359        let remainder = len % 8;
2360        if remainder != 0 {
2361            let f = |iy| {
2362                // SAFETY: The epilogue handles at most 8 values. Since the remainder is
2363                // non-zero, the pointer arithmetic is in-bounds and `load_simd_first` will
2364                // avoid accessing the out-of-bounds elements.
2365                let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
2366                s = ix.mul_add_simd(to_f32(iy), s);
2367            };
2368
2369            // SAFETY: Non-zero remainder means there are readable bytes from the offset
2370            // `blocks`.
2371            let py = unsafe { py.add(blocks) };
2372
2373            if remainder <= 2 {
2374                // SAFETY: Non-zero remainder less than 2 implies that one byte is readable.
2375                unsafe { load_one(py, f) };
2376            } else if remainder <= 4 {
2377                // SAFETY: At least two bytes are readable from `py`.
2378                unsafe { load_two(py, f) };
2379            } else if remainder <= 6 {
2380                // SAFETY: At least three bytes are readable from `py`.
2381                unsafe { load_three(py, f) };
2382            } else {
2383                // SAFETY: At least four bytes are readable from `py`.
2384                unsafe { load_four(py, f) };
2385            }
2386        }
2387
2388        Ok(MV::new(s.sum_tree()))
2389    }
2390}
2391
2392impl<const N: usize>
2393    Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
2394    for InnerProduct
2395where
2396    Unsigned: Representation<N>,
2397{
2398    /// A fallback implementation that uses scaler indexing to retrieve values from
2399    /// the corresponding `BitSlice`.
2400    #[inline(always)]
2401    fn run(
2402        self,
2403        _: diskann_wide::arch::Scalar,
2404        x: &[f32],
2405        y: USlice<'_, N>,
2406    ) -> MathematicalResult<f32> {
2407        check_lengths!(x, y)?;
2408
2409        let mut s = 0.0;
2410        for (i, x) in x.iter().enumerate() {
2411            // SAFETY: We've ensured that `x.len() == y.len()`, so this access is
2412            // always inbounds.
2413            let y = unsafe { y.get_unchecked(i) } as f32;
2414            s += x * y;
2415        }
2416
2417        Ok(MV::new(s))
2418    }
2419}
2420
2421/// Implement `Target2` for higher architecture in terms of the scalar fallback.
2422macro_rules! ip_retarget {
2423    ($arch:path, $N:literal) => {
2424        impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
2425            for InnerProduct
2426        {
2427            #[inline(always)]
2428            fn run(
2429                self,
2430                arch: $arch,
2431                x: &[f32],
2432                y: USlice<'_, $N>,
2433            ) -> MathematicalResult<f32> {
2434                self.run(arch.retarget(), x, y)
2435            }
2436        }
2437    };
2438    ($arch:path, $($Ns:literal),*) => {
2439        $(ip_retarget!($arch, $Ns);)*
2440    }
2441}
2442
2443#[cfg(target_arch = "x86_64")]
2444ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
2445
2446#[cfg(target_arch = "x86_64")]
2447ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
2448
2449#[cfg(target_arch = "aarch64")]
2450ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8);
2451
2452/// Delegate the implementation of `PureDistanceFunction` to `diskann_wide::arch::Target2`
2453/// with the current architectures.
2454macro_rules! dispatch_full_ip {
2455    ($N:literal) => {
2456        /// Compute the inner product between `x` and `y`.
2457        ///
2458        /// Returns an error if the arguments have different lengths.
2459        impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
2460            for InnerProduct
2461        {
2462            fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
2463                Self.run(ARCH, x, y)
2464            }
2465        }
2466    };
2467    ($($Ns:literal),*) => {
2468        $(dispatch_full_ip!($Ns);)*
2469    }
2470}
2471
2472dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
2473
2474///////////
2475// Tests //
2476///////////
2477
2478#[cfg(test)]
2479mod tests {
2480    use std::{collections::HashMap, fmt::Display, sync::LazyLock};
2481
2482    use diskann_utils::{Reborrow, lazy_format};
2483    use rand::{
2484        Rng, SeedableRng,
2485        distr::{Distribution, Uniform},
2486        rngs::StdRng,
2487        seq::IndexedRandom,
2488    };
2489
2490    use super::*;
2491    use crate::bits::{BoxedBitSlice, Representation, Unsigned};
2492
2493    type MR = MathematicalResult<u32>;
2494
2495    #[inline(always)]
2496    fn should_check_this_dimension(dim: usize) -> bool {
2497        if cfg!(miri) {
2498            return dim.is_power_of_two()
2499                || (dim > 1 && (dim - 1).is_power_of_two())
2500                || (dim < 64 && (dim % 8 == 7));
2501        }
2502
2503        true
2504    }
2505
2506    /////////////////////////
2507    // Unsigned Bit Slices //
2508    /////////////////////////
2509
2510    // This test works by generating random integer codes for the compressed vectors,
2511    // then uses the functions implemented in `vector` to compute the expected result of
2512    // the computation in "full precision integer space".
2513    //
2514    // We verify that the exact same results are returned by each computation.
2515    fn test_bitslice_distances<const NBITS: usize, R>(
2516        dim_max: usize,
2517        trials_per_dim: usize,
2518        evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2519        evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2520        context: &str,
2521        rng: &mut R,
2522    ) where
2523        Unsigned: Representation<NBITS>,
2524        R: Rng,
2525    {
2526        let domain = Unsigned::domain_const::<NBITS>();
2527        let min: i64 = *domain.start();
2528        let max: i64 = *domain.end();
2529
2530        let dist = Uniform::new_inclusive(min, max).unwrap();
2531
2532        for dim in 0..dim_max {
2533            if !should_check_this_dimension(dim) {
2534                continue;
2535            }
2536
2537            let mut x_reference: Vec<u8> = vec![0; dim];
2538            let mut y_reference: Vec<u8> = vec![0; dim];
2539
2540            let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2541            let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2542
2543            for trial in 0..trials_per_dim {
2544                x_reference
2545                    .iter_mut()
2546                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2547                y_reference
2548                    .iter_mut()
2549                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2550
2551                // Fill the input slices with 1's so we can catch situations where we don't
2552                // correctly handle odd remaining elements.
2553                x.as_mut_slice().fill(u8::MAX);
2554                y.as_mut_slice().fill(u8::MAX);
2555
2556                for i in 0..dim {
2557                    x.set(i, x_reference[i].into()).unwrap();
2558                    y.set(i, y_reference[i].into()).unwrap();
2559                }
2560
2561                // Check L2
2562                let expected: MV<f32> =
2563                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2564
2565                let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2566
2567                // Integer computations should be exact.
2568                assert_eq!(
2569                    expected.into_inner(),
2570                    got.into_inner() as f32,
2571                    "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2572                    NBITS,
2573                    dim,
2574                    trial,
2575                    context,
2576                );
2577
2578                // Check IP
2579                let expected: MV<f32> =
2580                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2581
2582                let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2583
2584                // Integer computations should be exact.
2585                assert_eq!(
2586                    expected.into_inner(),
2587                    got.into_inner() as f32,
2588                    "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2589                    NBITS,
2590                    dim,
2591                    trial,
2592                    context,
2593                );
2594            }
2595        }
2596
2597        // Test that we correctly return error types for length mismatches.
2598        let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2599        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2600
2601        assert!(
2602            evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2603            "context: {}",
2604            context
2605        );
2606        assert!(
2607            evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2608            "context: {}",
2609            context
2610        );
2611
2612        assert!(
2613            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2614            "context: {}",
2615            context
2616        );
2617        assert!(
2618            evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2619            "context: {}",
2620            context
2621        );
2622    }
2623
2624    cfg_if::cfg_if! {
2625        if #[cfg(miri)] {
2626            const MAX_DIM: usize = 132;
2627            const TRIALS_PER_DIM: usize = 1;
2628        } else {
2629            const MAX_DIM: usize = 256;
2630            const TRIALS_PER_DIM: usize = 20;
2631        }
2632    }
2633
2634    // For the bit-slice kernels, we want to use different maximum dimensions for the distance
2635    // test depending on the implementation of the kernel, and whether or not we are running
2636    // under Miri.
2637    //
2638    // For implementations that use the scalar fallback, we need not set very high bounds
2639    // (particularly when running under miri) because the implementations are quite simple.
2640    //
2641    // However, some SIMD kernels (especially for the lower bit widths), require higher bounds
2642    // to trigger all possible corner cases.
2643    static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2644        use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4};
2645        [
2646            (Key::new(1, Scalar), Bounds::new(64, 64)),
2647            (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2648            (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2649            (Key::new(1, Neon), Bounds::new(64, 64)),
2650            (Key::new(2, Scalar), Bounds::new(64, 64)),
2651            // Need a higher miri-amount due to the larget block size
2652            (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2653            (Key::new(2, X86_64_V4), Bounds::new(768, 600)), // main loop processes 256 items
2654            (Key::new(2, Neon), Bounds::new(64, 64)),
2655            (Key::new(3, Scalar), Bounds::new(64, 64)),
2656            (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2657            (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2658            (Key::new(3, Neon), Bounds::new(64, 64)),
2659            (Key::new(4, Scalar), Bounds::new(64, 64)),
2660            // Need a higher miri-amount due to the larget block size
2661            (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2662            (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2663            (Key::new(4, Neon), Bounds::new(64, 64)),
2664            (Key::new(5, Scalar), Bounds::new(64, 64)),
2665            (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2666            (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2667            (Key::new(5, Neon), Bounds::new(64, 64)),
2668            (Key::new(6, Scalar), Bounds::new(64, 64)),
2669            (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2670            (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2671            (Key::new(6, Neon), Bounds::new(64, 64)),
2672            (Key::new(7, Scalar), Bounds::new(64, 64)),
2673            (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2674            (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2675            (Key::new(7, Neon), Bounds::new(64, 64)),
2676            (Key::new(8, Scalar), Bounds::new(64, 64)),
2677            (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2678            (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2679            (Key::new(8, Neon), Bounds::new(64, 64)),
2680        ]
2681        .into_iter()
2682        .collect()
2683    });
2684
2685    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2686    enum ArchKey {
2687        Scalar,
2688        #[expect(non_camel_case_types)]
2689        X86_64_V3,
2690        #[expect(non_camel_case_types)]
2691        X86_64_V4,
2692        Neon,
2693    }
2694
2695    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2696    struct Key {
2697        nbits: usize,
2698        arch: ArchKey,
2699    }
2700
2701    impl Key {
2702        fn new(nbits: usize, arch: ArchKey) -> Self {
2703            Self { nbits, arch }
2704        }
2705    }
2706
2707    #[derive(Debug, Clone, Copy)]
2708    struct Bounds {
2709        standard: usize,
2710        miri: usize,
2711    }
2712
2713    impl Bounds {
2714        fn new(standard: usize, miri: usize) -> Self {
2715            Self { standard, miri }
2716        }
2717
2718        fn get(&self) -> usize {
2719            if cfg!(miri) { self.miri } else { self.standard }
2720        }
2721    }
2722
2723    macro_rules! test_bitslice {
2724        ($name:ident, $nbits:literal, $seed:literal) => {
2725            #[test]
2726            fn $name() {
2727                let mut rng = StdRng::seed_from_u64($seed);
2728
2729                let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2730
2731                test_bitslice_distances::<$nbits, _>(
2732                    max_dim,
2733                    TRIALS_PER_DIM,
2734                    &|x, y| SquaredL2::evaluate(x, y),
2735                    &|x, y| InnerProduct::evaluate(x, y),
2736                    "pure distance function",
2737                    &mut rng,
2738                );
2739
2740                test_bitslice_distances::<$nbits, _>(
2741                    max_dim,
2742                    TRIALS_PER_DIM,
2743                    &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2744                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2745                    "scalar arch",
2746                    &mut rng,
2747                );
2748
2749                // Architecture Specific.
2750                #[cfg(target_arch = "x86_64")]
2751                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2752                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2753                    test_bitslice_distances::<$nbits, _>(
2754                        max_dim,
2755                        TRIALS_PER_DIM,
2756                        &|x, y| arch.run2(SquaredL2, x, y),
2757                        &|x, y| arch.run2(InnerProduct, x, y),
2758                        "x86-64-v3",
2759                        &mut rng,
2760                    );
2761                }
2762
2763                #[cfg(target_arch = "x86_64")]
2764                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2765                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2766                    test_bitslice_distances::<$nbits, _>(
2767                        max_dim,
2768                        TRIALS_PER_DIM,
2769                        &|x, y| arch.run2(SquaredL2, x, y),
2770                        &|x, y| arch.run2(InnerProduct, x, y),
2771                        "x86-64-v4",
2772                        &mut rng,
2773                    );
2774                }
2775
2776                #[cfg(target_arch = "aarch64")]
2777                if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2778                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get();
2779                    test_bitslice_distances::<$nbits, _>(
2780                        max_dim,
2781                        TRIALS_PER_DIM,
2782                        &|x, y| arch.run2(SquaredL2, x, y),
2783                        &|x, y| arch.run2(InnerProduct, x, y),
2784                        "neon",
2785                        &mut rng,
2786                    );
2787                }
2788            }
2789        };
2790    }
2791
2792    test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2793    test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2794    test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2795    test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2796    test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2797    test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2798    test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2799    test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2800
2801    ///////////////////////////
2802    // Hamming Bit Distances //
2803    ///////////////////////////
2804
2805    fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2806    where
2807        R: Rng,
2808    {
2809        let dist: [i8; 2] = [-1, 1];
2810
2811        for dim in 0..dim_max {
2812            if !should_check_this_dimension(dim) {
2813                continue;
2814            }
2815
2816            let mut x_reference: Vec<i8> = vec![1; dim];
2817            let mut y_reference: Vec<i8> = vec![1; dim];
2818
2819            let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2820            let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2821
2822            for _ in 0..trials_per_dim {
2823                x_reference
2824                    .iter_mut()
2825                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2826                y_reference
2827                    .iter_mut()
2828                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2829
2830                // Fill the input slices with 1's so we can catch situations where we don't
2831                // correctly handle odd remaining elements.
2832                x.as_mut_slice().fill(u8::MAX);
2833                y.as_mut_slice().fill(u8::MAX);
2834
2835                for i in 0..dim {
2836                    x.set(i, x_reference[i].into()).unwrap();
2837                    y.set(i, y_reference[i].into()).unwrap();
2838                }
2839
2840                // We can check equality by evaluating the L2 distance between the reference
2841                // vectors.
2842                //
2843                // This is proportional to the Hamming distance by a factor of 4 (since the
2844                // distance betwwen +1 and -1 is 2 - and 2^2 = 4.
2845                let expected: MV<f32> =
2846                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2847                let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2848                assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2849            }
2850        }
2851
2852        let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2853        let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2854        assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2855        assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2856    }
2857
2858    #[test]
2859    fn test_hamming_distance() {
2860        let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2861        test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2862    }
2863
2864    ///////////////////
2865    // Heterogeneous //
2866    ///////////////////
2867
2868    fn test_bit_transpose_distances<R>(
2869        dim_max: usize,
2870        trials_per_dim: usize,
2871        evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2872        context: &str,
2873        rng: &mut R,
2874    ) where
2875        R: Rng,
2876    {
2877        let dist_4bit = {
2878            let domain = Unsigned::domain_const::<4>();
2879            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2880        };
2881
2882        let dist_1bit = {
2883            let domain = Unsigned::domain_const::<1>();
2884            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2885        };
2886
2887        for dim in 0..dim_max {
2888            if !should_check_this_dimension(dim) {
2889                continue;
2890            }
2891
2892            let mut x_reference: Vec<u8> = vec![0; dim];
2893            let mut y_reference: Vec<u8> = vec![0; dim];
2894
2895            let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2896            let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2897
2898            for trial in 0..trials_per_dim {
2899                x_reference
2900                    .iter_mut()
2901                    .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2902                y_reference
2903                    .iter_mut()
2904                    .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2905
2906                // First - pre-set all the values in the bit-slices to 1.
2907                x.as_mut_slice().fill(u8::MAX);
2908                y.as_mut_slice().fill(u8::MAX);
2909
2910                for i in 0..dim {
2911                    x.set(i, x_reference[i].into()).unwrap();
2912                    y.set(i, y_reference[i].into()).unwrap();
2913                }
2914
2915                // Check IP
2916                let expected: MV<f32> =
2917                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2918
2919                let got = evaluate_ip(x.reborrow(), y.reborrow());
2920
2921                // Integer computations should be exact.
2922                assert_eq!(
2923                    expected.into_inner(),
2924                    got.unwrap().into_inner() as f32,
2925                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
2926                    dim,
2927                    trial,
2928                    context,
2929                );
2930            }
2931        }
2932
2933        let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2934        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2935        assert!(
2936            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2937            "context: {}",
2938            context
2939        );
2940
2941        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2942        assert!(
2943            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2944            "context: {}",
2945            context
2946        );
2947    }
2948
2949    #[test]
2950    fn test_bit_transpose_distance() {
2951        let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2952
2953        test_bit_transpose_distances(
2954            MAX_DIM,
2955            TRIALS_PER_DIM,
2956            &|x, y| InnerProduct::evaluate(x, y),
2957            "pure distance function",
2958            &mut rng,
2959        );
2960
2961        test_bit_transpose_distances(
2962            MAX_DIM,
2963            TRIALS_PER_DIM,
2964            &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2965            "scalar",
2966            &mut rng,
2967        );
2968
2969        // Architecture Specific.
2970        #[cfg(target_arch = "x86_64")]
2971        if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2972            test_bit_transpose_distances(
2973                MAX_DIM,
2974                TRIALS_PER_DIM,
2975                &|x, y| arch.run2(InnerProduct, x, y),
2976                "x86-64-v3",
2977                &mut rng,
2978            );
2979        }
2980
2981        // Architecture Specific.
2982        #[cfg(target_arch = "x86_64")]
2983        if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2984            test_bit_transpose_distances(
2985                MAX_DIM,
2986                TRIALS_PER_DIM,
2987                &|x, y| arch.run2(InnerProduct, x, y),
2988                "x86-64-v4",
2989                &mut rng,
2990            );
2991        }
2992
2993        // Architecture Specific.
2994        #[cfg(target_arch = "aarch64")]
2995        if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2996            test_bit_transpose_distances(
2997                MAX_DIM,
2998                TRIALS_PER_DIM,
2999                &|x, y| arch.run2(InnerProduct, x, y),
3000                "neon",
3001                &mut rng,
3002            );
3003        }
3004    }
3005
3006    //////////
3007    // Full //
3008    //////////
3009
3010    fn test_full_distances<const NBITS: usize>(
3011        dim_max: usize,
3012        trials_per_dim: usize,
3013        evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
3014        context: &str,
3015        rng: &mut impl Rng,
3016    ) where
3017        Unsigned: Representation<NBITS>,
3018    {
3019        // let dist_float = Uniform::new_inclusive(-2.0f32, 2.0f32).unwrap();
3020        let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
3021        let dist_bit = {
3022            let domain = Unsigned::domain_const::<NBITS>();
3023            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
3024        };
3025
3026        for dim in 0..dim_max {
3027            if !should_check_this_dimension(dim) {
3028                continue;
3029            }
3030
3031            let mut x: Vec<f32> = vec![0.0; dim];
3032
3033            let mut y_reference: Vec<u8> = vec![0; dim];
3034            let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
3035
3036            for trial in 0..trials_per_dim {
3037                x.iter_mut()
3038                    .for_each(|i| *i = *dist_float.choose(rng).unwrap());
3039                y_reference
3040                    .iter_mut()
3041                    .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
3042
3043                // First - pre-set all the values in the bit-slices to 1.
3044                y.as_mut_slice().fill(u8::MAX);
3045
3046                let mut expected = 0.0;
3047                for i in 0..dim {
3048                    y.set(i, y_reference[i].into()).unwrap();
3049                    expected += y_reference[i] as f32 * x[i];
3050                }
3051
3052                // Check IP
3053                let got = evaluate_ip(&x, y.reborrow()).unwrap();
3054
3055                // Integer computations should be exact.
3056                assert_eq!(
3057                    expected,
3058                    got.into_inner(),
3059                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
3060                    dim,
3061                    trial,
3062                    context,
3063                );
3064
3065                // Ensure that using the `Scalar` architecture providers the same
3066                // results.
3067                let scalar: MV<f32> = InnerProduct
3068                    .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
3069                    .unwrap();
3070                assert_eq!(got.into_inner(), scalar.into_inner());
3071            }
3072        }
3073
3074        // Error Checking
3075        let x = vec![0.0; 10];
3076        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
3077        assert!(
3078            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
3079            "context: {}",
3080            context
3081        );
3082
3083        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
3084        assert!(
3085            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
3086            "context: {}",
3087            context
3088        );
3089    }
3090
3091    macro_rules! test_full {
3092        ($name:ident, $nbits:literal, $seed:literal) => {
3093            #[test]
3094            fn $name() {
3095                let mut rng = StdRng::seed_from_u64($seed);
3096
3097                test_full_distances::<$nbits>(
3098                    MAX_DIM,
3099                    TRIALS_PER_DIM,
3100                    &|x, y| InnerProduct::evaluate(x, y),
3101                    "pure distance function",
3102                    &mut rng,
3103                );
3104
3105                test_full_distances::<$nbits>(
3106                    MAX_DIM,
3107                    TRIALS_PER_DIM,
3108                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
3109                    "scalar",
3110                    &mut rng,
3111                );
3112
3113                // Architecture Specific.
3114                #[cfg(target_arch = "x86_64")]
3115                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3116                    test_full_distances::<$nbits>(
3117                        MAX_DIM,
3118                        TRIALS_PER_DIM,
3119                        &|x, y| arch.run2(InnerProduct, x, y),
3120                        "x86-64-v3",
3121                        &mut rng,
3122                    );
3123                }
3124
3125                #[cfg(target_arch = "x86_64")]
3126                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
3127                    test_full_distances::<$nbits>(
3128                        MAX_DIM,
3129                        TRIALS_PER_DIM,
3130                        &|x, y| arch.run2(InnerProduct, x, y),
3131                        "x86-64-v4",
3132                        &mut rng,
3133                    );
3134                }
3135
3136                #[cfg(target_arch = "aarch64")]
3137                if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
3138                    test_full_distances::<$nbits>(
3139                        MAX_DIM,
3140                        TRIALS_PER_DIM,
3141                        &|x, y| arch.run2(InnerProduct, x, y),
3142                        "neon",
3143                        &mut rng,
3144                    );
3145                }
3146            }
3147        };
3148    }
3149
3150    test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
3151    test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
3152    test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
3153    test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
3154    test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
3155    test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
3156    test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
3157    test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
3158
3159    ///////////////////////////////////////////
3160    // Heterogeneous: USlice<8> × USlice<M> //
3161    ///////////////////////////////////////////
3162
3163    /// Helper that builds vectors from a fill function and asserts the
3164    /// inner product matches.
3165    struct HetCase<const M: usize> {
3166        x_vals: Vec<i64>,
3167        y_vals: Vec<i64>,
3168    }
3169
3170    impl<const M: usize> HetCase<M>
3171    where
3172        Unsigned: Representation<M>,
3173    {
3174        fn new(dim: usize, fill: impl FnMut(usize) -> (i64, i64)) -> Self {
3175            let (x_vals, y_vals) = (0..dim).map(fill).unzip();
3176            Self { x_vals, y_vals }
3177        }
3178
3179        fn check_with(
3180            &self,
3181            label: impl Display,
3182            evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3183        ) {
3184            let dim = self.x_vals.len();
3185            let mut x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
3186            let mut y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim);
3187            // Pre-fill with 0xFF to catch trailing-element bugs.
3188            x.as_mut_slice().fill(u8::MAX);
3189            y.as_mut_slice().fill(u8::MAX);
3190            for (i, (&xv, &yv)) in self.x_vals.iter().zip(&self.y_vals).enumerate() {
3191                x.set(i, xv).unwrap();
3192                y.set(i, yv).unwrap();
3193            }
3194            let expected: u32 = self
3195                .x_vals
3196                .iter()
3197                .zip(&self.y_vals)
3198                .map(|(&a, &b)| a as u32 * b as u32)
3199                .sum();
3200            let got = evaluate(x.reborrow(), y.reborrow()).unwrap().into_inner();
3201            assert_eq!(expected, got, "{} failed for dim = {}", label, dim);
3202        }
3203    }
3204
3205    /// Fuzz test helper: random vectors across many dimensions.
3206    fn fuzz_heterogeneous_ip<const M: usize>(
3207        dim_max: usize,
3208        trials_per_dim: usize,
3209        max_val: i64,
3210        evaluate_ip: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3211        context: &str,
3212        rng: &mut impl Rng,
3213    ) where
3214        Unsigned: Representation<M>,
3215    {
3216        let dist_8bit = Uniform::new_inclusive(0i64, 255i64).unwrap();
3217        let dist_mbit = Uniform::new_inclusive(0i64, max_val).unwrap();
3218
3219        for dim in 0..dim_max {
3220            for trial in 0..trials_per_dim {
3221                HetCase::<M>::new(dim, |_| {
3222                    (dist_8bit.sample(&mut *rng), dist_mbit.sample(&mut *rng))
3223                })
3224                .check_with(
3225                    lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M),
3226                    evaluate_ip,
3227                );
3228            }
3229
3230            // Length mismatch → error.
3231            let x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
3232            let y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim + 1);
3233            assert!(
3234                evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
3235                "context: {}",
3236                context,
3237            );
3238        }
3239    }
3240
3241    /// All values at maximum: x[i] = 255, y[i] = max_val.
3242    /// Confirms no i16 saturation in vpmaddubsw.
3243    fn het_test_max_values<const M: usize>(
3244        max_val: i64,
3245        context: &str,
3246        evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3247    ) where
3248        Unsigned: Representation<M>,
3249    {
3250        let dims = [127, 128, 129, 255, 256, 512, 768, 896, 3072];
3251        for &dim in &dims {
3252            let case = HetCase::<M>::new(dim, |_| (255, max_val));
3253            case.check_with(lazy_format!("max-value {context} dim={dim}"), evaluate);
3254        }
3255    }
3256
3257    /// Known-answer tests to catch bit-ordering bugs.
3258    fn het_test_known_answers<const M: usize>(
3259        max_val: i64,
3260        evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3261    ) where
3262        Unsigned: Representation<M>,
3263    {
3264        // _mm256_addubs_epi8 unsigned treatment: x[i] = 200 (> 127), y[i] = max_val.
3265        // Correct: 200 × max_val per element.
3266        HetCase::<M>::new(64, |_| (200, max_val)).check_with("vpmaddubsw operand-order", evaluate);
3267
3268        // Ascending x, constant y.
3269        let y_val = (max_val / 2).max(1);
3270        HetCase::<M>::new(128, |i| ((i % 256) as i64, y_val))
3271            .check_with("ascending-x constant-y", evaluate);
3272
3273        // Single element (pure scalar fallback).
3274        HetCase::<M>::new(1, |_| (200, max_val)).check_with("single element", evaluate);
3275    }
3276
3277    /// Exhaustive edge-case coverage.
3278    fn het_test_edge_cases<const M: usize>(
3279        max_val: i64,
3280        block_size: usize,
3281        evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
3282    ) where
3283        Unsigned: Representation<M>,
3284    {
3285        let y_half = (max_val / 2).max(1);
3286
3287        // One side zero.
3288        HetCase::<M>::new(64, |_| (0, max_val)).check_with("x-zero y-nonzero", evaluate);
3289        HetCase::<M>::new(64, |_| (255, 0)).check_with("y-zero x-nonzero", evaluate);
3290
3291        // Every dimension from 0..block_size+1 (scalar fallback boundary).
3292        for dim in 0..=(block_size + 1) {
3293            HetCase::<M>::new(dim, |_| (3, y_half)).check_with("uniform fill", evaluate);
3294        }
3295
3296        // Exact block boundaries.
3297        for &dim in &[block_size, 2 * block_size, 4 * block_size, 8 * block_size] {
3298            HetCase::<M>::new(dim, |_| (100, max_val)).check_with("exact block boundary", evaluate);
3299        }
3300
3301        // x varies, y constant.
3302        HetCase::<M>::new(300, |i| ((i % 256) as i64, 1))
3303            .check_with("x-varies y-constant", evaluate);
3304
3305        // x constant, y varies.
3306        HetCase::<M>::new(300, |i| (1, (i as i64) % (max_val + 1)))
3307            .check_with("x-constant y-varies", evaluate);
3308
3309        // Alternating pattern.
3310        HetCase::<M>::new(128, |i| if i % 2 == 0 { (255, max_val) } else { (0, 0) })
3311            .check_with("alternating pattern", evaluate);
3312
3313        // Opposite alternating pattern.
3314        HetCase::<M>::new(128, |i| if i % 2 == 0 { (0, 0) } else { (255, max_val) })
3315            .check_with("opposite alternating", evaluate);
3316
3317        // Large accumulation check for overflow.
3318        HetCase::<M>::new(1024, |_| (255, max_val)).check_with("large accumulation", evaluate);
3319
3320        // x > 127 sweep (vpmaddubsw unsigned treatment).
3321        for x_val in [128i64, 170, 200, 240, 255] {
3322            HetCase::<M>::new(block_size, move |_| (x_val, y_half))
3323                .check_with(lazy_format!("x > 127 (x_val={x_val})"), evaluate);
3324        }
3325
3326        // Dim = block_size - 1 (no full block, all scalar).
3327        HetCase::<M>::new(block_size - 1, |i| {
3328            (
3329                ((i * 7 + 3) % 256) as i64,
3330                ((i * 11 + 5) as i64) % (max_val + 1),
3331            )
3332        })
3333        .check_with("dim=block_size-1 (all scalar)", evaluate);
3334
3335        // 4× unroll boundary exercises.
3336        let unroll4 = 4 * block_size;
3337        for &dim in &[
3338            unroll4,
3339            unroll4 + 1,
3340            unroll4 + block_size,
3341            unroll4 + block_size + 1,
3342        ] {
3343            HetCase::<M>::new(dim, |i| {
3344                (((i + 1) % 256) as i64, ((i + 1) as i64) % (max_val + 1))
3345            })
3346            .check_with("unroll boundary", evaluate);
3347        }
3348    }
3349
3350    macro_rules! heterogeneous_ip_tests_8xM {
3351        (
3352            mod_name: $mod:ident,
3353            M: $M:literal,
3354            max_val: $max_val:literal,
3355            block_size: $block_size:literal,
3356            seed_fuzz: $seed_fuzz:literal,
3357        ) => {
3358            mod $mod {
3359                use super::*;
3360
3361                #[test]
3362                fn all_ip_dispatches() {
3363                    let mut rng = StdRng::seed_from_u64($seed_fuzz);
3364
3365                    fuzz_heterogeneous_ip::<$M>(
3366                        MAX_DIM,
3367                        TRIALS_PER_DIM,
3368                        $max_val,
3369                        &|x, y| InnerProduct::evaluate(x, y),
3370                        "pure distance function",
3371                        &mut rng,
3372                    );
3373                    fuzz_heterogeneous_ip::<$M>(
3374                        MAX_DIM,
3375                        TRIALS_PER_DIM,
3376                        $max_val,
3377                        &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
3378                        "scalar arch",
3379                        &mut rng,
3380                    );
3381                    #[cfg(target_arch = "x86_64")]
3382                    if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3383                        fuzz_heterogeneous_ip::<$M>(
3384                            MAX_DIM,
3385                            TRIALS_PER_DIM,
3386                            $max_val,
3387                            &|x, y| arch.run2(InnerProduct, x, y),
3388                            "x86-64-v3",
3389                            &mut rng,
3390                        );
3391                    }
3392                    #[cfg(target_arch = "x86_64")]
3393                    if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
3394                        fuzz_heterogeneous_ip::<$M>(
3395                            MAX_DIM,
3396                            TRIALS_PER_DIM,
3397                            $max_val,
3398                            &|x, y| arch.run2(InnerProduct, x, y),
3399                            "x86-64-v4",
3400                            &mut rng,
3401                        );
3402                    }
3403                }
3404
3405                #[test]
3406                fn max_values() {
3407                    het_test_max_values::<$M>($max_val, "dispatch", &|x, y| {
3408                        InnerProduct::evaluate(x, y)
3409                    });
3410                    #[cfg(target_arch = "x86_64")]
3411                    if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
3412                        het_test_max_values::<$M>($max_val, "V3", &|x, y| {
3413                            arch.run2(InnerProduct, x, y)
3414                        });
3415                    }
3416                }
3417
3418                #[test]
3419                fn known_answers() {
3420                    het_test_known_answers::<$M>($max_val, &|x, y| InnerProduct::evaluate(x, y));
3421                }
3422
3423                #[test]
3424                fn edge_cases() {
3425                    het_test_edge_cases::<$M>($max_val, $block_size, &|x, y| {
3426                        InnerProduct::evaluate(x, y)
3427                    });
3428                }
3429            }
3430        };
3431    }
3432
3433    heterogeneous_ip_tests_8xM! {
3434        mod_name: heterogeneous_ip_8x4,
3435        M: 4,
3436        max_val: 15,
3437        block_size: 32,
3438        seed_fuzz: 0xd3a7f1c09b2e4856,
3439    }
3440
3441    heterogeneous_ip_tests_8xM! {
3442        mod_name: heterogeneous_ip_8x2,
3443        M: 2,
3444        max_val: 3,
3445        block_size: 64,
3446        seed_fuzz: 0x82c4a6e809f1d3b5,
3447    }
3448
3449    heterogeneous_ip_tests_8xM! {
3450        mod_name: heterogeneous_ip_8x1,
3451        M: 1,
3452        max_val: 1,
3453        block_size: 32,
3454        seed_fuzz: 0x1b17_a5e7c2d0f839,
3455    }
3456}