Skip to main content

diskann_quantization/bits/
distances.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! # Low-level functions
7//!
8//! The methods here are meant to be primitives used by the distance functions for the
9//! various scalar-quantized-like quantizers.
10//!
11//! As such, they typically return integer distance results since they largely operate over
12//! raw bit-slices.
13//!
14//! ## Micro-architecture Mapping
15//!
16//! There are two interfaces for interacting with the distance primitives:
17//!
18//! * [`diskann_wide::arch::Target2`]: A micro-architecture aware interface where the target
19//!   micro-architecture is provided as an explicit argument.
20//!
21//!   This can be used in conjunction with [`diskann_wide::Architecture::run2`] to apply the
22//!   necessary target-features to opt-into newer architecture code generation when
23//!   compiling the whole binary for an older architecture.
24//!
25//!   This interface is also composable with micro-architecture dispatching done higher in
26//!   the callstack, and so should be preferred when incorporating into quantizer distance
27//!   computations.
28//!
29//! * [`diskann_vector::PureDistanceFunction`]: If micro-architecture awareness is not needed,
30//!   this provides a simple interface targeting [`diskann_wide::ARCH`] (the current compilation
31//!   architecture).
32//!
33//!   This interface will always yield a binary compatible with the compilation architecture
34//!   target, but will not enable faster code-paths when compiling for older architectures.
35//!
36//! The following table summarizes the implementation status of kernels. All kernels have
37//! `diskann_wide::arch::Scalar` implementation fallbacks.
38//!
39//! Implementation Kind:
40//!
41//! * "Fallback": A fallback implementation using scalar indexing.
42//!
43//! * "Optimized": A better implementation than "fallback" that does not contain
44//!   target-depeendent code, instead relying on compiler optimizations.
45//!
46//!   Micro-architecture dispatch is still relevant as it allows the compiler to generate
47//!   better code for newer machines.
48//!
49//! * "Yes": Architecture specific SIMD implementation exists.
50//!
51//! * "No": Architecture specific implementation does not exist - the next most-specific
52//!   implementation is used. For example, if a `x86-64-v3` implementation does not exist,
53//!   then the "scalar" implementation will be used instead.
54//!
55//! Type Aliases
56//!
57//! * `USlice<N>`: `BitSlice<N, Unsigned, Dense>`
58//! * `TSlice<N>`: `BitSlice<N, Unsigned, BitTranspose>`
59//! * `BSlice`: `BitSlice<1, Binary, Dense>`
60//!
61//! * `MV<T>`: [`diskann_vector::MathematicalValue<T>`]
62//!
63//! ### Inner Product
64//!
65//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 |
66//! |---------------|---------------|-----------|-----------|---------------|-----------|
67//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   |
68//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Yes       |
69//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
70//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   |
71//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
72//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
73//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
74//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       |
75//! |               |               | `       ` |           |               |           |
76//! | `TSlice<4>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Optimized |
77//! |               |               | `       ` |           |               |           |
78//! | `&[f32]`      | `USlice<1>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   |
79//! | `&[f32]`      | `USlice<2>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   |
80//! | `&[f32]`      | `USlice<3>`   | `MV<f32>` | Fallback  | No            | Uses V3   |
81//! | `&[f32]`      | `USlice<4>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   |
82//! | `&[f32]`      | `USlice<5>`   | `MV<f32>` | Fallback  | No            | Uses V3   |
83//! | `&[f32]`      | `USlice<6>`   | `MV<f32>` | Fallback  | No            | Uses V3   |
84//! | `&[f32]`      | `USlice<7>`   | `MV<f32>` | Fallback  | No            | Uses V3   |
85//! | `&[f32]`      | `USlice<8>`   | `MV<f32>` | Fallback  | No            | Uses V3   |
86//!
87//! ### Squared L2
88//!
89//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 |
90//! |---------------|---------------|-----------|-----------|---------------|-----------|
91//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   |
92//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   |
93//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
94//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   |
95//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
96//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
97//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   |
98//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       |
99//!
100//! ### Hamming
101//!
102//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 |
103//! |---------------|---------------|-----------|-----------|---------------|-----------|
104//! | `BSlice`      | `BSlice`      | `MV<u32>` | Optimized | Optimized     | Uses V3   |
105
106use diskann_vector::PureDistanceFunction;
107use diskann_wide::{ARCH, Architecture, arch::Target2};
108#[cfg(target_arch = "x86_64")]
109use diskann_wide::{
110    SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
111};
112
113use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
114use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
115
116// Convenience alias.
117type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
118
119/// Retarget the [`diskann_wide::arch::x86_64::V3`] architecture to
120/// [`diskann_wide::arch::Scalar`] or [`diskann_wide::arch::x86_64::V4`] to V3 etc.
121#[cfg(target_arch = "x86_64")]
122macro_rules! retarget {
123    ($arch:path, $op:ty, $N:literal) => {
124        impl Target2<
125            $arch,
126            MathematicalResult<u32>,
127            USlice<'_, $N>,
128            USlice<'_, $N>,
129        > for $op {
130            #[inline(always)]
131            fn run(
132                self,
133                arch: $arch,
134                x: USlice<'_, $N>,
135                y: USlice<'_, $N>
136            ) -> MathematicalResult<u32> {
137                self.run(arch.retarget(), x, y)
138            }
139        }
140    };
141    ($arch:path, $op:ty, $($N:literal),+ $(,)?) => {
142        $(retarget!($arch, $op, $N);)+
143    }
144}
145
146/// Impledment [`diskann_vector::PureDistanceFunction`] using the current compilation architecture
147macro_rules! dispatch_pure {
148    ($op:ty, $N:literal) => {
149        /// Compute the squared L2 distance between `x` and `y`.
150        impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $N>, MathematicalResult<u32>> for $op {
151            #[inline(always)]
152            fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $N>) -> MathematicalResult<u32> {
153                (diskann_wide::ARCH).run2(Self, x, y)
154            }
155        }
156    };
157    ($op:ty, $($N:literal),+ $(,)?) => {
158        $(dispatch_pure!($op, $N);)+
159    }
160}
161
162/// Load 1 byte beginning at `ptr` and invoke `f` with that byte.
163///
164/// # Safety
165///
166/// * The memory range `[ptr, ptr + 1)` (in bytes) must be dereferencable.
167/// * `ptr` does not need to be aligned.
168#[cfg(target_arch = "x86_64")]
169unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
170where
171    F: FnMut(u32) -> R,
172{
173    // SAFETY: Caller asserts that one byte is readable.
174    f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
175}
176
177/// Load 2 bytes beginning at `ptr` and invoke `f` with the value.
178///
179/// # Safety
180///
181/// * The memory range `[ptr, ptr + 2)` (in bytes) must be dereferencable.
182/// * `ptr` does not need to be aligned.
183#[cfg(target_arch = "x86_64")]
184unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
185where
186    F: FnMut(u32) -> R,
187{
188    // SAFETY: Caller asserts that two bytes are readable.
189    f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
190}
191
192/// Load 3 bytes beginning at `ptr` and invoke `f` with the value.
193///
194/// # Safety
195///
196/// * The memory range `[ptr, ptr + 3)` (in bytes) must be dereferencable.
197/// * `ptr` does not need to be aligned.
198#[cfg(target_arch = "x86_64")]
199unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
200where
201    F: FnMut(u32) -> R,
202{
203    // SAFETY: Caller asserts that three bytes are readable. This loads the first two.
204    let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
205    // SAFETY: Caller asserts that three bytes are readable. This loads the third.
206    let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
207    f(lo | hi << 16)
208}
209
210/// Load 4 bytes beginning at `ptr` and invoke `f` with the value.
211///
212/// # Safety
213///
214/// * The memory range `[ptr, ptr + 4)` (in bytes) must be dereferencable.
215/// * `ptr` does not need to be aligned.
216#[cfg(target_arch = "x86_64")]
217unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
218where
219    F: FnMut(u32) -> R,
220{
221    // SAFETY: Caller asserts that four bytes are readable.
222    f(unsafe { ptr.read_unaligned() })
223}
224
225////////////////////////////
226// Distances on BitSlices //
227////////////////////////////
228
229/// Operations to apply to 1-bit encodings.
230///
231/// The general structure of 1-bit vector operations is the same, but the element wise
232/// operator is different. This trait encapsulates the differences in behavior required
233/// for different distance function.
234///
235/// The exact operations to apply depending on the representation of the bit encoding.
236trait BitVectorOp<Repr>
237where
238    Repr: Representation<1>,
239{
240    /// Apply the op to all bits in the 64-bit arguments.
241    fn on_u64(x: u64, y: u64) -> u32;
242
243    /// Apply the op to all bits in the 8-bit arguments.
244    ///
245    /// NOTE: Implementations must have the correct behavior when the upper bits of `x`
246    /// and `y` are set to 0 when handling epilogues.
247    fn on_u8(x: u8, y: u8) -> u32;
248}
249
250/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
251impl BitVectorOp<Unsigned> for SquaredL2 {
252    #[inline(always)]
253    fn on_u64(x: u64, y: u64) -> u32 {
254        (x ^ y).count_ones()
255    }
256    #[inline(always)]
257    fn on_u8(x: u8, y: u8) -> u32 {
258        (x ^ y).count_ones()
259    }
260}
261
262/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
263impl BitVectorOp<Binary> for Hamming {
264    #[inline(always)]
265    fn on_u64(x: u64, y: u64) -> u32 {
266        (x ^ y).count_ones()
267    }
268    #[inline(always)]
269    fn on_u8(x: u8, y: u8) -> u32 {
270        (x ^ y).count_ones()
271    }
272}
273
274/// The implementation as `and` is not straight-forward.
275///
276/// Recall that scalar quantization encodings are unsigned, so "0" is zero and "1" is some
277/// non-zero value.
278///
279/// When computing the inner product, `0 * x == 0` for all `x` and only `x * x` has a
280/// non-zero value. Therefore, the elementwise op is an `and` and not `xnor`.
281impl BitVectorOp<Unsigned> for InnerProduct {
282    #[inline(always)]
283    fn on_u64(x: u64, y: u64) -> u32 {
284        (x & y).count_ones()
285    }
286    #[inline(always)]
287    fn on_u8(x: u8, y: u8) -> u32 {
288        (x & y).count_ones()
289    }
290}
291
292/// A general algorithm for applying a bitwise operand to two dense bit vectors of equal
293/// but arbitrary length.
294///
295/// NOTE: The `inline(always)` attribute is required to inheret the caller's target-features.
296#[inline(always)]
297fn bitvector_op<Op, Repr>(
298    x: BitSlice<'_, 1, Repr>,
299    y: BitSlice<'_, 1, Repr>,
300) -> MathematicalResult<u32>
301where
302    Repr: Representation<1>,
303    Op: BitVectorOp<Repr>,
304{
305    let len = check_lengths!(x, y)?;
306
307    let px: *const u64 = x.as_ptr().cast();
308    let py: *const u64 = y.as_ptr().cast();
309
310    let mut i = 0;
311    let mut s: u32 = 0;
312
313    // Work in groups of 64
314    let blocks = len / 64;
315    while i < blocks {
316        // SAFETY: We know at least 64-bits (8-bytes) are valid from this offset (by
317        // guarantee of the `BitSlice`). All bit-patterns of a `u64` are valid, `u64: Copy`,
318        // and an `unaligned` read is used.
319        let vx = unsafe { px.add(i).read_unaligned() };
320
321        // SAFETY: The same logic applies to `y` because:
322        // 1. It has the same type as `x`.
323        // 2. We've verified that it has the same length as `x`.
324        let vy = unsafe { py.add(i).read_unaligned() };
325
326        s += Op::on_u64(vx, vy);
327        i += 1;
328    }
329
330    // Work in groups of 8
331    i *= 8;
332    let px: *const u8 = x.as_ptr();
333    let py: *const u8 = y.as_ptr();
334
335    let blocks = len / 8;
336    while i < blocks {
337        // SAFETY: The underlying pointer is a `*const u8` and we have checked that this
338        // offset is within the bounds of the slice underlying the bitslice.
339        let vx = unsafe { px.add(i).read_unaligned() };
340
341        // SAFETY: The same logic applies to `y` because:
342        // 1. It has the same type as `x`.
343        // 2. We've verified that it has the same length as `x`.
344        let vy = unsafe { py.add(i).read_unaligned() };
345        s += Op::on_u8(vx, vy);
346        i += 1;
347    }
348
349    if i * 8 != len {
350        // SAFETY: The underlying slice is readable in the range
351        // `[px, px + floor(len / 8) + 1)`. This accesses `px + floor(len / 8)`.
352        let vx = unsafe { px.add(i).read_unaligned() };
353
354        // SAFETY: Same as above.
355        let vy = unsafe { py.add(i).read_unaligned() };
356        let m = (0x01u8 << (len - 8 * i)) - 1;
357
358        s += Op::on_u8(vx & m, vy & m)
359    }
360    Ok(MV::new(s))
361}
362
363/// Compute the hamming distance between `x` and `y`.
364///
365/// Returns an error if the arguments have different lengths.
366impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
367    for Hamming
368{
369    fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
370        bitvector_op::<Hamming, Binary>(x, y)
371    }
372}
373
374///////////////
375// SquaredL2 //
376///////////////
377
378/// Compute the squared L2 distance between `x` and `y`.
379///
380/// Returns an error if the arguments have different lengths.
381///
382/// # Implementation Notes
383///
384/// This can directly invoke the methods implemented in `vector` because
385/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
386impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
387where
388    A: Architecture,
389    diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
390{
391    #[inline(always)]
392    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
393        check_lengths!(x, y)?;
394
395        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
396            diskann_vector::distance::SquaredL2 {},
397            arch,
398            x.as_slice(),
399            y.as_slice(),
400        );
401
402        Ok(MV::new(r.into_inner() as u32))
403    }
404}
405
406/// Compute the squared L2 distance between `x` and `y`.
407///
408/// Returns an error if the arguments have different lengths.
409///
410/// # Implementation Notes
411///
412/// This implementation is optimized around x86 with the AVX2 vector extension.
413/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
414/// hit the `_mm256_madd_epi16` intrinsic.
415///
416/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
417/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
418/// This works because we need to apply the same shift to all lanes.
419#[cfg(target_arch = "x86_64")]
420impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
421    for SquaredL2
422{
423    #[inline(always)]
424    fn run(
425        self,
426        arch: diskann_wide::arch::x86_64::V3,
427        x: USlice<'_, 4>,
428        y: USlice<'_, 4>,
429    ) -> MathematicalResult<u32> {
430        let len = check_lengths!(x, y)?;
431
432        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
433        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
434        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
435
436        let px_u32: *const u32 = x.as_ptr().cast();
437        let py_u32: *const u32 = y.as_ptr().cast();
438
439        let mut i = 0;
440        let mut s: u32 = 0;
441
442        // The number of 32-bit blocks over the underlying slice.
443        let blocks = len / 8;
444        if i < blocks {
445            let mut s0 = i32s::default(arch);
446            let mut s1 = i32s::default(arch);
447            let mut s2 = i32s::default(arch);
448            let mut s3 = i32s::default(arch);
449            let mask = u32s::splat(arch, 0x000f000f);
450            while i + 8 < blocks {
451                // SAFETY: We have checked that `i + 8 < blocks` which means the address
452                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
453                //
454                // The load has no alignment requirements.
455                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
456
457                // SAFETY: The same logic applies to `y` because:
458                // 1. It has the same type as `x`.
459                // 2. We've verified that it has the same length as `x`.
460                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
461
462                let wx: i16s = (vx & mask).reinterpret_simd();
463                let wy: i16s = (vy & mask).reinterpret_simd();
464                let d = wx - wy;
465                s0 = s0.dot_simd(d, d);
466
467                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
468                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
469                let d = wx - wy;
470                s1 = s1.dot_simd(d, d);
471
472                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
473                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
474                let d = wx - wy;
475                s2 = s2.dot_simd(d, d);
476
477                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
478                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
479                let d = wx - wy;
480                s3 = s3.dot_simd(d, d);
481
482                i += 8;
483            }
484
485            let remainder = blocks - i;
486
487            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
488            // at offset `i`. The exact number is computed as `remainder`.
489            //
490            // The predicated load is guaranteed not to access memory after `remainder` and
491            // has no alignment requirements.
492            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
493
494            // SAFETY: The same logic applies to `y` because:
495            // 1. It has the same type as `x`.
496            // 2. We've verified that it has the same length as `x`.
497            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
498
499            let wx: i16s = (vx & mask).reinterpret_simd();
500            let wy: i16s = (vy & mask).reinterpret_simd();
501            let d = wx - wy;
502            s0 = s0.dot_simd(d, d);
503
504            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
505            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
506            let d = wx - wy;
507            s1 = s1.dot_simd(d, d);
508
509            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
510            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
511            let d = wx - wy;
512            s2 = s2.dot_simd(d, d);
513
514            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
515            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
516            let d = wx - wy;
517            s3 = s3.dot_simd(d, d);
518
519            i += remainder;
520
521            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
522        }
523
524        // Convert blocks to indexes.
525        i *= 8;
526
527        // Deal with the remainder the slow way.
528        if i != len {
529            // Outline the fallback routine to keep code-generation at this level cleaner.
530            #[inline(never)]
531            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
532                let mut s: i32 = 0;
533                for i in from..x.len() {
534                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
535                    let ix = unsafe { x.get_unchecked(i) } as i32;
536                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
537                    let iy = unsafe { y.get_unchecked(i) } as i32;
538                    let d = ix - iy;
539                    s += d * d;
540                }
541                s as u32
542            }
543            s += fallback(x, y, i);
544        }
545
546        Ok(MV::new(s))
547    }
548}
549
550/// Compute the squared L2 distance between `x` and `y`.
551///
552/// Returns an error if the arguments have different lengths.
553///
554/// # Implementation Notes
555///
556/// This implementation is optimized around x86 with the AVX2 vector extension.
557/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
558/// hit the `_mm256_madd_epi16` intrinsic.
559///
560/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
561/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
562/// This works because we need to apply the same shift to all lanes.
563#[cfg(target_arch = "x86_64")]
564impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
565    for SquaredL2
566{
567    #[inline(always)]
568    fn run(
569        self,
570        arch: diskann_wide::arch::x86_64::V3,
571        x: USlice<'_, 2>,
572        y: USlice<'_, 2>,
573    ) -> MathematicalResult<u32> {
574        let len = check_lengths!(x, y)?;
575
576        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
577        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
578        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
579
580        let px_u32: *const u32 = x.as_ptr().cast();
581        let py_u32: *const u32 = y.as_ptr().cast();
582
583        let mut i = 0;
584        let mut s: u32 = 0;
585
586        // The number of 32-bit blocks over the underlying slice.
587        let blocks = len / 16;
588        if i < blocks {
589            let mut s0 = i32s::default(arch);
590            let mut s1 = i32s::default(arch);
591            let mut s2 = i32s::default(arch);
592            let mut s3 = i32s::default(arch);
593            let mask = u32s::splat(arch, 0x00030003);
594            while i + 8 < blocks {
595                // SAFETY: We have checked that `i + 8 < blocks` which means the address
596                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
597                //
598                // The load has no alignment requirements.
599                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
600
601                // SAFETY: The same logic applies to `y` because:
602                // 1. It has the same type as `x`.
603                // 2. We've verified that it has the same length as `x`.
604                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
605
606                let wx: i16s = (vx & mask).reinterpret_simd();
607                let wy: i16s = (vy & mask).reinterpret_simd();
608                let d = wx - wy;
609                s0 = s0.dot_simd(d, d);
610
611                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
612                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
613                let d = wx - wy;
614                s1 = s1.dot_simd(d, d);
615
616                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
617                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
618                let d = wx - wy;
619                s2 = s2.dot_simd(d, d);
620
621                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
622                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
623                let d = wx - wy;
624                s3 = s3.dot_simd(d, d);
625
626                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
627                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
628                let d = wx - wy;
629                s0 = s0.dot_simd(d, d);
630
631                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
632                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
633                let d = wx - wy;
634                s1 = s1.dot_simd(d, d);
635
636                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
637                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
638                let d = wx - wy;
639                s2 = s2.dot_simd(d, d);
640
641                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
642                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
643                let d = wx - wy;
644                s3 = s3.dot_simd(d, d);
645
646                i += 8;
647            }
648
649            let remainder = blocks - i;
650
651            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
652            // at offset `i`. The exact number is computed as `remainder`.
653            //
654            // The predicated load is guaranteed not to access memory after `remainder` and
655            // has no alignment requirements.
656            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
657
658            // SAFETY: The same logic applies to `y` because:
659            // 1. It has the same type as `x`.
660            // 2. We've verified that it has the same length as `x`.
661            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
662            let wx: i16s = (vx & mask).reinterpret_simd();
663            let wy: i16s = (vy & mask).reinterpret_simd();
664            let d = wx - wy;
665            s0 = s0.dot_simd(d, d);
666
667            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
668            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
669            let d = wx - wy;
670            s1 = s1.dot_simd(d, d);
671
672            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
673            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
674            let d = wx - wy;
675            s2 = s2.dot_simd(d, d);
676
677            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
678            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
679            let d = wx - wy;
680            s3 = s3.dot_simd(d, d);
681
682            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
683            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
684            let d = wx - wy;
685            s0 = s0.dot_simd(d, d);
686
687            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
688            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
689            let d = wx - wy;
690            s1 = s1.dot_simd(d, d);
691
692            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
693            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
694            let d = wx - wy;
695            s2 = s2.dot_simd(d, d);
696
697            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
698            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
699            let d = wx - wy;
700            s3 = s3.dot_simd(d, d);
701
702            i += remainder;
703
704            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
705        }
706
707        // Convert blocks to indexes.
708        i *= 16;
709
710        // Deal with the remainder the slow way.
711        if i != len {
712            // Outline the fallback routine to keep code-generation at this level cleaner.
713            #[inline(never)]
714            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
715                let mut s: i32 = 0;
716                for i in from..x.len() {
717                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
718                    let ix = unsafe { x.get_unchecked(i) } as i32;
719                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
720                    let iy = unsafe { y.get_unchecked(i) } as i32;
721                    let d = ix - iy;
722                    s += d * d;
723                }
724                s as u32
725            }
726            s += fallback(x, y, i);
727        }
728
729        Ok(MV::new(s))
730    }
731}
732
733/// Compute the squared L2 distance between bitvectors `x` and `y`.
734///
735/// Returns an error if the arguments have different lengths.
736impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
737where
738    A: Architecture,
739{
740    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
741        bitvector_op::<Self, Unsigned>(x, y)
742    }
743}
744
745/// An implementation for L2 distance that uses scalar indexing for the implementation.
746macro_rules! impl_fallback_l2 {
747    ($N:literal) => {
748        /// Compute the squared L2 distance between `x` and `y`.
749        ///
750        /// Returns an error if the arguments have different lengths.
751        ///
752        /// # Performance
753        ///
754        /// This function uses a generic implementation and therefore is not very fast.
755        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
756            #[inline(never)]
757            fn run(
758                self,
759                _: diskann_wide::arch::Scalar,
760                x: USlice<'_, $N>,
761                y: USlice<'_, $N>
762            ) -> MathematicalResult<u32> {
763                let len = check_lengths!(x, y)?;
764
765                let mut accum: i32 = 0;
766                for i in 0..len {
767                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
768                    let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
769                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
770                    let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
771                    let diff = ix - iy;
772                    accum += diff * diff;
773                }
774                Ok(MV::new(accum as u32))
775            }
776        }
777    };
778    ($($N:literal),+ $(,)?) => {
779        $(impl_fallback_l2!($N);)+
780    };
781}
782
783impl_fallback_l2!(7, 6, 5, 4, 3, 2);
784
785#[cfg(target_arch = "x86_64")]
786retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
787
788#[cfg(target_arch = "x86_64")]
789retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
790
791dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
792
793///////////////////
794// Inner Product //
795///////////////////
796
797/// Compute the inner product between `x` and `y`.
798///
799/// Returns an error if the arguments have different lengths.
800///
801/// # Implementation Notes
802///
803/// This can directly invoke the methods implemented in `vector` because
804/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
805impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
806where
807    A: Architecture,
808    diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
809{
810    #[inline(always)]
811    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
812        check_lengths!(x, y)?;
813        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
814            diskann_vector::distance::InnerProduct {},
815            arch,
816            x.as_slice(),
817            y.as_slice(),
818        );
819
820        Ok(MV::new(r.into_inner() as u32))
821    }
822}
823
824/// Compute the inner product between `x` and `y`.
825///
826/// Returns an error if the arguments have different lengths.
827///
828/// # Implementation Notes
829///
830/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the
831/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with
832/// an `i32` accumulation vector.
833///
834/// One quirk of this instruction is that one argument must be unsigned and the other must
835/// be signed. Since thie kernsl works on 2-bit integers, this is not a limitation. Just
836/// something to be aware of.
837///
838/// Since AVX512 does not have an 8-bit shift instruction, we generally load data as
839/// `u32x16` (which has a native shift) and bit-cast it to `u8x64` as needed.
840#[cfg(target_arch = "x86_64")]
841impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
842    for InnerProduct
843{
844    #[expect(non_camel_case_types)]
845    #[inline(always)]
846    fn run(
847        self,
848        arch: diskann_wide::arch::x86_64::V4,
849        x: USlice<'_, 2>,
850        y: USlice<'_, 2>,
851    ) -> MathematicalResult<u32> {
852        let len = check_lengths!(x, y)?;
853
854        type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
855        type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
856        type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
857        type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
858
859        let px_u32: *const u32 = x.as_ptr().cast();
860        let py_u32: *const u32 = y.as_ptr().cast();
861
862        let mut i = 0;
863        let mut s: u32 = 0;
864
865        // The number of 32-bit blocks over the underlying slice.
866        let blocks = len.div_ceil(16);
867        if i < blocks {
868            let mut s0 = i32s::default(arch);
869            let mut s1 = i32s::default(arch);
870            let mut s2 = i32s::default(arch);
871            let mut s3 = i32s::default(arch);
872            let mask = u32s::splat(arch, 0x03030303);
873            while i + 16 < blocks {
874                // SAFETY: We have checked that `i + 16 < blocks` which means the address
875                // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::<u32>())` is valid.
876                //
877                // The load has no alignment requirements.
878                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
879
880                // SAFETY: The same logic applies to `y` because:
881                // 1. It has the same type as `x`.
882                // 2. We've verified that it has the same length as `x`.
883                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
884
885                let wx: u8s = (vx & mask).reinterpret_simd();
886                let wy: i8s = (vy & mask).reinterpret_simd();
887                s0 = s0.dot_simd(wx, wy);
888
889                let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
890                let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
891                s1 = s1.dot_simd(wx, wy);
892
893                let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
894                let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
895                s2 = s2.dot_simd(wx, wy);
896
897                let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
898                let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
899                s3 = s3.dot_simd(wx, wy);
900
901                i += 16;
902            }
903
904            // Here
905            // * `len / 4` gives the number of full bytes
906            // * `4 * i` gives the number of bytes processed.
907            let remainder = len / 4 - 4 * i;
908
909            // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`.
910            //
911            // The predicated load is guaranteed not to access memory after `remainder` and
912            // has no alignment requirements.
913            let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
914            let vx: u32s = vx.reinterpret_simd();
915
916            // SAFETY: The same logic applies to `y` because:
917            // 1. It has the same type as `x`.
918            // 2. We've verified that it has the same length as `x`.
919            let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
920            let vy: u32s = vy.reinterpret_simd();
921
922            let wx: u8s = (vx & mask).reinterpret_simd();
923            let wy: i8s = (vy & mask).reinterpret_simd();
924            s0 = s0.dot_simd(wx, wy);
925
926            let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
927            let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
928            s1 = s1.dot_simd(wx, wy);
929
930            let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
931            let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
932            s2 = s2.dot_simd(wx, wy);
933
934            let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
935            let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
936            s3 = s3.dot_simd(wx, wy);
937
938            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
939            i = (4 * i) + remainder;
940        }
941
942        // Convert blocks to indexes.
943        i *= 4;
944
945        // Deal with the remainder the slow way.
946        debug_assert!(len - i <= 3);
947        let rest = (len - i).min(3);
948        if i != len {
949            for j in 0..rest {
950                // SAFETY: `i` is guaranteed to be less than `x.len()`.
951                let ix = unsafe { x.get_unchecked(i + j) } as u32;
952                // SAFETY: `i` is guaranteed to be less than `y.len()`.
953                let iy = unsafe { y.get_unchecked(i + j) } as u32;
954                s += ix * iy;
955            }
956        }
957
958        Ok(MV::new(s))
959    }
960}
961
962/// Compute the inner product between `x` and `y`.
963///
964/// Returns an error if the arguments have different lengths.
965///
966/// # Implementation Notes
967///
968/// This implementation is optimized around x86 with the AVX2 vector extension.
969/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
970/// hit the `_mm256_madd_epi16` intrinsic.
971///
972/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
973/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
974/// This works because we need to apply the same shift to all lanes.
975#[cfg(target_arch = "x86_64")]
976impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
977    for InnerProduct
978{
979    #[inline(always)]
980    fn run(
981        self,
982        arch: diskann_wide::arch::x86_64::V3,
983        x: USlice<'_, 4>,
984        y: USlice<'_, 4>,
985    ) -> MathematicalResult<u32> {
986        let len = check_lengths!(x, y)?;
987
988        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
989        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
990        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
991
992        let px_u32: *const u32 = x.as_ptr().cast();
993        let py_u32: *const u32 = y.as_ptr().cast();
994
995        let mut i = 0;
996        let mut s: u32 = 0;
997
998        let blocks = len / 8;
999        if i < blocks {
1000            let mut s0 = i32s::default(arch);
1001            let mut s1 = i32s::default(arch);
1002            let mut s2 = i32s::default(arch);
1003            let mut s3 = i32s::default(arch);
1004            let mask = u32s::splat(arch, 0x000f000f);
1005            while i + 8 < blocks {
1006                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1007                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1008                //
1009                // The load has no alignment requirements.
1010                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1011
1012                // SAFETY: The same logic applies to `y` because:
1013                // 1. It has the same type as `x`.
1014                // 2. We've verified that it has the same length as `x`.
1015                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1016
1017                let wx: i16s = (vx & mask).reinterpret_simd();
1018                let wy: i16s = (vy & mask).reinterpret_simd();
1019                s0 = s0.dot_simd(wx, wy);
1020
1021                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1022                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1023                s1 = s1.dot_simd(wx, wy);
1024
1025                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1026                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1027                s2 = s2.dot_simd(wx, wy);
1028
1029                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1030                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1031                s3 = s3.dot_simd(wx, wy);
1032
1033                i += 8;
1034            }
1035
1036            let remainder = blocks - i;
1037
1038            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1039            // at offset `i`. The exact number is computed as `remainder`.
1040            //
1041            // The predicated load is guaranteed not to access memory after `remainder` and
1042            // has no alignment requirements.
1043            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1044
1045            // SAFETY: The same logic applies to `y` because:
1046            // 1. It has the same type as `x`.
1047            // 2. We've verified that it has the same length as `x`.
1048            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1049
1050            let wx: i16s = (vx & mask).reinterpret_simd();
1051            let wy: i16s = (vy & mask).reinterpret_simd();
1052            s0 = s0.dot_simd(wx, wy);
1053
1054            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1055            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1056            s1 = s1.dot_simd(wx, wy);
1057
1058            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1059            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1060            s2 = s2.dot_simd(wx, wy);
1061
1062            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1063            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1064            s3 = s3.dot_simd(wx, wy);
1065
1066            i += remainder;
1067
1068            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1069        }
1070
1071        // Convert blocks to indexes.
1072        i *= 8;
1073
1074        // Deal with the remainder the slow way.
1075        if i != len {
1076            // Outline the fallback routine to keep code-generation at this level cleaner.
1077            #[inline(never)]
1078            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1079                let mut s: u32 = 0;
1080                for i in from..x.len() {
1081                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1082                    let ix = unsafe { x.get_unchecked(i) } as u32;
1083                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1084                    let iy = unsafe { y.get_unchecked(i) } as u32;
1085                    s += ix * iy;
1086                }
1087                s
1088            }
1089            s += fallback(x, y, i);
1090        }
1091
1092        Ok(MV::new(s))
1093    }
1094}
1095
1096/// Compute the inner product between `x` and `y`.
1097///
1098/// Returns an error if the arguments have different lengths.
1099///
1100/// # Implementation Notes
1101///
1102/// This implementation is optimized around x86 with the AVX2 vector extension.
1103/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
1104/// hit the `_mm256_madd_epi16` intrinsic.
1105///
1106/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
1107/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
1108/// This works because we need to apply the same shift to all lanes.
1109#[cfg(target_arch = "x86_64")]
1110impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1111    for InnerProduct
1112{
1113    #[inline(always)]
1114    fn run(
1115        self,
1116        arch: diskann_wide::arch::x86_64::V3,
1117        x: USlice<'_, 2>,
1118        y: USlice<'_, 2>,
1119    ) -> MathematicalResult<u32> {
1120        let len = check_lengths!(x, y)?;
1121
1122        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1123        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1124        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1125
1126        let px_u32: *const u32 = x.as_ptr().cast();
1127        let py_u32: *const u32 = y.as_ptr().cast();
1128
1129        let mut i = 0;
1130        let mut s: u32 = 0;
1131
1132        // The number of 32-bit blocks over the underlying slice.
1133        let blocks = len / 16;
1134        if i < blocks {
1135            let mut s0 = i32s::default(arch);
1136            let mut s1 = i32s::default(arch);
1137            let mut s2 = i32s::default(arch);
1138            let mut s3 = i32s::default(arch);
1139            let mask = u32s::splat(arch, 0x00030003);
1140            while i + 8 < blocks {
1141                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1142                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1143                //
1144                // The load has no alignment requirements.
1145                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1146
1147                // SAFETY: The same logic applies to `y` because:
1148                // 1. It has the same type as `x`.
1149                // 2. We've verified that it has the same length as `x`.
1150                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1151
1152                let wx: i16s = (vx & mask).reinterpret_simd();
1153                let wy: i16s = (vy & mask).reinterpret_simd();
1154                s0 = s0.dot_simd(wx, wy);
1155
1156                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1157                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1158                s1 = s1.dot_simd(wx, wy);
1159
1160                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1161                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1162                s2 = s2.dot_simd(wx, wy);
1163
1164                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1165                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1166                s3 = s3.dot_simd(wx, wy);
1167
1168                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1169                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1170                s0 = s0.dot_simd(wx, wy);
1171
1172                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1173                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1174                s1 = s1.dot_simd(wx, wy);
1175
1176                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1177                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1178                s2 = s2.dot_simd(wx, wy);
1179
1180                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1181                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1182                s3 = s3.dot_simd(wx, wy);
1183
1184                i += 8;
1185            }
1186
1187            let remainder = blocks - i;
1188
1189            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1190            // at offset `i`. The exact number is computed as `remainder`.
1191            //
1192            // The predicated load is guaranteed not to access memory after `remainder` and
1193            // has no alignment requirements.
1194            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1195
1196            // SAFETY: The same logic applies to `y` because:
1197            // 1. It has the same type as `x`.
1198            // 2. We've verified that it has the same length as `x`.
1199            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1200            let wx: i16s = (vx & mask).reinterpret_simd();
1201            let wy: i16s = (vy & mask).reinterpret_simd();
1202            s0 = s0.dot_simd(wx, wy);
1203
1204            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1205            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1206            s1 = s1.dot_simd(wx, wy);
1207
1208            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1209            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1210            s2 = s2.dot_simd(wx, wy);
1211
1212            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1213            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1214            s3 = s3.dot_simd(wx, wy);
1215
1216            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1217            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1218            s0 = s0.dot_simd(wx, wy);
1219
1220            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1221            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1222            s1 = s1.dot_simd(wx, wy);
1223
1224            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1225            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1226            s2 = s2.dot_simd(wx, wy);
1227
1228            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1229            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1230            s3 = s3.dot_simd(wx, wy);
1231
1232            i += remainder;
1233
1234            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1235        }
1236
1237        // Convert blocks to indexes.
1238        i *= 16;
1239
1240        // Deal with the remainder the slow way.
1241        if i != len {
1242            // Outline the fallback routine to keep code-generation at this level cleaner.
1243            #[inline(never)]
1244            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1245                let mut s: u32 = 0;
1246                for i in from..x.len() {
1247                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1248                    let ix = unsafe { x.get_unchecked(i) } as u32;
1249                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1250                    let iy = unsafe { y.get_unchecked(i) } as u32;
1251                    s += ix * iy;
1252                }
1253                s
1254            }
1255            s += fallback(x, y, i);
1256        }
1257
1258        Ok(MV::new(s))
1259    }
1260}
1261
1262/// Compute the inner product between bitvectors `x` and `y`.
1263///
1264/// Returns an error if the arguments have different lengths.
1265impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1266where
1267    A: Architecture,
1268{
1269    #[inline(always)]
1270    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1271        bitvector_op::<Self, Unsigned>(x, y)
1272    }
1273}
1274
1275/// An implementation for inner products that uses scalar indexing for the implementation.
1276macro_rules! impl_fallback_ip {
1277    ($N:literal) => {
1278        /// Compute the inner product between `x` and `y`.
1279        ///
1280        /// Returns an error if the arguments have different lengths.
1281        ///
1282        /// # Performance
1283        ///
1284        /// This function uses a generic implementation and therefore is not very fast.
1285        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for InnerProduct {
1286            #[inline(never)]
1287            fn run(
1288                self,
1289                _: diskann_wide::arch::Scalar,
1290                x: USlice<'_, $N>,
1291                y: USlice<'_, $N>
1292            ) -> MathematicalResult<u32> {
1293                let len = check_lengths!(x, y)?;
1294
1295                let mut accum: u32 = 0;
1296                for i in 0..len {
1297                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1298                    let ix = unsafe { x.get_unchecked(i) } as u32;
1299                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1300                    let iy = unsafe { y.get_unchecked(i) } as u32;
1301                    accum += ix * iy;
1302                }
1303                Ok(MV::new(accum))
1304            }
1305        }
1306    };
1307    ($($N:literal),+ $(,)?) => {
1308        $(impl_fallback_ip!($N);)+
1309    };
1310}
1311
1312impl_fallback_ip!(7, 6, 5, 4, 3, 2);
1313
1314#[cfg(target_arch = "x86_64")]
1315retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1316
1317#[cfg(target_arch = "x86_64")]
1318retarget!(diskann_wide::arch::x86_64::V4, InnerProduct, 7, 6, 4, 5, 3);
1319
1320dispatch_pure!(InnerProduct, 1, 2, 3, 4, 5, 6, 7, 8);
1321
1322//////////////////
1323// BitTranspose //
1324//////////////////
1325
1326/// The strategy is to compute the inner product `<x, y>` by decomposing the problem into
1327/// groups of 64-dimensions.
1328///
1329/// For each group, we load the 64-bits of `y` into a word `bits`. And the four 64-bit words
1330/// of the group in `x` in `b0`, `b1`, b2`, and `b3`.
1331///
1332/// Note that bit `i` in `b0` is bit-0 of the `i`-th value in ths group. Likewise, bit `i`
1333/// in `b1` is bit-1 of the same word.
1334///
1335/// This means that we can compute the partial inner product for this group as
1336/// ```math
1337/// (bits & b0).count_ones()                // Contribution of bit 0
1338///     + 2 * (bits & b1).count_ones()      // Contribution of bit 1
1339///     + 4 * (bits & b2).count_ones()      // Contribution of bit 2
1340///     + 8 * (bits & b3).count_ones()      // Contribution of bit 3
1341/// ```
1342/// We process as many full groups as we can.
1343///
1344/// To handle the remainder, we need to be careful about acessing `y` because `BitSlice`
1345/// only guarantees the validity of reads at the byte level. That is - we cannot assume that
1346/// a full 64-bit read is valid.
1347///
1348/// The bit-tranposed `x`, on the other hand, guarantees allocations in blocks of
1349/// 4 * 64-bits, so it can be treated as normal.
1350impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1351    for InnerProduct
1352where
1353    A: Architecture,
1354{
1355    #[inline(always)]
1356    fn run(
1357        self,
1358        _: A,
1359        x: USlice<'_, 4, BitTranspose>,
1360        y: USlice<'_, 1, Dense>,
1361    ) -> MathematicalResult<u32> {
1362        let len = check_lengths!(x, y)?;
1363
1364        // We work in blocks of 64 element.
1365        //
1366        // The `BitTranspose` guarantees read are valid in blocks of 64 elements (32 byte).
1367        // However, the `Dense` representation only pads to bytes.
1368        // Our strategy for dealing with fewer than 64 remaining elements is to reconstruct
1369        // a 64-bit integer from bytes.
1370        let px: *const u64 = x.as_ptr().cast();
1371        let py: *const u64 = y.as_ptr().cast();
1372
1373        let mut i = 0;
1374        let mut s: u32 = 0;
1375
1376        let blocks = len / 64;
1377        while i < blocks {
1378            // SAFETY: `y` is valid for at least `blocks` 64-bit reads and `i < blocks`.
1379            let bits = unsafe { py.add(i).read_unaligned() };
1380
1381            // SAFETY: The layout for `x` is grouped into 32-byte blocks. We've ensured that
1382            // the lengths of the two vectors are the same, so we know that `x` has at least
1383            // `blocks` such regions.
1384            //
1385            // This loads the first 64-bits of block `i` where `i < blocks`.
1386            let b0 = unsafe { px.add(4 * i).read_unaligned() };
1387            s += (bits & b0).count_ones();
1388
1389            // SAFETY: This loads the second 64-bit word of block `i`.
1390            let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1391            s += (bits & b1).count_ones() << 1;
1392
1393            // SAFETY: This loads the third 64-bit word of block `i`.
1394            let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1395            s += (bits & b2).count_ones() << 2;
1396
1397            // SAFETY: This loads the fourth 64-bit word of block `i`.
1398            let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1399            s += (bits & b3).count_ones() << 3;
1400
1401            i += 1;
1402        }
1403
1404        // If the input length is a multiple of 64 - then we're done.
1405        if 64 * i == len {
1406            return Ok(MV::new(s));
1407        }
1408
1409        // Convert blocks to bytes.
1410        let k = i * 8;
1411
1412        // Unpack the last elements from the bit-vector.
1413        //
1414        // SAFETY: The length of the 1-bit BitSlice is `ceil(len / 8)`. This computation
1415        // effectively computes `ceil((64 * floor(len / 64)) / 8)`, which is less.
1416        let py = unsafe { py.cast::<u8>().add(k) };
1417        let bytes_remaining = y.bytes() - k;
1418        let mut bits: u64 = 0;
1419
1420        // Code - generation: Applying `min(8)` gives a constant upper-bound to the
1421        // compiler, allowing better code-generation.
1422        for j in 0..bytes_remaining.min(8) {
1423            // SAFETY: Starting at `py`, there are `bytes_remaining` valid bytes. This
1424            // accesses all of them.
1425            bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1426        }
1427
1428        // Because the upper-bits of the last loaded byte can contain indeterminate bits,
1429        // we must mask out all out-of-bounds bits.
1430        bits &= (0x01u64 << (len - (64 * i))) - 1;
1431
1432        // Combine with the remainders.
1433        //
1434        // SAFETY: The `BitTranspose` permutation always allocates in granularies of blocks.
1435        // This loads the first 64-bit word of the last block.
1436        let b0 = unsafe { px.add(4 * i).read_unaligned() };
1437        s += (bits & b0).count_ones();
1438
1439        // SAFETY: This loads the second 64-bit word of the last block.
1440        let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1441        s += (bits & b1).count_ones() << 1;
1442
1443        // SAFETY: This loads the third 64-bit word of the last block.
1444        let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1445        s += (bits & b2).count_ones() << 2;
1446
1447        // SAFETY: This loads the fourth 64-bit word of the last block.
1448        let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1449        s += (bits & b3).count_ones() << 3;
1450
1451        Ok(MV::new(s))
1452    }
1453}
1454
1455impl
1456    PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1457    for InnerProduct
1458{
1459    fn evaluate(
1460        x: USlice<'_, 4, BitTranspose>,
1461        y: USlice<'_, 1, Dense>,
1462    ) -> MathematicalResult<u32> {
1463        (diskann_wide::ARCH).run2(Self, x, y)
1464    }
1465}
1466
1467////////////////////
1468// Full Precision //
1469////////////////////
1470
1471/// The main trick here is avoiding explicit conversion from 1 bit integers to 32-bit
1472/// floating-point numbers by using `_mm256_permutevar_ps`, which performs a shuffle on two
1473/// independent 128-bit lanes of `f32` values in a register `A` using the lower 2-bits of
1474/// each 32-bit integer in a register `B`.
1475///
1476/// Importantly, this instruction only takes a single cycle and we can avoid any kind of
1477/// masking. Going the route of conversion would require and `AND` operation to isolate
1478/// bottom bits and a somewhat lengthy 32-bit integer to `f32` conversion instruction.
1479///
1480/// The overall strategy broadcasts a 32-bit integer (consisting of 32, 1-bit values) across
1481/// 8 lanes into a register `A`.
1482///
1483/// Each lane is then shifted by a different amount so:
1484///
1485/// * Lane 0 has value 0 as its least significant bit (LSB)
1486/// * Lane 1 has value 1 as its LSB.
1487/// * Lane 2 has value 2 as its LSB.
1488/// * etc.
1489///
1490/// These LSB's are used to power the shuffle function to convert to `f32` values (either
1491/// 0.0 or 1.0) and we can FMA as needed.
1492///
1493/// To process the next group of 8 values, we shift all lanes in `A` by 8-bits so lane 0
1494/// has value 8 as its LSB, lane 1 has value 9 etc.
1495///
1496/// A total of three shifts are applied to extract all 32 1-bit value as `f32` in order.
1497#[cfg(target_arch = "x86_64")]
1498impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
1499    for InnerProduct
1500{
1501    #[inline(always)]
1502    fn run(
1503        self,
1504        arch: diskann_wide::arch::x86_64::V3,
1505        x: &[f32],
1506        y: USlice<'_, 1>,
1507    ) -> MathematicalResult<f32> {
1508        let len = check_lengths!(x, y)?;
1509
1510        use std::arch::x86_64::*;
1511
1512        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1513        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1514
1515        // Replicate 0s and 1s so we effectively get a shuffle that only depends on the
1516        // bottom bit (instead of the lowest 2).
1517        let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1518
1519        // Shifts required to offset each lane.
1520        let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
1521
1522        let px: *const f32 = x.as_ptr();
1523        let py: *const u32 = y.as_ptr().cast();
1524
1525        let mut i = 0;
1526        let mut s = f32s::default(arch);
1527
1528        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1529        let to_f32 = |v: u32s| -> f32s {
1530            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
1531            // which the presence of the `x86_64::V3` architecture guarantees is available.
1532            f32s::from_underlying(arch, unsafe {
1533                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1534            })
1535        };
1536
1537        // Data is processed in groups of 32 elements.
1538        let blocks = len / 32;
1539        if i < blocks {
1540            let mut s0 = f32s::default(arch);
1541            let mut s1 = f32s::default(arch);
1542
1543            while i < blocks {
1544                // SAFETY: `i < blocks` implies 32-bits are readable from this offset.
1545                let iy = prep(unsafe { py.add(i).read_unaligned() });
1546
1547                // SAFETY: `i < blocks` implies 32 f32 values are readable beginning at `32*i`.
1548                let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
1549                // SAFETY: See above.
1550                let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
1551                // SAFETY: See above.
1552                let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
1553                // SAFETY: See above.
1554                let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
1555
1556                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1557                s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
1558                s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
1559                s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
1560
1561                i += 1;
1562            }
1563            s = s0 + s1;
1564        }
1565
1566        let remainder = len % 32;
1567        if remainder != 0 {
1568            let tail = if len % 8 == 0 { 8 } else { len % 8 };
1569
1570            // SAFETY: Because `remainder != 0`, there is valid memory beginning at the
1571            // offset `blocks`, so this addition remains within an allocated object.
1572            let py = unsafe { py.add(blocks) };
1573
1574            if remainder <= 8 {
1575                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
1576                // The same logic applies to the SIMD loads.
1577                unsafe {
1578                    load_one(py, |iy| {
1579                        let iy = prep(iy);
1580                        let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
1581                        s = ix.mul_add_simd(to_f32(iy), s);
1582                    })
1583                }
1584            } else if remainder <= 16 {
1585                // SAFETY: At least two bytes are readable for `py`.
1586                // The same logic applies to the SIMD loads.
1587                unsafe {
1588                    load_two(py, |iy| {
1589                        let iy = prep(iy);
1590                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1591                        let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
1592                        s = ix0.mul_add_simd(to_f32(iy), s);
1593                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1594                    })
1595                }
1596            } else if remainder <= 24 {
1597                // SAFETY: At least three bytes are readable for `py`.
1598                // The same logic applies to the SIMD loads.
1599                unsafe {
1600                    load_three(py, |iy| {
1601                        let iy = prep(iy);
1602
1603                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1604                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1605                        let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
1606
1607                        s = ix0.mul_add_simd(to_f32(iy), s);
1608                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1609                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1610                    })
1611                }
1612            } else {
1613                // SAFETY: At least four bytes are readable for `py`.
1614                // The same logic applies to the SIMD loads.
1615                unsafe {
1616                    load_four(py, |iy| {
1617                        let iy = prep(iy);
1618
1619                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1620                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1621                        let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
1622                        let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
1623
1624                        s = ix0.mul_add_simd(to_f32(iy), s);
1625                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1626                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1627                        s = ix3.mul_add_simd(to_f32(iy >> 24), s);
1628                    })
1629                }
1630            }
1631        }
1632
1633        Ok(MV::new(s.sum_tree()))
1634    }
1635}
1636
1637/// The strategy used here is almost identical to that used for 1-bit distances. The main
1638/// difference is that now we use the full 2-bit shuffle capabilities of `_mm256_permutevar_ps`
1639/// and ths relatives sizes of the shifts are slightly different.
1640#[cfg(target_arch = "x86_64")]
1641impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
1642    for InnerProduct
1643{
1644    #[inline(always)]
1645    fn run(
1646        self,
1647        arch: diskann_wide::arch::x86_64::V3,
1648        x: &[f32],
1649        y: USlice<'_, 2>,
1650    ) -> MathematicalResult<f32> {
1651        let len = check_lengths!(x, y)?;
1652
1653        use std::arch::x86_64::*;
1654
1655        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1656        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1657
1658        // This is the lookup table mapping 2-bit patterns to their equivalent `f32`
1659        // representation. The AVX2 shuffle only applies within each 128-bit group of the
1660        // full 256-bit register, so we replicate the contents.
1661        let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
1662
1663        // Shifts required to get logical dimensions shifted to the lower 2-bits of each lane.
1664        let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
1665
1666        let px: *const f32 = x.as_ptr();
1667        let py: *const u32 = y.as_ptr().cast();
1668
1669        let mut i = 0;
1670        let mut s = f32s::default(arch);
1671
1672        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1673        let to_f32 = |v: u32s| -> f32s {
1674            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
1675            // which the presense of the `x86_64::V3` architecture guarantees is available.
1676            f32s::from_underlying(arch, unsafe {
1677                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1678            })
1679        };
1680
1681        let blocks = len / 16;
1682        if blocks != 0 {
1683            let mut s0 = f32s::default(arch);
1684            let mut s1 = f32s::default(arch);
1685
1686            // Process 32 elements.
1687            while i + 2 <= blocks {
1688                // SAFETY: `i + 2 <= blocks` implies `py.add(i)` is in-bounds and readable
1689                // for 4 unaligned bytes.
1690                let iy = prep(unsafe { py.add(i).read_unaligned() });
1691
1692                // SAFETY: Same logic as above, just applied to `f32` values instead of
1693                // packed bits.
1694                let (ix0, ix1) = unsafe {
1695                    (
1696                        f32s::load_simd(arch, px.add(16 * i)),
1697                        f32s::load_simd(arch, px.add(16 * i + 8)),
1698                    )
1699                };
1700
1701                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1702                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1703
1704                // SAFETY: `i + 2 <= blocks` implies `py.add(i + 1)` is in-bounds and readable
1705                // for 4 unaligned bytes.
1706                let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
1707
1708                // SAFETY: Same logic as above.
1709                let (ix0, ix1) = unsafe {
1710                    (
1711                        f32s::load_simd(arch, px.add(16 * (i + 1))),
1712                        f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
1713                    )
1714                };
1715
1716                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1717                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1718
1719                i += 2;
1720            }
1721
1722            // Process 16 elements
1723            if i < blocks {
1724                // SAFETY: `i < blocks` implies `py.add(i)` is in-bounds and readable for
1725                // 4 unaligned bytes.
1726                let iy = prep(unsafe { py.add(i).read_unaligned() });
1727
1728                // SAFETY: Same logic as above.
1729                let (ix0, ix1) = unsafe {
1730                    (
1731                        f32s::load_simd(arch, px.add(16 * i)),
1732                        f32s::load_simd(arch, px.add(16 * i + 8)),
1733                    )
1734                };
1735
1736                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1737                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1738            }
1739
1740            s = s0 + s1;
1741        }
1742
1743        let remainder = len % 16;
1744        if remainder != 0 {
1745            let tail = if len % 8 == 0 { 8 } else { len % 8 };
1746            // SAFETY: Non-zero remainder implies there are readable bytes after the offset
1747            // `blocks`, so the addition is valid.
1748            let py = unsafe { py.add(blocks) };
1749
1750            if remainder <= 4 {
1751                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
1752                // The same logic applies to the SIMD loads.
1753                unsafe {
1754                    load_one(py, |iy| {
1755                        let iy = prep(iy);
1756                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1757                        s = ix.mul_add_simd(to_f32(iy), s);
1758                    });
1759                }
1760            } else if remainder <= 8 {
1761                // SAFETY: At least two bytes are readable for `py`.
1762                // The same logic applies to the SIMD loads.
1763                unsafe {
1764                    load_two(py, |iy| {
1765                        let iy = prep(iy);
1766                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1767                        s = ix.mul_add_simd(to_f32(iy), s);
1768                    });
1769                }
1770            } else if remainder <= 12 {
1771                // SAFETY: At least three bytes are readable for `py`.
1772                // The same logic applies to the SIMD loads.
1773                unsafe {
1774                    load_three(py, |iy| {
1775                        let iy = prep(iy);
1776                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1777                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1778                        s = ix0.mul_add_simd(to_f32(iy), s);
1779                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1780                    });
1781                }
1782            } else {
1783                // SAFETY: At least four bytes are readable for `py`.
1784                // The same logic applies to the SIMD loads.
1785                unsafe {
1786                    load_four(py, |iy| {
1787                        let iy = prep(iy);
1788                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1789                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1790                        s = ix0.mul_add_simd(to_f32(iy), s);
1791                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1792                    });
1793                }
1794            }
1795        }
1796
1797        Ok(MV::new(s.sum_tree()))
1798    }
1799}
1800
1801/// The strategy here is similar to the 1 and 2-bit strategies. However, instead of using
1802/// `_mm256_permutevar_ps`, we now go directly for 32-bit integer to 32-bit floating point.
1803///
1804/// This is because the shuffle intrinsic only supports 2-bit shuffles.
1805#[cfg(target_arch = "x86_64")]
1806impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
1807    for InnerProduct
1808{
1809    #[inline(always)]
1810    fn run(
1811        self,
1812        arch: diskann_wide::arch::x86_64::V3,
1813        x: &[f32],
1814        y: USlice<'_, 4>,
1815    ) -> MathematicalResult<f32> {
1816        let len = check_lengths!(x, y)?;
1817
1818        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1819        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1820
1821        let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
1822        let mask = i32s::splat(arch, 0x0f);
1823
1824        let to_f32 = |v: u32| -> f32s {
1825            ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
1826        };
1827
1828        let px: *const f32 = x.as_ptr();
1829        let py: *const u32 = y.as_ptr().cast();
1830
1831        let mut i = 0;
1832        let mut s = f32s::default(arch);
1833
1834        let blocks = len / 8;
1835        while i < blocks {
1836            // SAFETY: `i < blocks` implies that 8 `f32` values are readable from `8*i`.
1837            let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
1838            // SAFETY: Same logic as above - but applied to the packed bits.
1839            let iy = to_f32(unsafe { py.add(i).read_unaligned() });
1840            s = ix.mul_add_simd(iy, s);
1841
1842            i += 1;
1843        }
1844
1845        let remainder = len % 8;
1846        if remainder != 0 {
1847            let f = |iy| {
1848                // SAFETY: The epilogue handles at most 8 values. Since the remainder is
1849                // non-zero, the pointer arithmetic is in-bounds and `load_simd_first` will
1850                // avoid accessing the out-of-bounds elements.
1851                let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
1852                s = ix.mul_add_simd(to_f32(iy), s);
1853            };
1854
1855            // SAFETY: Non-zero remainder means there are readable bytes from the offset
1856            // `blocks`.
1857            let py = unsafe { py.add(blocks) };
1858
1859            if remainder <= 2 {
1860                // SAFETY: Non-zero remainder less than 2 implies that one byte is readable.
1861                unsafe { load_one(py, f) };
1862            } else if remainder <= 4 {
1863                // SAFETY: At least two bytes are readable from `py`.
1864                unsafe { load_two(py, f) };
1865            } else if remainder <= 6 {
1866                // SAFETY: At least three bytes are readable from `py`.
1867                unsafe { load_three(py, f) };
1868            } else {
1869                // SAFETY: At least four bytes are readable from `py`.
1870                unsafe { load_four(py, f) };
1871            }
1872        }
1873
1874        Ok(MV::new(s.sum_tree()))
1875    }
1876}
1877
1878impl<const N: usize>
1879    Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
1880    for InnerProduct
1881where
1882    Unsigned: Representation<N>,
1883{
1884    /// A fallback implementation that uses scaler indexing to retrieve values from
1885    /// the corresponding `BitSlice`.
1886    #[inline(always)]
1887    fn run(
1888        self,
1889        _: diskann_wide::arch::Scalar,
1890        x: &[f32],
1891        y: USlice<'_, N>,
1892    ) -> MathematicalResult<f32> {
1893        check_lengths!(x, y)?;
1894
1895        let mut s = 0.0;
1896        for (i, x) in x.iter().enumerate() {
1897            // SAFETY: We've ensured that `x.len() == y.len()`, so this access is
1898            // always inbounds.
1899            let y = unsafe { y.get_unchecked(i) } as f32;
1900            s += x * y;
1901        }
1902
1903        Ok(MV::new(s))
1904    }
1905}
1906
1907/// Implement `Target2` for higher architecture in terms of the scalar fallback.
1908#[cfg(target_arch = "x86_64")]
1909macro_rules! ip_retarget {
1910    ($arch:path, $N:literal) => {
1911        impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
1912            for InnerProduct
1913        {
1914            #[inline(always)]
1915            fn run(
1916                self,
1917                arch: $arch,
1918                x: &[f32],
1919                y: USlice<'_, $N>,
1920            ) -> MathematicalResult<f32> {
1921                self.run(arch.retarget(), x, y)
1922            }
1923        }
1924    };
1925    ($arch:path, $($Ns:literal),*) => {
1926        $(ip_retarget!($arch, $Ns);)*
1927    }
1928}
1929
1930#[cfg(target_arch = "x86_64")]
1931ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
1932
1933#[cfg(target_arch = "x86_64")]
1934ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
1935
1936/// Delegate the implementation of `PureDistanceFunction` to `diskann_wide::arch::Target2`
1937/// with the current architectures.
1938macro_rules! dispatch_full_ip {
1939    ($N:literal) => {
1940        /// Compute the inner product between `x` and `y`.
1941        ///
1942        /// Returns an error if the arguments have different lengths.
1943        impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
1944            for InnerProduct
1945        {
1946            fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
1947                Self.run(ARCH, x, y)
1948            }
1949        }
1950    };
1951    ($($Ns:literal),*) => {
1952        $(dispatch_full_ip!($Ns);)*
1953    }
1954}
1955
1956dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
1957
1958///////////
1959// Tests //
1960///////////
1961
1962#[cfg(test)]
1963mod tests {
1964    use std::{collections::HashMap, sync::LazyLock};
1965
1966    use diskann_utils::Reborrow;
1967    use rand::{
1968        Rng, SeedableRng,
1969        distr::{Distribution, Uniform},
1970        rngs::StdRng,
1971        seq::IndexedRandom,
1972    };
1973
1974    use super::*;
1975    use crate::bits::{BoxedBitSlice, Representation, Unsigned};
1976
1977    type MR = MathematicalResult<u32>;
1978
1979    /////////////////////////
1980    // Unsigned Bit Slices //
1981    /////////////////////////
1982
1983    // This test works by generating random integer codes for the compressed vectors,
1984    // then uses the functions implemented in `vector` to compute the expected result of
1985    // the computation in "full precision integer space".
1986    //
1987    // We verify that the exact same results are returned by each computation.
1988    fn test_bitslice_distances<const NBITS: usize, R>(
1989        dim_max: usize,
1990        trials_per_dim: usize,
1991        evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
1992        evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
1993        context: &str,
1994        rng: &mut R,
1995    ) where
1996        Unsigned: Representation<NBITS>,
1997        R: Rng,
1998    {
1999        let domain = Unsigned::domain_const::<NBITS>();
2000        let min: i64 = *domain.start();
2001        let max: i64 = *domain.end();
2002
2003        let dist = Uniform::new_inclusive(min, max).unwrap();
2004
2005        for dim in 0..dim_max {
2006            let mut x_reference: Vec<u8> = vec![0; dim];
2007            let mut y_reference: Vec<u8> = vec![0; dim];
2008
2009            let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2010            let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2011
2012            for trial in 0..trials_per_dim {
2013                x_reference
2014                    .iter_mut()
2015                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2016                y_reference
2017                    .iter_mut()
2018                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2019
2020                // Fill the input slices with 1's so we can catch situations where we don't
2021                // correctly handle odd remaining elements.
2022                x.as_mut_slice().fill(u8::MAX);
2023                y.as_mut_slice().fill(u8::MAX);
2024
2025                for i in 0..dim {
2026                    x.set(i, x_reference[i].into()).unwrap();
2027                    y.set(i, y_reference[i].into()).unwrap();
2028                }
2029
2030                // Check L2
2031                let expected: MV<f32> =
2032                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2033
2034                let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2035
2036                // Integer computations should be exact.
2037                assert_eq!(
2038                    expected.into_inner(),
2039                    got.into_inner() as f32,
2040                    "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2041                    NBITS,
2042                    dim,
2043                    trial,
2044                    context,
2045                );
2046
2047                // Check IP
2048                let expected: MV<f32> =
2049                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2050
2051                let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2052
2053                // Integer computations should be exact.
2054                assert_eq!(
2055                    expected.into_inner(),
2056                    got.into_inner() as f32,
2057                    "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2058                    NBITS,
2059                    dim,
2060                    trial,
2061                    context,
2062                );
2063            }
2064        }
2065
2066        // Test that we correctly return error types for length mismatches.
2067        let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2068        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2069
2070        assert!(
2071            evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2072            "context: {}",
2073            context
2074        );
2075        assert!(
2076            evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2077            "context: {}",
2078            context
2079        );
2080
2081        assert!(
2082            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2083            "context: {}",
2084            context
2085        );
2086        assert!(
2087            evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2088            "context: {}",
2089            context
2090        );
2091    }
2092
2093    cfg_if::cfg_if! {
2094        if #[cfg(miri)] {
2095            const MAX_DIM: usize = 128;
2096            const TRIALS_PER_DIM: usize = 1;
2097        } else {
2098            const MAX_DIM: usize = 256;
2099            const TRIALS_PER_DIM: usize = 20;
2100        }
2101    }
2102
2103    // For the bit-slice kernels, we want to use different maximum dimensions for the distance
2104    // test depending on the implementation of the kernel, and whether or not we are running
2105    // under Miri.
2106    //
2107    // For implementations that use the scalar fallback, we need not set very high bounds
2108    // (particularly when running under miri) because the implementations are quite simple.
2109    //
2110    // However, some SIMD kernels (especially for the lower bit widths), require higher bounds
2111    // to trigger all possible corner cases.
2112    static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2113        use ArchKey::{Scalar, X86_64_V3, X86_64_V4};
2114        [
2115            (Key::new(1, Scalar), Bounds::new(64, 64)),
2116            (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2117            (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2118            (Key::new(2, Scalar), Bounds::new(64, 64)),
2119            // Need a higher miri-amount due to the larget block size
2120            (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2121            (Key::new(2, X86_64_V4), Bounds::new(768, 600)), // main loop processes 256 items
2122            (Key::new(3, Scalar), Bounds::new(64, 64)),
2123            (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2124            (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2125            (Key::new(4, Scalar), Bounds::new(64, 64)),
2126            // Need a higher miri-amount due to the larget block size
2127            (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2128            (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2129            (Key::new(5, Scalar), Bounds::new(64, 64)),
2130            (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2131            (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2132            (Key::new(6, Scalar), Bounds::new(64, 64)),
2133            (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2134            (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2135            (Key::new(7, Scalar), Bounds::new(64, 64)),
2136            (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2137            (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2138            (Key::new(8, Scalar), Bounds::new(64, 64)),
2139            (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2140            (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2141        ]
2142        .into_iter()
2143        .collect()
2144    });
2145
2146    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2147    enum ArchKey {
2148        Scalar,
2149        #[expect(non_camel_case_types)]
2150        X86_64_V3,
2151        #[expect(non_camel_case_types)]
2152        X86_64_V4,
2153    }
2154
2155    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2156    struct Key {
2157        nbits: usize,
2158        arch: ArchKey,
2159    }
2160
2161    impl Key {
2162        fn new(nbits: usize, arch: ArchKey) -> Self {
2163            Self { nbits, arch }
2164        }
2165    }
2166
2167    #[derive(Debug, Clone, Copy)]
2168    struct Bounds {
2169        standard: usize,
2170        miri: usize,
2171    }
2172
2173    impl Bounds {
2174        fn new(standard: usize, miri: usize) -> Self {
2175            Self { standard, miri }
2176        }
2177
2178        fn get(&self) -> usize {
2179            if cfg!(miri) { self.miri } else { self.standard }
2180        }
2181    }
2182
2183    macro_rules! test_bitslice {
2184        ($name:ident, $nbits:literal, $seed:literal) => {
2185            #[test]
2186            fn $name() {
2187                let mut rng = StdRng::seed_from_u64($seed);
2188
2189                let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2190
2191                test_bitslice_distances::<$nbits, _>(
2192                    max_dim,
2193                    TRIALS_PER_DIM,
2194                    &|x, y| SquaredL2::evaluate(x, y),
2195                    &|x, y| InnerProduct::evaluate(x, y),
2196                    "pure distance function",
2197                    &mut rng,
2198                );
2199
2200                test_bitslice_distances::<$nbits, _>(
2201                    max_dim,
2202                    TRIALS_PER_DIM,
2203                    &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2204                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2205                    "scalar arch",
2206                    &mut rng,
2207                );
2208
2209                // Architecture Specific.
2210                #[cfg(target_arch = "x86_64")]
2211                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2212                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2213                    test_bitslice_distances::<$nbits, _>(
2214                        max_dim,
2215                        TRIALS_PER_DIM,
2216                        &|x, y| arch.run2(SquaredL2, x, y),
2217                        &|x, y| arch.run2(InnerProduct, x, y),
2218                        "x86-64-v3",
2219                        &mut rng,
2220                    );
2221                }
2222
2223                #[cfg(target_arch = "x86_64")]
2224                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2225                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2226                    test_bitslice_distances::<$nbits, _>(
2227                        max_dim,
2228                        TRIALS_PER_DIM,
2229                        &|x, y| arch.run2(SquaredL2, x, y),
2230                        &|x, y| arch.run2(InnerProduct, x, y),
2231                        "x86-64-v4",
2232                        &mut rng,
2233                    );
2234                }
2235            }
2236        };
2237    }
2238
2239    test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2240    test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2241    test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2242    test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2243    test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2244    test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2245    test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2246    test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2247
2248    ///////////////////////////
2249    // Hamming Bit Distances //
2250    ///////////////////////////
2251
2252    fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2253    where
2254        R: Rng,
2255    {
2256        let dist: [i8; 2] = [-1, 1];
2257
2258        for dim in 0..dim_max {
2259            let mut x_reference: Vec<i8> = vec![1; dim];
2260            let mut y_reference: Vec<i8> = vec![1; dim];
2261
2262            let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2263            let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2264
2265            for _ in 0..trials_per_dim {
2266                x_reference
2267                    .iter_mut()
2268                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2269                y_reference
2270                    .iter_mut()
2271                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2272
2273                // Fill the input slices with 1's so we can catch situations where we don't
2274                // correctly handle odd remaining elements.
2275                x.as_mut_slice().fill(u8::MAX);
2276                y.as_mut_slice().fill(u8::MAX);
2277
2278                for i in 0..dim {
2279                    x.set(i, x_reference[i].into()).unwrap();
2280                    y.set(i, y_reference[i].into()).unwrap();
2281                }
2282
2283                // We can check equality by evaluating the L2 distance between the reference
2284                // vectors.
2285                //
2286                // This is proportional to the Hamming distance by a factor of 4 (since the
2287                // distance betwwen +1 and -1 is 2 - and 2^2 = 4.
2288                let expected: MV<f32> =
2289                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2290                let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2291                assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2292            }
2293        }
2294
2295        let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2296        let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2297        assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2298        assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2299    }
2300
2301    #[test]
2302    fn test_hamming_distance() {
2303        let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2304        test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2305    }
2306
2307    ///////////////////
2308    // Heterogeneous //
2309    ///////////////////
2310
2311    fn test_bit_transpose_distances<R>(
2312        dim_max: usize,
2313        trials_per_dim: usize,
2314        evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2315        context: &str,
2316        rng: &mut R,
2317    ) where
2318        R: Rng,
2319    {
2320        let dist_4bit = {
2321            let domain = Unsigned::domain_const::<4>();
2322            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2323        };
2324
2325        let dist_1bit = {
2326            let domain = Unsigned::domain_const::<1>();
2327            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2328        };
2329
2330        for dim in 0..dim_max {
2331            let mut x_reference: Vec<u8> = vec![0; dim];
2332            let mut y_reference: Vec<u8> = vec![0; dim];
2333
2334            let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2335            let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2336
2337            for trial in 0..trials_per_dim {
2338                x_reference
2339                    .iter_mut()
2340                    .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2341                y_reference
2342                    .iter_mut()
2343                    .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2344
2345                // First - pre-set all the values in the bit-slices to 1.
2346                x.as_mut_slice().fill(u8::MAX);
2347                y.as_mut_slice().fill(u8::MAX);
2348
2349                for i in 0..dim {
2350                    x.set(i, x_reference[i].into()).unwrap();
2351                    y.set(i, y_reference[i].into()).unwrap();
2352                }
2353
2354                // Check IP
2355                let expected: MV<f32> =
2356                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2357
2358                let got = evaluate_ip(x.reborrow(), y.reborrow());
2359
2360                // Integer computations should be exact.
2361                assert_eq!(
2362                    expected.into_inner(),
2363                    got.unwrap().into_inner() as f32,
2364                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
2365                    dim,
2366                    trial,
2367                    context,
2368                );
2369            }
2370        }
2371
2372        let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2373        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2374        assert!(
2375            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2376            "context: {}",
2377            context
2378        );
2379
2380        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2381        assert!(
2382            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2383            "context: {}",
2384            context
2385        );
2386    }
2387
2388    #[test]
2389    fn test_bit_transpose_distance() {
2390        let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2391
2392        test_bit_transpose_distances(
2393            MAX_DIM,
2394            TRIALS_PER_DIM,
2395            &|x, y| InnerProduct::evaluate(x, y),
2396            "pure distance function",
2397            &mut rng,
2398        );
2399
2400        test_bit_transpose_distances(
2401            MAX_DIM,
2402            TRIALS_PER_DIM,
2403            &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2404            "scalar",
2405            &mut rng,
2406        );
2407
2408        // Architecture Specific.
2409        #[cfg(target_arch = "x86_64")]
2410        if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2411            test_bit_transpose_distances(
2412                MAX_DIM,
2413                TRIALS_PER_DIM,
2414                &|x, y| arch.run2(InnerProduct, x, y),
2415                "x86-64-v3",
2416                &mut rng,
2417            );
2418        }
2419
2420        // Architecture Specific.
2421        #[cfg(target_arch = "x86_64")]
2422        if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2423            test_bit_transpose_distances(
2424                MAX_DIM,
2425                TRIALS_PER_DIM,
2426                &|x, y| arch.run2(InnerProduct, x, y),
2427                "x86-64-v4",
2428                &mut rng,
2429            );
2430        }
2431    }
2432
2433    //////////
2434    // Full //
2435    //////////
2436
2437    fn test_full_distances<const NBITS: usize>(
2438        dim_max: usize,
2439        trials_per_dim: usize,
2440        evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
2441        context: &str,
2442        rng: &mut impl Rng,
2443    ) where
2444        Unsigned: Representation<NBITS>,
2445    {
2446        // let dist_float = Uniform::new_inclusive(-2.0f32, 2.0f32).unwrap();
2447        let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
2448        let dist_bit = {
2449            let domain = Unsigned::domain_const::<NBITS>();
2450            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2451        };
2452
2453        for dim in 0..dim_max {
2454            let mut x: Vec<f32> = vec![0.0; dim];
2455
2456            let mut y_reference: Vec<u8> = vec![0; dim];
2457            let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
2458
2459            for trial in 0..trials_per_dim {
2460                x.iter_mut()
2461                    .for_each(|i| *i = *dist_float.choose(rng).unwrap());
2462                y_reference
2463                    .iter_mut()
2464                    .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
2465
2466                // First - pre-set all the values in the bit-slices to 1.
2467                y.as_mut_slice().fill(u8::MAX);
2468
2469                let mut expected = 0.0;
2470                for i in 0..dim {
2471                    y.set(i, y_reference[i].into()).unwrap();
2472                    expected += y_reference[i] as f32 * x[i];
2473                }
2474
2475                // Check IP
2476                let got = evaluate_ip(&x, y.reborrow()).unwrap();
2477
2478                // Integer computations should be exact.
2479                assert_eq!(
2480                    expected,
2481                    got.into_inner(),
2482                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
2483                    dim,
2484                    trial,
2485                    context,
2486                );
2487
2488                // Ensure that using the `Scalar` architecture providers the same
2489                // results.
2490                let scalar: MV<f32> = InnerProduct
2491                    .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
2492                    .unwrap();
2493                assert_eq!(got.into_inner(), scalar.into_inner());
2494            }
2495        }
2496
2497        // Error Checking
2498        let x = vec![0.0; 10];
2499        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2500        assert!(
2501            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2502            "context: {}",
2503            context
2504        );
2505
2506        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
2507        assert!(
2508            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2509            "context: {}",
2510            context
2511        );
2512    }
2513
2514    macro_rules! test_full {
2515        ($name:ident, $nbits:literal, $seed:literal) => {
2516            #[test]
2517            fn $name() {
2518                let mut rng = StdRng::seed_from_u64($seed);
2519
2520                test_full_distances::<$nbits>(
2521                    MAX_DIM,
2522                    TRIALS_PER_DIM,
2523                    &|x, y| InnerProduct::evaluate(x, y),
2524                    "pure distance function",
2525                    &mut rng,
2526                );
2527
2528                test_full_distances::<$nbits>(
2529                    MAX_DIM,
2530                    TRIALS_PER_DIM,
2531                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2532                    "scalar",
2533                    &mut rng,
2534                );
2535
2536                // Architecture Specific.
2537                #[cfg(target_arch = "x86_64")]
2538                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2539                    test_full_distances::<$nbits>(
2540                        MAX_DIM,
2541                        TRIALS_PER_DIM,
2542                        &|x, y| arch.run2(InnerProduct, x, y),
2543                        "x86-64-v3",
2544                        &mut rng,
2545                    );
2546                }
2547
2548                #[cfg(target_arch = "x86_64")]
2549                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked() {
2550                    test_full_distances::<$nbits>(
2551                        MAX_DIM,
2552                        TRIALS_PER_DIM,
2553                        &|x, y| arch.run2(InnerProduct, x, y),
2554                        "x86-64-v4",
2555                        &mut rng,
2556                    );
2557                }
2558            }
2559        };
2560    }
2561
2562    test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
2563    test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
2564    test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
2565    test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
2566    test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
2567    test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
2568    test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
2569    test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
2570}