Skip to main content

diskann_quantization/bits/
distances.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! # Low-level functions
7//!
8//! The methods here are meant to be primitives used by the distance functions for the
9//! various scalar-quantized-like quantizers.
10//!
11//! As such, they typically return integer distance results since they largely operate over
12//! raw bit-slices.
13//!
14//! ## Micro-architecture Mapping
15//!
16//! There are two interfaces for interacting with the distance primitives:
17//!
18//! * [`diskann_wide::arch::Target2`]: A micro-architecture aware interface where the target
19//!   micro-architecture is provided as an explicit argument.
20//!
21//!   This can be used in conjunction with [`diskann_wide::Architecture::run2`] to apply the
22//!   necessary target-features to opt-into newer architecture code generation when
23//!   compiling the whole binary for an older architecture.
24//!
25//!   This interface is also composable with micro-architecture dispatching done higher in
26//!   the callstack, and so should be preferred when incorporating into quantizer distance
27//!   computations.
28//!
29//! * [`diskann_vector::PureDistanceFunction`]: If micro-architecture awareness is not needed,
30//!   this provides a simple interface targeting [`diskann_wide::ARCH`] (the current compilation
31//!   architecture).
32//!
33//!   This interface will always yield a binary compatible with the compilation architecture
34//!   target, but will not enable faster code-paths when compiling for older architectures.
35//!
36//! The following table summarizes the implementation status of kernels. All kernels have
37//! `diskann_wide::arch::Scalar` implementation fallbacks.
38//!
39//! Implementation Kind:
40//!
41//! * "Fallback": A fallback implementation using scalar indexing.
42//!
43//! * "Optimized": A better implementation than "fallback" that does not contain
44//!   target-depeendent code, instead relying on compiler optimizations.
45//!
46//!   Micro-architecture dispatch is still relevant as it allows the compiler to generate
47//!   better code for newer machines.
48//!
49//! * "Yes": Architecture specific SIMD implementation exists.
50//!
51//! * "No": Architecture specific implementation does not exist - the next most-specific
52//!   implementation is used. For example, if a `x86-64-v3` implementation does not exist,
53//!   then the "scalar" implementation will be used instead.
54//!
55//! Type Aliases
56//!
57//! * `USlice<N>`: `BitSlice<N, Unsigned, Dense>`
58//! * `TSlice<N>`: `BitSlice<N, Unsigned, BitTranspose>`
59//! * `BSlice`: `BitSlice<1, Binary, Dense>`
60//!
61//! * `MV<T>`: [`diskann_vector::MathematicalValue<T>`]
62//!
63//! ### Inner Product
64//!
65//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
66//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
67//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
68//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Yes       | Fallback  |
69//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
70//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
71//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
72//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
73//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
74//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       | Fallback  |
75//! |               |               | `       ` |           |               |           |           |
76//! | `TSlice<4>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Optimized | Optimized |
77//! |               |               | `       ` |           |               |           |
78//! | `&[f32]`      | `USlice<1>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
79//! | `&[f32]`      | `USlice<2>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
80//! | `&[f32]`      | `USlice<3>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
81//! | `&[f32]`      | `USlice<4>`   | `MV<f32>` | Fallback  | Yes           | Uses V3   | Fallback  |
82//! | `&[f32]`      | `USlice<5>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
83//! | `&[f32]`      | `USlice<6>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
84//! | `&[f32]`      | `USlice<7>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
85//! | `&[f32]`      | `USlice<8>`   | `MV<f32>` | Fallback  | No            | Uses V3   | Fallback  |
86//!
87//! ### Squared L2
88//!
89//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
90//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
91//! | `USlice<1>`   | `USlice<1>`   | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
92//! | `USlice<2>`   | `USlice<2>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
93//! | `USlice<3>`   | `USlice<3>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
94//! | `USlice<4>`   | `USlice<4>`   | `MV<u32>` | Fallback  | Yes           | Uses V3   | Fallback  |
95//! | `USlice<5>`   | `USlice<5>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
96//! | `USlice<6>`   | `USlice<6>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
97//! | `USlice<7>`   | `USlice<7>`   | `MV<u32>` | Fallback  | No            | Uses V3   | Fallback  |
98//! | `USlice<8>`   | `USlice<8>`   | `MV<u32>` | Yes       | Yes           | Yes       | Fallback  |
99//!
100//! ### Hamming
101//!
102//! | LHS           | RHS           | Result    | Scalar    | x86-64-v3     | x86-64-v4 | Neon      |
103//! |---------------|---------------|-----------|-----------|---------------|-----------|-----------|
104//! | `BSlice`      | `BSlice`      | `MV<u32>` | Optimized | Optimized     | Uses V3   | Optimized |
105
106use diskann_vector::PureDistanceFunction;
107use diskann_wide::{ARCH, Architecture, arch::Target2};
108#[cfg(target_arch = "x86_64")]
109use diskann_wide::{
110    SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
111};
112
113use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
114use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
115
116// Convenience alias.
117type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
118
119/// Retarget the architectures via the `retarget` inherent method.
120///
121/// * [`diskann_wide::arch::x86_64::V3`] -> [`diskann_wide::arch::Scalar`]
122/// * [`diskann_wide::arch::x86_64::V4`] -> [`diskann_wide::arch::x86_64::V3`]
123/// * [`diskann_wide::arch::aarch64::Neon`] -> [`diskann_wide::arch::Scalar`]
124macro_rules! retarget {
125    ($arch:path, $op:ty, $N:literal) => {
126        impl Target2<
127            $arch,
128            MathematicalResult<u32>,
129            USlice<'_, $N>,
130            USlice<'_, $N>,
131        > for $op {
132            #[inline(always)]
133            fn run(
134                self,
135                arch: $arch,
136                x: USlice<'_, $N>,
137                y: USlice<'_, $N>
138            ) -> MathematicalResult<u32> {
139                self.run(arch.retarget(), x, y)
140            }
141        }
142    };
143    ($arch:path, $op:ty, $($N:literal),+ $(,)?) => {
144        $(retarget!($arch, $op, $N);)+
145    }
146}
147
148/// Impledment [`diskann_vector::PureDistanceFunction`] using the current compilation architecture
149macro_rules! dispatch_pure {
150    ($op:ty, $N:literal) => {
151        /// Compute the squared L2 distance between `x` and `y`.
152        impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $N>, MathematicalResult<u32>> for $op {
153            #[inline(always)]
154            fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $N>) -> MathematicalResult<u32> {
155                (diskann_wide::ARCH).run2(Self, x, y)
156            }
157        }
158    };
159    ($op:ty, $($N:literal),+ $(,)?) => {
160        $(dispatch_pure!($op, $N);)+
161    }
162}
163
164/// Load 1 byte beginning at `ptr` and invoke `f` with that byte.
165///
166/// # Safety
167///
168/// * The memory range `[ptr, ptr + 1)` (in bytes) must be dereferencable.
169/// * `ptr` does not need to be aligned.
170#[cfg(target_arch = "x86_64")]
171unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
172where
173    F: FnMut(u32) -> R,
174{
175    // SAFETY: Caller asserts that one byte is readable.
176    f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
177}
178
179/// Load 2 bytes beginning at `ptr` and invoke `f` with the value.
180///
181/// # Safety
182///
183/// * The memory range `[ptr, ptr + 2)` (in bytes) must be dereferencable.
184/// * `ptr` does not need to be aligned.
185#[cfg(target_arch = "x86_64")]
186unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
187where
188    F: FnMut(u32) -> R,
189{
190    // SAFETY: Caller asserts that two bytes are readable.
191    f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
192}
193
194/// Load 3 bytes beginning at `ptr` and invoke `f` with the value.
195///
196/// # Safety
197///
198/// * The memory range `[ptr, ptr + 3)` (in bytes) must be dereferencable.
199/// * `ptr` does not need to be aligned.
200#[cfg(target_arch = "x86_64")]
201unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
202where
203    F: FnMut(u32) -> R,
204{
205    // SAFETY: Caller asserts that three bytes are readable. This loads the first two.
206    let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
207    // SAFETY: Caller asserts that three bytes are readable. This loads the third.
208    let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
209    f(lo | hi << 16)
210}
211
212/// Load 4 bytes beginning at `ptr` and invoke `f` with the value.
213///
214/// # Safety
215///
216/// * The memory range `[ptr, ptr + 4)` (in bytes) must be dereferencable.
217/// * `ptr` does not need to be aligned.
218#[cfg(target_arch = "x86_64")]
219unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
220where
221    F: FnMut(u32) -> R,
222{
223    // SAFETY: Caller asserts that four bytes are readable.
224    f(unsafe { ptr.read_unaligned() })
225}
226
227////////////////////////////
228// Distances on BitSlices //
229////////////////////////////
230
231/// Operations to apply to 1-bit encodings.
232///
233/// The general structure of 1-bit vector operations is the same, but the element wise
234/// operator is different. This trait encapsulates the differences in behavior required
235/// for different distance function.
236///
237/// The exact operations to apply depending on the representation of the bit encoding.
238trait BitVectorOp<Repr>
239where
240    Repr: Representation<1>,
241{
242    /// Apply the op to all bits in the 64-bit arguments.
243    fn on_u64(x: u64, y: u64) -> u32;
244
245    /// Apply the op to all bits in the 8-bit arguments.
246    ///
247    /// NOTE: Implementations must have the correct behavior when the upper bits of `x`
248    /// and `y` are set to 0 when handling epilogues.
249    fn on_u8(x: u8, y: u8) -> u32;
250}
251
252/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
253impl BitVectorOp<Unsigned> for SquaredL2 {
254    #[inline(always)]
255    fn on_u64(x: u64, y: u64) -> u32 {
256        (x ^ y).count_ones()
257    }
258    #[inline(always)]
259    fn on_u8(x: u8, y: u8) -> u32 {
260        (x ^ y).count_ones()
261    }
262}
263
264/// Computing Squared-L2 amounts to evaluating the pop-count of a bitwise `xor`.
265impl BitVectorOp<Binary> for Hamming {
266    #[inline(always)]
267    fn on_u64(x: u64, y: u64) -> u32 {
268        (x ^ y).count_ones()
269    }
270    #[inline(always)]
271    fn on_u8(x: u8, y: u8) -> u32 {
272        (x ^ y).count_ones()
273    }
274}
275
276/// The implementation as `and` is not straight-forward.
277///
278/// Recall that scalar quantization encodings are unsigned, so "0" is zero and "1" is some
279/// non-zero value.
280///
281/// When computing the inner product, `0 * x == 0` for all `x` and only `x * x` has a
282/// non-zero value. Therefore, the elementwise op is an `and` and not `xnor`.
283impl BitVectorOp<Unsigned> for InnerProduct {
284    #[inline(always)]
285    fn on_u64(x: u64, y: u64) -> u32 {
286        (x & y).count_ones()
287    }
288    #[inline(always)]
289    fn on_u8(x: u8, y: u8) -> u32 {
290        (x & y).count_ones()
291    }
292}
293
294/// A general algorithm for applying a bitwise operand to two dense bit vectors of equal
295/// but arbitrary length.
296///
297/// NOTE: The `inline(always)` attribute is required to inheret the caller's target-features.
298#[inline(always)]
299fn bitvector_op<Op, Repr>(
300    x: BitSlice<'_, 1, Repr>,
301    y: BitSlice<'_, 1, Repr>,
302) -> MathematicalResult<u32>
303where
304    Repr: Representation<1>,
305    Op: BitVectorOp<Repr>,
306{
307    let len = check_lengths!(x, y)?;
308
309    let px: *const u64 = x.as_ptr().cast();
310    let py: *const u64 = y.as_ptr().cast();
311
312    let mut i = 0;
313    let mut s: u32 = 0;
314
315    // Work in groups of 64
316    let blocks = len / 64;
317    while i < blocks {
318        // SAFETY: We know at least 64-bits (8-bytes) are valid from this offset (by
319        // guarantee of the `BitSlice`). All bit-patterns of a `u64` are valid, `u64: Copy`,
320        // and an `unaligned` read is used.
321        let vx = unsafe { px.add(i).read_unaligned() };
322
323        // SAFETY: The same logic applies to `y` because:
324        // 1. It has the same type as `x`.
325        // 2. We've verified that it has the same length as `x`.
326        let vy = unsafe { py.add(i).read_unaligned() };
327
328        s += Op::on_u64(vx, vy);
329        i += 1;
330    }
331
332    // Work in groups of 8
333    i *= 8;
334    let px: *const u8 = x.as_ptr();
335    let py: *const u8 = y.as_ptr();
336
337    let blocks = len / 8;
338    while i < blocks {
339        // SAFETY: The underlying pointer is a `*const u8` and we have checked that this
340        // offset is within the bounds of the slice underlying the bitslice.
341        let vx = unsafe { px.add(i).read_unaligned() };
342
343        // SAFETY: The same logic applies to `y` because:
344        // 1. It has the same type as `x`.
345        // 2. We've verified that it has the same length as `x`.
346        let vy = unsafe { py.add(i).read_unaligned() };
347        s += Op::on_u8(vx, vy);
348        i += 1;
349    }
350
351    if i * 8 != len {
352        // SAFETY: The underlying slice is readable in the range
353        // `[px, px + floor(len / 8) + 1)`. This accesses `px + floor(len / 8)`.
354        let vx = unsafe { px.add(i).read_unaligned() };
355
356        // SAFETY: Same as above.
357        let vy = unsafe { py.add(i).read_unaligned() };
358        let m = (0x01u8 << (len - 8 * i)) - 1;
359
360        s += Op::on_u8(vx & m, vy & m)
361    }
362    Ok(MV::new(s))
363}
364
365/// Compute the hamming distance between `x` and `y`.
366///
367/// Returns an error if the arguments have different lengths.
368impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
369    for Hamming
370{
371    fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
372        bitvector_op::<Hamming, Binary>(x, y)
373    }
374}
375
376///////////////
377// SquaredL2 //
378///////////////
379
380/// Compute the squared L2 distance between `x` and `y`.
381///
382/// Returns an error if the arguments have different lengths.
383///
384/// # Implementation Notes
385///
386/// This can directly invoke the methods implemented in `vector` because
387/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
388impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
389where
390    A: Architecture,
391    diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
392{
393    #[inline(always)]
394    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
395        check_lengths!(x, y)?;
396
397        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
398            diskann_vector::distance::SquaredL2 {},
399            arch,
400            x.as_slice(),
401            y.as_slice(),
402        );
403
404        Ok(MV::new(r.into_inner() as u32))
405    }
406}
407
408/// Compute the squared L2 distance between `x` and `y`.
409///
410/// Returns an error if the arguments have different lengths.
411///
412/// # Implementation Notes
413///
414/// This implementation is optimized around x86 with the AVX2 vector extension.
415/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
416/// hit the `_mm256_madd_epi16` intrinsic.
417///
418/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
419/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
420/// This works because we need to apply the same shift to all lanes.
421#[cfg(target_arch = "x86_64")]
422impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
423    for SquaredL2
424{
425    #[inline(always)]
426    fn run(
427        self,
428        arch: diskann_wide::arch::x86_64::V3,
429        x: USlice<'_, 4>,
430        y: USlice<'_, 4>,
431    ) -> MathematicalResult<u32> {
432        let len = check_lengths!(x, y)?;
433
434        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
435        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
436        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
437
438        let px_u32: *const u32 = x.as_ptr().cast();
439        let py_u32: *const u32 = y.as_ptr().cast();
440
441        let mut i = 0;
442        let mut s: u32 = 0;
443
444        // The number of 32-bit blocks over the underlying slice.
445        let blocks = len / 8;
446        if i < blocks {
447            let mut s0 = i32s::default(arch);
448            let mut s1 = i32s::default(arch);
449            let mut s2 = i32s::default(arch);
450            let mut s3 = i32s::default(arch);
451            let mask = u32s::splat(arch, 0x000f000f);
452            while i + 8 < blocks {
453                // SAFETY: We have checked that `i + 8 < blocks` which means the address
454                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
455                //
456                // The load has no alignment requirements.
457                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
458
459                // SAFETY: The same logic applies to `y` because:
460                // 1. It has the same type as `x`.
461                // 2. We've verified that it has the same length as `x`.
462                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
463
464                let wx: i16s = (vx & mask).reinterpret_simd();
465                let wy: i16s = (vy & mask).reinterpret_simd();
466                let d = wx - wy;
467                s0 = s0.dot_simd(d, d);
468
469                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
470                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
471                let d = wx - wy;
472                s1 = s1.dot_simd(d, d);
473
474                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
475                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
476                let d = wx - wy;
477                s2 = s2.dot_simd(d, d);
478
479                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
480                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
481                let d = wx - wy;
482                s3 = s3.dot_simd(d, d);
483
484                i += 8;
485            }
486
487            let remainder = blocks - i;
488
489            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
490            // at offset `i`. The exact number is computed as `remainder`.
491            //
492            // The predicated load is guaranteed not to access memory after `remainder` and
493            // has no alignment requirements.
494            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
495
496            // SAFETY: The same logic applies to `y` because:
497            // 1. It has the same type as `x`.
498            // 2. We've verified that it has the same length as `x`.
499            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
500
501            let wx: i16s = (vx & mask).reinterpret_simd();
502            let wy: i16s = (vy & mask).reinterpret_simd();
503            let d = wx - wy;
504            s0 = s0.dot_simd(d, d);
505
506            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
507            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
508            let d = wx - wy;
509            s1 = s1.dot_simd(d, d);
510
511            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
512            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
513            let d = wx - wy;
514            s2 = s2.dot_simd(d, d);
515
516            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
517            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
518            let d = wx - wy;
519            s3 = s3.dot_simd(d, d);
520
521            i += remainder;
522
523            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
524        }
525
526        // Convert blocks to indexes.
527        i *= 8;
528
529        // Deal with the remainder the slow way.
530        if i != len {
531            // Outline the fallback routine to keep code-generation at this level cleaner.
532            #[inline(never)]
533            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
534                let mut s: i32 = 0;
535                for i in from..x.len() {
536                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
537                    let ix = unsafe { x.get_unchecked(i) } as i32;
538                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
539                    let iy = unsafe { y.get_unchecked(i) } as i32;
540                    let d = ix - iy;
541                    s += d * d;
542                }
543                s as u32
544            }
545            s += fallback(x, y, i);
546        }
547
548        Ok(MV::new(s))
549    }
550}
551
552/// Compute the squared L2 distance between `x` and `y`.
553///
554/// Returns an error if the arguments have different lengths.
555///
556/// # Implementation Notes
557///
558/// This implementation is optimized around x86 with the AVX2 vector extension.
559/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
560/// hit the `_mm256_madd_epi16` intrinsic.
561///
562/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
563/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
564/// This works because we need to apply the same shift to all lanes.
565#[cfg(target_arch = "x86_64")]
566impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
567    for SquaredL2
568{
569    #[inline(always)]
570    fn run(
571        self,
572        arch: diskann_wide::arch::x86_64::V3,
573        x: USlice<'_, 2>,
574        y: USlice<'_, 2>,
575    ) -> MathematicalResult<u32> {
576        let len = check_lengths!(x, y)?;
577
578        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
579        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
580        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
581
582        let px_u32: *const u32 = x.as_ptr().cast();
583        let py_u32: *const u32 = y.as_ptr().cast();
584
585        let mut i = 0;
586        let mut s: u32 = 0;
587
588        // The number of 32-bit blocks over the underlying slice.
589        let blocks = len / 16;
590        if i < blocks {
591            let mut s0 = i32s::default(arch);
592            let mut s1 = i32s::default(arch);
593            let mut s2 = i32s::default(arch);
594            let mut s3 = i32s::default(arch);
595            let mask = u32s::splat(arch, 0x00030003);
596            while i + 8 < blocks {
597                // SAFETY: We have checked that `i + 8 < blocks` which means the address
598                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
599                //
600                // The load has no alignment requirements.
601                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
602
603                // SAFETY: The same logic applies to `y` because:
604                // 1. It has the same type as `x`.
605                // 2. We've verified that it has the same length as `x`.
606                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
607
608                let wx: i16s = (vx & mask).reinterpret_simd();
609                let wy: i16s = (vy & mask).reinterpret_simd();
610                let d = wx - wy;
611                s0 = s0.dot_simd(d, d);
612
613                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
614                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
615                let d = wx - wy;
616                s1 = s1.dot_simd(d, d);
617
618                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
619                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
620                let d = wx - wy;
621                s2 = s2.dot_simd(d, d);
622
623                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
624                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
625                let d = wx - wy;
626                s3 = s3.dot_simd(d, d);
627
628                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
629                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
630                let d = wx - wy;
631                s0 = s0.dot_simd(d, d);
632
633                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
634                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
635                let d = wx - wy;
636                s1 = s1.dot_simd(d, d);
637
638                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
639                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
640                let d = wx - wy;
641                s2 = s2.dot_simd(d, d);
642
643                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
644                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
645                let d = wx - wy;
646                s3 = s3.dot_simd(d, d);
647
648                i += 8;
649            }
650
651            let remainder = blocks - i;
652
653            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
654            // at offset `i`. The exact number is computed as `remainder`.
655            //
656            // The predicated load is guaranteed not to access memory after `remainder` and
657            // has no alignment requirements.
658            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
659
660            // SAFETY: The same logic applies to `y` because:
661            // 1. It has the same type as `x`.
662            // 2. We've verified that it has the same length as `x`.
663            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
664            let wx: i16s = (vx & mask).reinterpret_simd();
665            let wy: i16s = (vy & mask).reinterpret_simd();
666            let d = wx - wy;
667            s0 = s0.dot_simd(d, d);
668
669            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
670            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
671            let d = wx - wy;
672            s1 = s1.dot_simd(d, d);
673
674            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
675            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
676            let d = wx - wy;
677            s2 = s2.dot_simd(d, d);
678
679            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
680            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
681            let d = wx - wy;
682            s3 = s3.dot_simd(d, d);
683
684            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
685            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
686            let d = wx - wy;
687            s0 = s0.dot_simd(d, d);
688
689            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
690            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
691            let d = wx - wy;
692            s1 = s1.dot_simd(d, d);
693
694            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
695            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
696            let d = wx - wy;
697            s2 = s2.dot_simd(d, d);
698
699            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
700            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
701            let d = wx - wy;
702            s3 = s3.dot_simd(d, d);
703
704            i += remainder;
705
706            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
707        }
708
709        // Convert blocks to indexes.
710        i *= 16;
711
712        // Deal with the remainder the slow way.
713        if i != len {
714            // Outline the fallback routine to keep code-generation at this level cleaner.
715            #[inline(never)]
716            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
717                let mut s: i32 = 0;
718                for i in from..x.len() {
719                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
720                    let ix = unsafe { x.get_unchecked(i) } as i32;
721                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
722                    let iy = unsafe { y.get_unchecked(i) } as i32;
723                    let d = ix - iy;
724                    s += d * d;
725                }
726                s as u32
727            }
728            s += fallback(x, y, i);
729        }
730
731        Ok(MV::new(s))
732    }
733}
734
735/// Compute the squared L2 distance between bitvectors `x` and `y`.
736///
737/// Returns an error if the arguments have different lengths.
738impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
739where
740    A: Architecture,
741{
742    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
743        bitvector_op::<Self, Unsigned>(x, y)
744    }
745}
746
747/// An implementation for L2 distance that uses scalar indexing for the implementation.
748macro_rules! impl_fallback_l2 {
749    ($N:literal) => {
750        /// Compute the squared L2 distance between `x` and `y`.
751        ///
752        /// Returns an error if the arguments have different lengths.
753        ///
754        /// # Performance
755        ///
756        /// This function uses a generic implementation and therefore is not very fast.
757        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
758            #[inline(never)]
759            fn run(
760                self,
761                _: diskann_wide::arch::Scalar,
762                x: USlice<'_, $N>,
763                y: USlice<'_, $N>
764            ) -> MathematicalResult<u32> {
765                let len = check_lengths!(x, y)?;
766
767                let mut accum: i32 = 0;
768                for i in 0..len {
769                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
770                    let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
771                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
772                    let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
773                    let diff = ix - iy;
774                    accum += diff * diff;
775                }
776                Ok(MV::new(accum as u32))
777            }
778        }
779    };
780    ($($N:literal),+ $(,)?) => {
781        $(impl_fallback_l2!($N);)+
782    };
783}
784
785impl_fallback_l2!(7, 6, 5, 4, 3, 2);
786
787#[cfg(target_arch = "x86_64")]
788retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
789
790#[cfg(target_arch = "x86_64")]
791retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
792
793#[cfg(target_arch = "aarch64")]
794retarget!(
795    diskann_wide::arch::aarch64::Neon,
796    SquaredL2,
797    7,
798    6,
799    5,
800    4,
801    3,
802    2
803);
804
805dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
806
807///////////////////
808// Inner Product //
809///////////////////
810
811/// Compute the inner product between `x` and `y`.
812///
813/// Returns an error if the arguments have different lengths.
814///
815/// # Implementation Notes
816///
817/// This can directly invoke the methods implemented in `vector` because
818/// `BitSlice<'_, 8, Unsigned>` is isomorphic to `&[u8]`.
819impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
820where
821    A: Architecture,
822    diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
823{
824    #[inline(always)]
825    fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
826        check_lengths!(x, y)?;
827        let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
828            diskann_vector::distance::InnerProduct {},
829            arch,
830            x.as_slice(),
831            y.as_slice(),
832        );
833
834        Ok(MV::new(r.into_inner() as u32))
835    }
836}
837
838/// Compute the inner product between `x` and `y`.
839///
840/// Returns an error if the arguments have different lengths.
841///
842/// # Implementation Notes
843///
844/// This is optimized around the `__mm512_dpbusd_epi32` VNNI instruction, which computes the
845/// pairwise dot product between vectors of 8-bit integers and accumulates groups of 4 with
846/// an `i32` accumulation vector.
847///
848/// One quirk of this instruction is that one argument must be unsigned and the other must
849/// be signed. Since thie kernsl works on 2-bit integers, this is not a limitation. Just
850/// something to be aware of.
851///
852/// Since AVX512 does not have an 8-bit shift instruction, we generally load data as
853/// `u32x16` (which has a native shift) and bit-cast it to `u8x64` as needed.
854#[cfg(target_arch = "x86_64")]
855impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
856    for InnerProduct
857{
858    #[expect(non_camel_case_types)]
859    #[inline(always)]
860    fn run(
861        self,
862        arch: diskann_wide::arch::x86_64::V4,
863        x: USlice<'_, 2>,
864        y: USlice<'_, 2>,
865    ) -> MathematicalResult<u32> {
866        let len = check_lengths!(x, y)?;
867
868        type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
869        type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
870        type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
871        type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
872
873        let px_u32: *const u32 = x.as_ptr().cast();
874        let py_u32: *const u32 = y.as_ptr().cast();
875
876        let mut i = 0;
877        let mut s: u32 = 0;
878
879        // The number of 32-bit blocks over the underlying slice.
880        let blocks = len.div_ceil(16);
881        if i < blocks {
882            let mut s0 = i32s::default(arch);
883            let mut s1 = i32s::default(arch);
884            let mut s2 = i32s::default(arch);
885            let mut s3 = i32s::default(arch);
886            let mask = u32s::splat(arch, 0x03030303);
887            while i + 16 < blocks {
888                // SAFETY: We have checked that `i + 16 < blocks` which means the address
889                // range `[px_u32 + i, px_u32 + i + 16 * std::mem::size_of::<u32>())` is valid.
890                //
891                // The load has no alignment requirements.
892                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
893
894                // SAFETY: The same logic applies to `y` because:
895                // 1. It has the same type as `x`.
896                // 2. We've verified that it has the same length as `x`.
897                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
898
899                let wx: u8s = (vx & mask).reinterpret_simd();
900                let wy: i8s = (vy & mask).reinterpret_simd();
901                s0 = s0.dot_simd(wx, wy);
902
903                let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
904                let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
905                s1 = s1.dot_simd(wx, wy);
906
907                let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
908                let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
909                s2 = s2.dot_simd(wx, wy);
910
911                let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
912                let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
913                s3 = s3.dot_simd(wx, wy);
914
915                i += 16;
916            }
917
918            // Here
919            // * `len / 4` gives the number of full bytes
920            // * `4 * i` gives the number of bytes processed.
921            let remainder = len / 4 - 4 * i;
922
923            // SAFETY: At least `remainder` bytes are valid starting at an offset of `i`.
924            //
925            // The predicated load is guaranteed not to access memory after `remainder` and
926            // has no alignment requirements.
927            let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
928            let vx: u32s = vx.reinterpret_simd();
929
930            // SAFETY: The same logic applies to `y` because:
931            // 1. It has the same type as `x`.
932            // 2. We've verified that it has the same length as `x`.
933            let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
934            let vy: u32s = vy.reinterpret_simd();
935
936            let wx: u8s = (vx & mask).reinterpret_simd();
937            let wy: i8s = (vy & mask).reinterpret_simd();
938            s0 = s0.dot_simd(wx, wy);
939
940            let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
941            let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
942            s1 = s1.dot_simd(wx, wy);
943
944            let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
945            let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
946            s2 = s2.dot_simd(wx, wy);
947
948            let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
949            let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
950            s3 = s3.dot_simd(wx, wy);
951
952            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
953            i = (4 * i) + remainder;
954        }
955
956        // Convert blocks to indexes.
957        i *= 4;
958
959        // Deal with the remainder the slow way.
960        debug_assert!(len - i <= 3);
961        let rest = (len - i).min(3);
962        if i != len {
963            for j in 0..rest {
964                // SAFETY: `i` is guaranteed to be less than `x.len()`.
965                let ix = unsafe { x.get_unchecked(i + j) } as u32;
966                // SAFETY: `i` is guaranteed to be less than `y.len()`.
967                let iy = unsafe { y.get_unchecked(i + j) } as u32;
968                s += ix * iy;
969            }
970        }
971
972        Ok(MV::new(s))
973    }
974}
975
976/// Compute the inner product between `x` and `y`.
977///
978/// Returns an error if the arguments have different lengths.
979///
980/// # Implementation Notes
981///
982/// This implementation is optimized around x86 with the AVX2 vector extension.
983/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
984/// hit the `_mm256_madd_epi16` intrinsic.
985///
986/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
987/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
988/// This works because we need to apply the same shift to all lanes.
989#[cfg(target_arch = "x86_64")]
990impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
991    for InnerProduct
992{
993    #[inline(always)]
994    fn run(
995        self,
996        arch: diskann_wide::arch::x86_64::V3,
997        x: USlice<'_, 4>,
998        y: USlice<'_, 4>,
999    ) -> MathematicalResult<u32> {
1000        let len = check_lengths!(x, y)?;
1001
1002        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1003        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1004        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1005
1006        let px_u32: *const u32 = x.as_ptr().cast();
1007        let py_u32: *const u32 = y.as_ptr().cast();
1008
1009        let mut i = 0;
1010        let mut s: u32 = 0;
1011
1012        let blocks = len / 8;
1013        if i < blocks {
1014            let mut s0 = i32s::default(arch);
1015            let mut s1 = i32s::default(arch);
1016            let mut s2 = i32s::default(arch);
1017            let mut s3 = i32s::default(arch);
1018            let mask = u32s::splat(arch, 0x000f000f);
1019            while i + 8 < blocks {
1020                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1021                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1022                //
1023                // The load has no alignment requirements.
1024                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1025
1026                // SAFETY: The same logic applies to `y` because:
1027                // 1. It has the same type as `x`.
1028                // 2. We've verified that it has the same length as `x`.
1029                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1030
1031                let wx: i16s = (vx & mask).reinterpret_simd();
1032                let wy: i16s = (vy & mask).reinterpret_simd();
1033                s0 = s0.dot_simd(wx, wy);
1034
1035                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1036                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1037                s1 = s1.dot_simd(wx, wy);
1038
1039                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1040                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1041                s2 = s2.dot_simd(wx, wy);
1042
1043                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1044                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1045                s3 = s3.dot_simd(wx, wy);
1046
1047                i += 8;
1048            }
1049
1050            let remainder = blocks - i;
1051
1052            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1053            // at offset `i`. The exact number is computed as `remainder`.
1054            //
1055            // The predicated load is guaranteed not to access memory after `remainder` and
1056            // has no alignment requirements.
1057            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1058
1059            // SAFETY: The same logic applies to `y` because:
1060            // 1. It has the same type as `x`.
1061            // 2. We've verified that it has the same length as `x`.
1062            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1063
1064            let wx: i16s = (vx & mask).reinterpret_simd();
1065            let wy: i16s = (vy & mask).reinterpret_simd();
1066            s0 = s0.dot_simd(wx, wy);
1067
1068            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1069            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1070            s1 = s1.dot_simd(wx, wy);
1071
1072            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1073            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1074            s2 = s2.dot_simd(wx, wy);
1075
1076            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1077            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1078            s3 = s3.dot_simd(wx, wy);
1079
1080            i += remainder;
1081
1082            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1083        }
1084
1085        // Convert blocks to indexes.
1086        i *= 8;
1087
1088        // Deal with the remainder the slow way.
1089        if i != len {
1090            // Outline the fallback routine to keep code-generation at this level cleaner.
1091            #[inline(never)]
1092            fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
1093                let mut s: u32 = 0;
1094                for i in from..x.len() {
1095                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1096                    let ix = unsafe { x.get_unchecked(i) } as u32;
1097                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1098                    let iy = unsafe { y.get_unchecked(i) } as u32;
1099                    s += ix * iy;
1100                }
1101                s
1102            }
1103            s += fallback(x, y, i);
1104        }
1105
1106        Ok(MV::new(s))
1107    }
1108}
1109
1110/// Compute the inner product between `x` and `y`.
1111///
1112/// Returns an error if the arguments have different lengths.
1113///
1114/// # Implementation Notes
1115///
1116/// This implementation is optimized around x86 with the AVX2 vector extension.
1117/// Specifically, we try to hit `Wide::<i32, 8> as SIMDDotProduct<Wide<i16, 8>>` so we can
1118/// hit the `_mm256_madd_epi16` intrinsic.
1119///
1120/// Also note that AVX2 does not have 16-bit integer bit-shift instructions. Instead, we
1121/// have to use 32-bit integer shifts and then bit-cast to 16-bit intrinsics.
1122/// This works because we need to apply the same shift to all lanes.
1123#[cfg(target_arch = "x86_64")]
1124impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
1125    for InnerProduct
1126{
1127    #[inline(always)]
1128    fn run(
1129        self,
1130        arch: diskann_wide::arch::x86_64::V3,
1131        x: USlice<'_, 2>,
1132        y: USlice<'_, 2>,
1133    ) -> MathematicalResult<u32> {
1134        let len = check_lengths!(x, y)?;
1135
1136        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1137        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1138        diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
1139
1140        let px_u32: *const u32 = x.as_ptr().cast();
1141        let py_u32: *const u32 = y.as_ptr().cast();
1142
1143        let mut i = 0;
1144        let mut s: u32 = 0;
1145
1146        // The number of 32-bit blocks over the underlying slice.
1147        let blocks = len / 16;
1148        if i < blocks {
1149            let mut s0 = i32s::default(arch);
1150            let mut s1 = i32s::default(arch);
1151            let mut s2 = i32s::default(arch);
1152            let mut s3 = i32s::default(arch);
1153            let mask = u32s::splat(arch, 0x00030003);
1154            while i + 8 < blocks {
1155                // SAFETY: We have checked that `i + 8 < blocks` which means the address
1156                // range `[px_u32 + i, px_u32 + i + 8 * std::mem::size_of::<u32>())` is valid.
1157                //
1158                // The load has no alignment requirements.
1159                let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
1160
1161                // SAFETY: The same logic applies to `y` because:
1162                // 1. It has the same type as `x`.
1163                // 2. We've verified that it has the same length as `x`.
1164                let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
1165
1166                let wx: i16s = (vx & mask).reinterpret_simd();
1167                let wy: i16s = (vy & mask).reinterpret_simd();
1168                s0 = s0.dot_simd(wx, wy);
1169
1170                let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1171                let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1172                s1 = s1.dot_simd(wx, wy);
1173
1174                let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1175                let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1176                s2 = s2.dot_simd(wx, wy);
1177
1178                let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1179                let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1180                s3 = s3.dot_simd(wx, wy);
1181
1182                let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1183                let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1184                s0 = s0.dot_simd(wx, wy);
1185
1186                let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1187                let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1188                s1 = s1.dot_simd(wx, wy);
1189
1190                let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1191                let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1192                s2 = s2.dot_simd(wx, wy);
1193
1194                let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1195                let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1196                s3 = s3.dot_simd(wx, wy);
1197
1198                i += 8;
1199            }
1200
1201            let remainder = blocks - i;
1202
1203            // SAFETY: At least one value of type `u32` is valid for an unaligned starting
1204            // at offset `i`. The exact number is computed as `remainder`.
1205            //
1206            // The predicated load is guaranteed not to access memory after `remainder` and
1207            // has no alignment requirements.
1208            let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
1209
1210            // SAFETY: The same logic applies to `y` because:
1211            // 1. It has the same type as `x`.
1212            // 2. We've verified that it has the same length as `x`.
1213            let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
1214            let wx: i16s = (vx & mask).reinterpret_simd();
1215            let wy: i16s = (vy & mask).reinterpret_simd();
1216            s0 = s0.dot_simd(wx, wy);
1217
1218            let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
1219            let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
1220            s1 = s1.dot_simd(wx, wy);
1221
1222            let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
1223            let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
1224            s2 = s2.dot_simd(wx, wy);
1225
1226            let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
1227            let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
1228            s3 = s3.dot_simd(wx, wy);
1229
1230            let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
1231            let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
1232            s0 = s0.dot_simd(wx, wy);
1233
1234            let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
1235            let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
1236            s1 = s1.dot_simd(wx, wy);
1237
1238            let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
1239            let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
1240            s2 = s2.dot_simd(wx, wy);
1241
1242            let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
1243            let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
1244            s3 = s3.dot_simd(wx, wy);
1245
1246            i += remainder;
1247
1248            s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
1249        }
1250
1251        // Convert blocks to indexes.
1252        i *= 16;
1253
1254        // Deal with the remainder the slow way.
1255        if i != len {
1256            // Outline the fallback routine to keep code-generation at this level cleaner.
1257            #[inline(never)]
1258            fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
1259                let mut s: u32 = 0;
1260                for i in from..x.len() {
1261                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1262                    let ix = unsafe { x.get_unchecked(i) } as u32;
1263                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1264                    let iy = unsafe { y.get_unchecked(i) } as u32;
1265                    s += ix * iy;
1266                }
1267                s
1268            }
1269            s += fallback(x, y, i);
1270        }
1271
1272        Ok(MV::new(s))
1273    }
1274}
1275
1276/// Compute the inner product between bitvectors `x` and `y`.
1277///
1278/// Returns an error if the arguments have different lengths.
1279impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
1280where
1281    A: Architecture,
1282{
1283    #[inline(always)]
1284    fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
1285        bitvector_op::<Self, Unsigned>(x, y)
1286    }
1287}
1288
1289/// An implementation for inner products that uses scalar indexing for the implementation.
1290macro_rules! impl_fallback_ip {
1291    ($N:literal) => {
1292        /// Compute the inner product between `x` and `y`.
1293        ///
1294        /// Returns an error if the arguments have different lengths.
1295        ///
1296        /// # Performance
1297        ///
1298        /// This function uses a generic implementation and therefore is not very fast.
1299        impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for InnerProduct {
1300            #[inline(never)]
1301            fn run(
1302                self,
1303                _: diskann_wide::arch::Scalar,
1304                x: USlice<'_, $N>,
1305                y: USlice<'_, $N>
1306            ) -> MathematicalResult<u32> {
1307                let len = check_lengths!(x, y)?;
1308
1309                let mut accum: u32 = 0;
1310                for i in 0..len {
1311                    // SAFETY: `i` is guaranteed to be less than `x.len()`.
1312                    let ix = unsafe { x.get_unchecked(i) } as u32;
1313                    // SAFETY: `i` is guaranteed to be less than `y.len()`.
1314                    let iy = unsafe { y.get_unchecked(i) } as u32;
1315                    accum += ix * iy;
1316                }
1317                Ok(MV::new(accum))
1318            }
1319        }
1320    };
1321    ($($N:literal),+ $(,)?) => {
1322        $(impl_fallback_ip!($N);)+
1323    };
1324}
1325
1326impl_fallback_ip!(7, 6, 5, 4, 3, 2);
1327
1328#[cfg(target_arch = "x86_64")]
1329retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
1330
1331#[cfg(target_arch = "x86_64")]
1332retarget!(diskann_wide::arch::x86_64::V4, InnerProduct, 7, 6, 4, 5, 3);
1333
1334#[cfg(target_arch = "aarch64")]
1335retarget!(
1336    diskann_wide::arch::aarch64::Neon,
1337    InnerProduct,
1338    7,
1339    6,
1340    4,
1341    5,
1342    3,
1343    2
1344);
1345
1346dispatch_pure!(InnerProduct, 1, 2, 3, 4, 5, 6, 7, 8);
1347
1348//////////////////
1349// BitTranspose //
1350//////////////////
1351
1352/// The strategy is to compute the inner product `<x, y>` by decomposing the problem into
1353/// groups of 64-dimensions.
1354///
1355/// For each group, we load the 64-bits of `y` into a word `bits`. And the four 64-bit words
1356/// of the group in `x` in `b0`, `b1`, b2`, and `b3`.
1357///
1358/// Note that bit `i` in `b0` is bit-0 of the `i`-th value in ths group. Likewise, bit `i`
1359/// in `b1` is bit-1 of the same word.
1360///
1361/// This means that we can compute the partial inner product for this group as
1362/// ```math
1363/// (bits & b0).count_ones()                // Contribution of bit 0
1364///     + 2 * (bits & b1).count_ones()      // Contribution of bit 1
1365///     + 4 * (bits & b2).count_ones()      // Contribution of bit 2
1366///     + 8 * (bits & b3).count_ones()      // Contribution of bit 3
1367/// ```
1368/// We process as many full groups as we can.
1369///
1370/// To handle the remainder, we need to be careful about acessing `y` because `BitSlice`
1371/// only guarantees the validity of reads at the byte level. That is - we cannot assume that
1372/// a full 64-bit read is valid.
1373///
1374/// The bit-tranposed `x`, on the other hand, guarantees allocations in blocks of
1375/// 4 * 64-bits, so it can be treated as normal.
1376impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
1377    for InnerProduct
1378where
1379    A: Architecture,
1380{
1381    #[inline(always)]
1382    fn run(
1383        self,
1384        _: A,
1385        x: USlice<'_, 4, BitTranspose>,
1386        y: USlice<'_, 1, Dense>,
1387    ) -> MathematicalResult<u32> {
1388        let len = check_lengths!(x, y)?;
1389
1390        // We work in blocks of 64 element.
1391        //
1392        // The `BitTranspose` guarantees read are valid in blocks of 64 elements (32 byte).
1393        // However, the `Dense` representation only pads to bytes.
1394        // Our strategy for dealing with fewer than 64 remaining elements is to reconstruct
1395        // a 64-bit integer from bytes.
1396        let px: *const u64 = x.as_ptr().cast();
1397        let py: *const u64 = y.as_ptr().cast();
1398
1399        let mut i = 0;
1400        let mut s: u32 = 0;
1401
1402        let blocks = len / 64;
1403        while i < blocks {
1404            // SAFETY: `y` is valid for at least `blocks` 64-bit reads and `i < blocks`.
1405            let bits = unsafe { py.add(i).read_unaligned() };
1406
1407            // SAFETY: The layout for `x` is grouped into 32-byte blocks. We've ensured that
1408            // the lengths of the two vectors are the same, so we know that `x` has at least
1409            // `blocks` such regions.
1410            //
1411            // This loads the first 64-bits of block `i` where `i < blocks`.
1412            let b0 = unsafe { px.add(4 * i).read_unaligned() };
1413            s += (bits & b0).count_ones();
1414
1415            // SAFETY: This loads the second 64-bit word of block `i`.
1416            let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1417            s += (bits & b1).count_ones() << 1;
1418
1419            // SAFETY: This loads the third 64-bit word of block `i`.
1420            let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1421            s += (bits & b2).count_ones() << 2;
1422
1423            // SAFETY: This loads the fourth 64-bit word of block `i`.
1424            let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1425            s += (bits & b3).count_ones() << 3;
1426
1427            i += 1;
1428        }
1429
1430        // If the input length is a multiple of 64 - then we're done.
1431        if 64 * i == len {
1432            return Ok(MV::new(s));
1433        }
1434
1435        // Convert blocks to bytes.
1436        let k = i * 8;
1437
1438        // Unpack the last elements from the bit-vector.
1439        //
1440        // SAFETY: The length of the 1-bit BitSlice is `ceil(len / 8)`. This computation
1441        // effectively computes `ceil((64 * floor(len / 64)) / 8)`, which is less.
1442        let py = unsafe { py.cast::<u8>().add(k) };
1443        let bytes_remaining = y.bytes() - k;
1444        let mut bits: u64 = 0;
1445
1446        // Code - generation: Applying `min(8)` gives a constant upper-bound to the
1447        // compiler, allowing better code-generation.
1448        for j in 0..bytes_remaining.min(8) {
1449            // SAFETY: Starting at `py`, there are `bytes_remaining` valid bytes. This
1450            // accesses all of them.
1451            bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
1452        }
1453
1454        // Because the upper-bits of the last loaded byte can contain indeterminate bits,
1455        // we must mask out all out-of-bounds bits.
1456        bits &= (0x01u64 << (len - (64 * i))) - 1;
1457
1458        // Combine with the remainders.
1459        //
1460        // SAFETY: The `BitTranspose` permutation always allocates in granularies of blocks.
1461        // This loads the first 64-bit word of the last block.
1462        let b0 = unsafe { px.add(4 * i).read_unaligned() };
1463        s += (bits & b0).count_ones();
1464
1465        // SAFETY: This loads the second 64-bit word of the last block.
1466        let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
1467        s += (bits & b1).count_ones() << 1;
1468
1469        // SAFETY: This loads the third 64-bit word of the last block.
1470        let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
1471        s += (bits & b2).count_ones() << 2;
1472
1473        // SAFETY: This loads the fourth 64-bit word of the last block.
1474        let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
1475        s += (bits & b3).count_ones() << 3;
1476
1477        Ok(MV::new(s))
1478    }
1479}
1480
1481impl
1482    PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
1483    for InnerProduct
1484{
1485    fn evaluate(
1486        x: USlice<'_, 4, BitTranspose>,
1487        y: USlice<'_, 1, Dense>,
1488    ) -> MathematicalResult<u32> {
1489        (diskann_wide::ARCH).run2(Self, x, y)
1490    }
1491}
1492
1493////////////////////
1494// Full Precision //
1495////////////////////
1496
1497/// The main trick here is avoiding explicit conversion from 1 bit integers to 32-bit
1498/// floating-point numbers by using `_mm256_permutevar_ps`, which performs a shuffle on two
1499/// independent 128-bit lanes of `f32` values in a register `A` using the lower 2-bits of
1500/// each 32-bit integer in a register `B`.
1501///
1502/// Importantly, this instruction only takes a single cycle and we can avoid any kind of
1503/// masking. Going the route of conversion would require and `AND` operation to isolate
1504/// bottom bits and a somewhat lengthy 32-bit integer to `f32` conversion instruction.
1505///
1506/// The overall strategy broadcasts a 32-bit integer (consisting of 32, 1-bit values) across
1507/// 8 lanes into a register `A`.
1508///
1509/// Each lane is then shifted by a different amount so:
1510///
1511/// * Lane 0 has value 0 as its least significant bit (LSB)
1512/// * Lane 1 has value 1 as its LSB.
1513/// * Lane 2 has value 2 as its LSB.
1514/// * etc.
1515///
1516/// These LSB's are used to power the shuffle function to convert to `f32` values (either
1517/// 0.0 or 1.0) and we can FMA as needed.
1518///
1519/// To process the next group of 8 values, we shift all lanes in `A` by 8-bits so lane 0
1520/// has value 8 as its LSB, lane 1 has value 9 etc.
1521///
1522/// A total of three shifts are applied to extract all 32 1-bit value as `f32` in order.
1523#[cfg(target_arch = "x86_64")]
1524impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
1525    for InnerProduct
1526{
1527    #[inline(always)]
1528    fn run(
1529        self,
1530        arch: diskann_wide::arch::x86_64::V3,
1531        x: &[f32],
1532        y: USlice<'_, 1>,
1533    ) -> MathematicalResult<f32> {
1534        let len = check_lengths!(x, y)?;
1535
1536        use std::arch::x86_64::*;
1537
1538        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1539        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1540
1541        // Replicate 0s and 1s so we effectively get a shuffle that only depends on the
1542        // bottom bit (instead of the lowest 2).
1543        let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
1544
1545        // Shifts required to offset each lane.
1546        let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
1547
1548        let px: *const f32 = x.as_ptr();
1549        let py: *const u32 = y.as_ptr().cast();
1550
1551        let mut i = 0;
1552        let mut s = f32s::default(arch);
1553
1554        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1555        let to_f32 = |v: u32s| -> f32s {
1556            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
1557            // which the presence of the `x86_64::V3` architecture guarantees is available.
1558            f32s::from_underlying(arch, unsafe {
1559                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1560            })
1561        };
1562
1563        // Data is processed in groups of 32 elements.
1564        let blocks = len / 32;
1565        if i < blocks {
1566            let mut s0 = f32s::default(arch);
1567            let mut s1 = f32s::default(arch);
1568
1569            while i < blocks {
1570                // SAFETY: `i < blocks` implies 32-bits are readable from this offset.
1571                let iy = prep(unsafe { py.add(i).read_unaligned() });
1572
1573                // SAFETY: `i < blocks` implies 32 f32 values are readable beginning at `32*i`.
1574                let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
1575                // SAFETY: See above.
1576                let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
1577                // SAFETY: See above.
1578                let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
1579                // SAFETY: See above.
1580                let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
1581
1582                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1583                s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
1584                s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
1585                s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
1586
1587                i += 1;
1588            }
1589            s = s0 + s1;
1590        }
1591
1592        let remainder = len % 32;
1593        if remainder != 0 {
1594            let tail = if len % 8 == 0 { 8 } else { len % 8 };
1595
1596            // SAFETY: Because `remainder != 0`, there is valid memory beginning at the
1597            // offset `blocks`, so this addition remains within an allocated object.
1598            let py = unsafe { py.add(blocks) };
1599
1600            if remainder <= 8 {
1601                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
1602                // The same logic applies to the SIMD loads.
1603                unsafe {
1604                    load_one(py, |iy| {
1605                        let iy = prep(iy);
1606                        let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
1607                        s = ix.mul_add_simd(to_f32(iy), s);
1608                    })
1609                }
1610            } else if remainder <= 16 {
1611                // SAFETY: At least two bytes are readable for `py`.
1612                // The same logic applies to the SIMD loads.
1613                unsafe {
1614                    load_two(py, |iy| {
1615                        let iy = prep(iy);
1616                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1617                        let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
1618                        s = ix0.mul_add_simd(to_f32(iy), s);
1619                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1620                    })
1621                }
1622            } else if remainder <= 24 {
1623                // SAFETY: At least three bytes are readable for `py`.
1624                // The same logic applies to the SIMD loads.
1625                unsafe {
1626                    load_three(py, |iy| {
1627                        let iy = prep(iy);
1628
1629                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1630                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1631                        let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
1632
1633                        s = ix0.mul_add_simd(to_f32(iy), s);
1634                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1635                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1636                    })
1637                }
1638            } else {
1639                // SAFETY: At least four bytes are readable for `py`.
1640                // The same logic applies to the SIMD loads.
1641                unsafe {
1642                    load_four(py, |iy| {
1643                        let iy = prep(iy);
1644
1645                        let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
1646                        let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
1647                        let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
1648                        let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
1649
1650                        s = ix0.mul_add_simd(to_f32(iy), s);
1651                        s = ix1.mul_add_simd(to_f32(iy >> 8), s);
1652                        s = ix2.mul_add_simd(to_f32(iy >> 16), s);
1653                        s = ix3.mul_add_simd(to_f32(iy >> 24), s);
1654                    })
1655                }
1656            }
1657        }
1658
1659        Ok(MV::new(s.sum_tree()))
1660    }
1661}
1662
1663/// The strategy used here is almost identical to that used for 1-bit distances. The main
1664/// difference is that now we use the full 2-bit shuffle capabilities of `_mm256_permutevar_ps`
1665/// and ths relatives sizes of the shifts are slightly different.
1666#[cfg(target_arch = "x86_64")]
1667impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
1668    for InnerProduct
1669{
1670    #[inline(always)]
1671    fn run(
1672        self,
1673        arch: diskann_wide::arch::x86_64::V3,
1674        x: &[f32],
1675        y: USlice<'_, 2>,
1676    ) -> MathematicalResult<f32> {
1677        let len = check_lengths!(x, y)?;
1678
1679        use std::arch::x86_64::*;
1680
1681        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1682        diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
1683
1684        // This is the lookup table mapping 2-bit patterns to their equivalent `f32`
1685        // representation. The AVX2 shuffle only applies within each 128-bit group of the
1686        // full 256-bit register, so we replicate the contents.
1687        let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
1688
1689        // Shifts required to get logical dimensions shifted to the lower 2-bits of each lane.
1690        let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
1691
1692        let px: *const f32 = x.as_ptr();
1693        let py: *const u32 = y.as_ptr().cast();
1694
1695        let mut i = 0;
1696        let mut s = f32s::default(arch);
1697
1698        let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
1699        let to_f32 = |v: u32s| -> f32s {
1700            // SAFETY: The `_mm256_permutevar_ps` instruction requires the AVX extension,
1701            // which the presense of the `x86_64::V3` architecture guarantees is available.
1702            f32s::from_underlying(arch, unsafe {
1703                _mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
1704            })
1705        };
1706
1707        let blocks = len / 16;
1708        if blocks != 0 {
1709            let mut s0 = f32s::default(arch);
1710            let mut s1 = f32s::default(arch);
1711
1712            // Process 32 elements.
1713            while i + 2 <= blocks {
1714                // SAFETY: `i + 2 <= blocks` implies `py.add(i)` is in-bounds and readable
1715                // for 4 unaligned bytes.
1716                let iy = prep(unsafe { py.add(i).read_unaligned() });
1717
1718                // SAFETY: Same logic as above, just applied to `f32` values instead of
1719                // packed bits.
1720                let (ix0, ix1) = unsafe {
1721                    (
1722                        f32s::load_simd(arch, px.add(16 * i)),
1723                        f32s::load_simd(arch, px.add(16 * i + 8)),
1724                    )
1725                };
1726
1727                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1728                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1729
1730                // SAFETY: `i + 2 <= blocks` implies `py.add(i + 1)` is in-bounds and readable
1731                // for 4 unaligned bytes.
1732                let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
1733
1734                // SAFETY: Same logic as above.
1735                let (ix0, ix1) = unsafe {
1736                    (
1737                        f32s::load_simd(arch, px.add(16 * (i + 1))),
1738                        f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
1739                    )
1740                };
1741
1742                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1743                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1744
1745                i += 2;
1746            }
1747
1748            // Process 16 elements
1749            if i < blocks {
1750                // SAFETY: `i < blocks` implies `py.add(i)` is in-bounds and readable for
1751                // 4 unaligned bytes.
1752                let iy = prep(unsafe { py.add(i).read_unaligned() });
1753
1754                // SAFETY: Same logic as above.
1755                let (ix0, ix1) = unsafe {
1756                    (
1757                        f32s::load_simd(arch, px.add(16 * i)),
1758                        f32s::load_simd(arch, px.add(16 * i + 8)),
1759                    )
1760                };
1761
1762                s0 = ix0.mul_add_simd(to_f32(iy), s0);
1763                s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
1764            }
1765
1766            s = s0 + s1;
1767        }
1768
1769        let remainder = len % 16;
1770        if remainder != 0 {
1771            let tail = if len % 8 == 0 { 8 } else { len % 8 };
1772            // SAFETY: Non-zero remainder implies there are readable bytes after the offset
1773            // `blocks`, so the addition is valid.
1774            let py = unsafe { py.add(blocks) };
1775
1776            if remainder <= 4 {
1777                // SAFETY: Non-zero remainder implies at least one byte is readable for `py`.
1778                // The same logic applies to the SIMD loads.
1779                unsafe {
1780                    load_one(py, |iy| {
1781                        let iy = prep(iy);
1782                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1783                        s = ix.mul_add_simd(to_f32(iy), s);
1784                    });
1785                }
1786            } else if remainder <= 8 {
1787                // SAFETY: At least two bytes are readable for `py`.
1788                // The same logic applies to the SIMD loads.
1789                unsafe {
1790                    load_two(py, |iy| {
1791                        let iy = prep(iy);
1792                        let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
1793                        s = ix.mul_add_simd(to_f32(iy), s);
1794                    });
1795                }
1796            } else if remainder <= 12 {
1797                // SAFETY: At least three bytes are readable for `py`.
1798                // The same logic applies to the SIMD loads.
1799                unsafe {
1800                    load_three(py, |iy| {
1801                        let iy = prep(iy);
1802                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1803                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1804                        s = ix0.mul_add_simd(to_f32(iy), s);
1805                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1806                    });
1807                }
1808            } else {
1809                // SAFETY: At least four bytes are readable for `py`.
1810                // The same logic applies to the SIMD loads.
1811                unsafe {
1812                    load_four(py, |iy| {
1813                        let iy = prep(iy);
1814                        let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
1815                        let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
1816                        s = ix0.mul_add_simd(to_f32(iy), s);
1817                        s = ix1.mul_add_simd(to_f32(iy >> 16), s);
1818                    });
1819                }
1820            }
1821        }
1822
1823        Ok(MV::new(s.sum_tree()))
1824    }
1825}
1826
1827/// The strategy here is similar to the 1 and 2-bit strategies. However, instead of using
1828/// `_mm256_permutevar_ps`, we now go directly for 32-bit integer to 32-bit floating point.
1829///
1830/// This is because the shuffle intrinsic only supports 2-bit shuffles.
1831#[cfg(target_arch = "x86_64")]
1832impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
1833    for InnerProduct
1834{
1835    #[inline(always)]
1836    fn run(
1837        self,
1838        arch: diskann_wide::arch::x86_64::V3,
1839        x: &[f32],
1840        y: USlice<'_, 4>,
1841    ) -> MathematicalResult<f32> {
1842        let len = check_lengths!(x, y)?;
1843
1844        diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
1845        diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
1846
1847        let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
1848        let mask = i32s::splat(arch, 0x0f);
1849
1850        let to_f32 = |v: u32| -> f32s {
1851            ((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
1852        };
1853
1854        let px: *const f32 = x.as_ptr();
1855        let py: *const u32 = y.as_ptr().cast();
1856
1857        let mut i = 0;
1858        let mut s = f32s::default(arch);
1859
1860        let blocks = len / 8;
1861        while i < blocks {
1862            // SAFETY: `i < blocks` implies that 8 `f32` values are readable from `8*i`.
1863            let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
1864            // SAFETY: Same logic as above - but applied to the packed bits.
1865            let iy = to_f32(unsafe { py.add(i).read_unaligned() });
1866            s = ix.mul_add_simd(iy, s);
1867
1868            i += 1;
1869        }
1870
1871        let remainder = len % 8;
1872        if remainder != 0 {
1873            let f = |iy| {
1874                // SAFETY: The epilogue handles at most 8 values. Since the remainder is
1875                // non-zero, the pointer arithmetic is in-bounds and `load_simd_first` will
1876                // avoid accessing the out-of-bounds elements.
1877                let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
1878                s = ix.mul_add_simd(to_f32(iy), s);
1879            };
1880
1881            // SAFETY: Non-zero remainder means there are readable bytes from the offset
1882            // `blocks`.
1883            let py = unsafe { py.add(blocks) };
1884
1885            if remainder <= 2 {
1886                // SAFETY: Non-zero remainder less than 2 implies that one byte is readable.
1887                unsafe { load_one(py, f) };
1888            } else if remainder <= 4 {
1889                // SAFETY: At least two bytes are readable from `py`.
1890                unsafe { load_two(py, f) };
1891            } else if remainder <= 6 {
1892                // SAFETY: At least three bytes are readable from `py`.
1893                unsafe { load_three(py, f) };
1894            } else {
1895                // SAFETY: At least four bytes are readable from `py`.
1896                unsafe { load_four(py, f) };
1897            }
1898        }
1899
1900        Ok(MV::new(s.sum_tree()))
1901    }
1902}
1903
1904impl<const N: usize>
1905    Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
1906    for InnerProduct
1907where
1908    Unsigned: Representation<N>,
1909{
1910    /// A fallback implementation that uses scaler indexing to retrieve values from
1911    /// the corresponding `BitSlice`.
1912    #[inline(always)]
1913    fn run(
1914        self,
1915        _: diskann_wide::arch::Scalar,
1916        x: &[f32],
1917        y: USlice<'_, N>,
1918    ) -> MathematicalResult<f32> {
1919        check_lengths!(x, y)?;
1920
1921        let mut s = 0.0;
1922        for (i, x) in x.iter().enumerate() {
1923            // SAFETY: We've ensured that `x.len() == y.len()`, so this access is
1924            // always inbounds.
1925            let y = unsafe { y.get_unchecked(i) } as f32;
1926            s += x * y;
1927        }
1928
1929        Ok(MV::new(s))
1930    }
1931}
1932
1933/// Implement `Target2` for higher architecture in terms of the scalar fallback.
1934macro_rules! ip_retarget {
1935    ($arch:path, $N:literal) => {
1936        impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
1937            for InnerProduct
1938        {
1939            #[inline(always)]
1940            fn run(
1941                self,
1942                arch: $arch,
1943                x: &[f32],
1944                y: USlice<'_, $N>,
1945            ) -> MathematicalResult<f32> {
1946                self.run(arch.retarget(), x, y)
1947            }
1948        }
1949    };
1950    ($arch:path, $($Ns:literal),*) => {
1951        $(ip_retarget!($arch, $Ns);)*
1952    }
1953}
1954
1955#[cfg(target_arch = "x86_64")]
1956ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
1957
1958#[cfg(target_arch = "x86_64")]
1959ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
1960
1961#[cfg(target_arch = "aarch64")]
1962ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8);
1963
1964/// Delegate the implementation of `PureDistanceFunction` to `diskann_wide::arch::Target2`
1965/// with the current architectures.
1966macro_rules! dispatch_full_ip {
1967    ($N:literal) => {
1968        /// Compute the inner product between `x` and `y`.
1969        ///
1970        /// Returns an error if the arguments have different lengths.
1971        impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
1972            for InnerProduct
1973        {
1974            fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
1975                Self.run(ARCH, x, y)
1976            }
1977        }
1978    };
1979    ($($Ns:literal),*) => {
1980        $(dispatch_full_ip!($Ns);)*
1981    }
1982}
1983
1984dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
1985
1986///////////
1987// Tests //
1988///////////
1989
1990#[cfg(test)]
1991mod tests {
1992    use std::{collections::HashMap, sync::LazyLock};
1993
1994    use diskann_utils::Reborrow;
1995    use rand::{
1996        Rng, SeedableRng,
1997        distr::{Distribution, Uniform},
1998        rngs::StdRng,
1999        seq::IndexedRandom,
2000    };
2001
2002    use super::*;
2003    use crate::bits::{BoxedBitSlice, Representation, Unsigned};
2004
2005    type MR = MathematicalResult<u32>;
2006
2007    /////////////////////////
2008    // Unsigned Bit Slices //
2009    /////////////////////////
2010
2011    // This test works by generating random integer codes for the compressed vectors,
2012    // then uses the functions implemented in `vector` to compute the expected result of
2013    // the computation in "full precision integer space".
2014    //
2015    // We verify that the exact same results are returned by each computation.
2016    fn test_bitslice_distances<const NBITS: usize, R>(
2017        dim_max: usize,
2018        trials_per_dim: usize,
2019        evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2020        evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
2021        context: &str,
2022        rng: &mut R,
2023    ) where
2024        Unsigned: Representation<NBITS>,
2025        R: Rng,
2026    {
2027        let domain = Unsigned::domain_const::<NBITS>();
2028        let min: i64 = *domain.start();
2029        let max: i64 = *domain.end();
2030
2031        let dist = Uniform::new_inclusive(min, max).unwrap();
2032
2033        for dim in 0..dim_max {
2034            let mut x_reference: Vec<u8> = vec![0; dim];
2035            let mut y_reference: Vec<u8> = vec![0; dim];
2036
2037            let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2038            let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
2039
2040            for trial in 0..trials_per_dim {
2041                x_reference
2042                    .iter_mut()
2043                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2044                y_reference
2045                    .iter_mut()
2046                    .for_each(|i| *i = dist.sample(rng).try_into().unwrap());
2047
2048                // Fill the input slices with 1's so we can catch situations where we don't
2049                // correctly handle odd remaining elements.
2050                x.as_mut_slice().fill(u8::MAX);
2051                y.as_mut_slice().fill(u8::MAX);
2052
2053                for i in 0..dim {
2054                    x.set(i, x_reference[i].into()).unwrap();
2055                    y.set(i, y_reference[i].into()).unwrap();
2056                }
2057
2058                // Check L2
2059                let expected: MV<f32> =
2060                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2061
2062                let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
2063
2064                // Integer computations should be exact.
2065                assert_eq!(
2066                    expected.into_inner(),
2067                    got.into_inner() as f32,
2068                    "failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
2069                    NBITS,
2070                    dim,
2071                    trial,
2072                    context,
2073                );
2074
2075                // Check IP
2076                let expected: MV<f32> =
2077                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2078
2079                let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
2080
2081                // Integer computations should be exact.
2082                assert_eq!(
2083                    expected.into_inner(),
2084                    got.into_inner() as f32,
2085                    "faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
2086                    NBITS,
2087                    dim,
2088                    trial,
2089                    context,
2090                );
2091            }
2092        }
2093
2094        // Test that we correctly return error types for length mismatches.
2095        let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
2096        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2097
2098        assert!(
2099            evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
2100            "context: {}",
2101            context
2102        );
2103        assert!(
2104            evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
2105            "context: {}",
2106            context
2107        );
2108
2109        assert!(
2110            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2111            "context: {}",
2112            context
2113        );
2114        assert!(
2115            evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
2116            "context: {}",
2117            context
2118        );
2119    }
2120
2121    cfg_if::cfg_if! {
2122        if #[cfg(miri)] {
2123            const MAX_DIM: usize = 128;
2124            const TRIALS_PER_DIM: usize = 1;
2125        } else {
2126            const MAX_DIM: usize = 256;
2127            const TRIALS_PER_DIM: usize = 20;
2128        }
2129    }
2130
2131    // For the bit-slice kernels, we want to use different maximum dimensions for the distance
2132    // test depending on the implementation of the kernel, and whether or not we are running
2133    // under Miri.
2134    //
2135    // For implementations that use the scalar fallback, we need not set very high bounds
2136    // (particularly when running under miri) because the implementations are quite simple.
2137    //
2138    // However, some SIMD kernels (especially for the lower bit widths), require higher bounds
2139    // to trigger all possible corner cases.
2140    static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
2141        use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4};
2142        [
2143            (Key::new(1, Scalar), Bounds::new(64, 64)),
2144            (Key::new(1, X86_64_V3), Bounds::new(256, 256)),
2145            (Key::new(1, X86_64_V4), Bounds::new(256, 256)),
2146            (Key::new(1, Neon), Bounds::new(64, 64)),
2147            (Key::new(2, Scalar), Bounds::new(64, 64)),
2148            // Need a higher miri-amount due to the larget block size
2149            (Key::new(2, X86_64_V3), Bounds::new(512, 300)),
2150            (Key::new(2, X86_64_V4), Bounds::new(768, 600)), // main loop processes 256 items
2151            (Key::new(2, Neon), Bounds::new(64, 64)),
2152            (Key::new(3, Scalar), Bounds::new(64, 64)),
2153            (Key::new(3, X86_64_V3), Bounds::new(256, 96)),
2154            (Key::new(3, X86_64_V4), Bounds::new(256, 96)),
2155            (Key::new(3, Neon), Bounds::new(64, 64)),
2156            (Key::new(4, Scalar), Bounds::new(64, 64)),
2157            // Need a higher miri-amount due to the larget block size
2158            (Key::new(4, X86_64_V3), Bounds::new(256, 150)),
2159            (Key::new(4, X86_64_V4), Bounds::new(256, 150)),
2160            (Key::new(4, Neon), Bounds::new(64, 64)),
2161            (Key::new(5, Scalar), Bounds::new(64, 64)),
2162            (Key::new(5, X86_64_V3), Bounds::new(256, 96)),
2163            (Key::new(5, X86_64_V4), Bounds::new(256, 96)),
2164            (Key::new(5, Neon), Bounds::new(64, 64)),
2165            (Key::new(6, Scalar), Bounds::new(64, 64)),
2166            (Key::new(6, X86_64_V3), Bounds::new(256, 96)),
2167            (Key::new(6, X86_64_V4), Bounds::new(256, 96)),
2168            (Key::new(6, Neon), Bounds::new(64, 64)),
2169            (Key::new(7, Scalar), Bounds::new(64, 64)),
2170            (Key::new(7, X86_64_V3), Bounds::new(256, 96)),
2171            (Key::new(7, X86_64_V4), Bounds::new(256, 96)),
2172            (Key::new(7, Neon), Bounds::new(64, 64)),
2173            (Key::new(8, Scalar), Bounds::new(64, 64)),
2174            (Key::new(8, X86_64_V3), Bounds::new(256, 96)),
2175            (Key::new(8, X86_64_V4), Bounds::new(256, 96)),
2176            (Key::new(8, Neon), Bounds::new(64, 64)),
2177        ]
2178        .into_iter()
2179        .collect()
2180    });
2181
2182    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2183    enum ArchKey {
2184        Scalar,
2185        #[expect(non_camel_case_types)]
2186        X86_64_V3,
2187        #[expect(non_camel_case_types)]
2188        X86_64_V4,
2189        Neon,
2190    }
2191
2192    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
2193    struct Key {
2194        nbits: usize,
2195        arch: ArchKey,
2196    }
2197
2198    impl Key {
2199        fn new(nbits: usize, arch: ArchKey) -> Self {
2200            Self { nbits, arch }
2201        }
2202    }
2203
2204    #[derive(Debug, Clone, Copy)]
2205    struct Bounds {
2206        standard: usize,
2207        miri: usize,
2208    }
2209
2210    impl Bounds {
2211        fn new(standard: usize, miri: usize) -> Self {
2212            Self { standard, miri }
2213        }
2214
2215        fn get(&self) -> usize {
2216            if cfg!(miri) { self.miri } else { self.standard }
2217        }
2218    }
2219
2220    macro_rules! test_bitslice {
2221        ($name:ident, $nbits:literal, $seed:literal) => {
2222            #[test]
2223            fn $name() {
2224                let mut rng = StdRng::seed_from_u64($seed);
2225
2226                let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
2227
2228                test_bitslice_distances::<$nbits, _>(
2229                    max_dim,
2230                    TRIALS_PER_DIM,
2231                    &|x, y| SquaredL2::evaluate(x, y),
2232                    &|x, y| InnerProduct::evaluate(x, y),
2233                    "pure distance function",
2234                    &mut rng,
2235                );
2236
2237                test_bitslice_distances::<$nbits, _>(
2238                    max_dim,
2239                    TRIALS_PER_DIM,
2240                    &|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
2241                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2242                    "scalar arch",
2243                    &mut rng,
2244                );
2245
2246                // Architecture Specific.
2247                #[cfg(target_arch = "x86_64")]
2248                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2249                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
2250                    test_bitslice_distances::<$nbits, _>(
2251                        max_dim,
2252                        TRIALS_PER_DIM,
2253                        &|x, y| arch.run2(SquaredL2, x, y),
2254                        &|x, y| arch.run2(InnerProduct, x, y),
2255                        "x86-64-v3",
2256                        &mut rng,
2257                    );
2258                }
2259
2260                #[cfg(target_arch = "x86_64")]
2261                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2262                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
2263                    test_bitslice_distances::<$nbits, _>(
2264                        max_dim,
2265                        TRIALS_PER_DIM,
2266                        &|x, y| arch.run2(SquaredL2, x, y),
2267                        &|x, y| arch.run2(InnerProduct, x, y),
2268                        "x86-64-v4",
2269                        &mut rng,
2270                    );
2271                }
2272
2273                #[cfg(target_arch = "aarch64")]
2274                if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2275                    let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get();
2276                    test_bitslice_distances::<$nbits, _>(
2277                        max_dim,
2278                        TRIALS_PER_DIM,
2279                        &|x, y| arch.run2(SquaredL2, x, y),
2280                        &|x, y| arch.run2(InnerProduct, x, y),
2281                        "neon",
2282                        &mut rng,
2283                    );
2284                }
2285            }
2286        };
2287    }
2288
2289    test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
2290    test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
2291    test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
2292    test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
2293    test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
2294    test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
2295    test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
2296    test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
2297
2298    ///////////////////////////
2299    // Hamming Bit Distances //
2300    ///////////////////////////
2301
2302    fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
2303    where
2304        R: Rng,
2305    {
2306        let dist: [i8; 2] = [-1, 1];
2307
2308        for dim in 0..dim_max {
2309            let mut x_reference: Vec<i8> = vec![1; dim];
2310            let mut y_reference: Vec<i8> = vec![1; dim];
2311
2312            let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2313            let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
2314
2315            for _ in 0..trials_per_dim {
2316                x_reference
2317                    .iter_mut()
2318                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2319                y_reference
2320                    .iter_mut()
2321                    .for_each(|i| *i = *dist.choose(rng).unwrap());
2322
2323                // Fill the input slices with 1's so we can catch situations where we don't
2324                // correctly handle odd remaining elements.
2325                x.as_mut_slice().fill(u8::MAX);
2326                y.as_mut_slice().fill(u8::MAX);
2327
2328                for i in 0..dim {
2329                    x.set(i, x_reference[i].into()).unwrap();
2330                    y.set(i, y_reference[i].into()).unwrap();
2331                }
2332
2333                // We can check equality by evaluating the L2 distance between the reference
2334                // vectors.
2335                //
2336                // This is proportional to the Hamming distance by a factor of 4 (since the
2337                // distance betwwen +1 and -1 is 2 - and 2^2 = 4.
2338                let expected: MV<f32> =
2339                    diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
2340                let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
2341                assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
2342            }
2343        }
2344
2345        let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
2346        let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
2347        assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
2348        assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
2349    }
2350
2351    #[test]
2352    fn test_hamming_distance() {
2353        let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
2354        test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
2355    }
2356
2357    ///////////////////
2358    // Heterogeneous //
2359    ///////////////////
2360
2361    fn test_bit_transpose_distances<R>(
2362        dim_max: usize,
2363        trials_per_dim: usize,
2364        evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
2365        context: &str,
2366        rng: &mut R,
2367    ) where
2368        R: Rng,
2369    {
2370        let dist_4bit = {
2371            let domain = Unsigned::domain_const::<4>();
2372            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2373        };
2374
2375        let dist_1bit = {
2376            let domain = Unsigned::domain_const::<1>();
2377            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2378        };
2379
2380        for dim in 0..dim_max {
2381            let mut x_reference: Vec<u8> = vec![0; dim];
2382            let mut y_reference: Vec<u8> = vec![0; dim];
2383
2384            let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
2385            let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
2386
2387            for trial in 0..trials_per_dim {
2388                x_reference
2389                    .iter_mut()
2390                    .for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
2391                y_reference
2392                    .iter_mut()
2393                    .for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
2394
2395                // First - pre-set all the values in the bit-slices to 1.
2396                x.as_mut_slice().fill(u8::MAX);
2397                y.as_mut_slice().fill(u8::MAX);
2398
2399                for i in 0..dim {
2400                    x.set(i, x_reference[i].into()).unwrap();
2401                    y.set(i, y_reference[i].into()).unwrap();
2402                }
2403
2404                // Check IP
2405                let expected: MV<f32> =
2406                    diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
2407
2408                let got = evaluate_ip(x.reborrow(), y.reborrow());
2409
2410                // Integer computations should be exact.
2411                assert_eq!(
2412                    expected.into_inner(),
2413                    got.unwrap().into_inner() as f32,
2414                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
2415                    dim,
2416                    trial,
2417                    context,
2418                );
2419            }
2420        }
2421
2422        let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
2423        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
2424        assert!(
2425            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2426            "context: {}",
2427            context
2428        );
2429
2430        let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
2431        assert!(
2432            evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
2433            "context: {}",
2434            context
2435        );
2436    }
2437
2438    #[test]
2439    fn test_bit_transpose_distance() {
2440        let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
2441
2442        test_bit_transpose_distances(
2443            MAX_DIM,
2444            TRIALS_PER_DIM,
2445            &|x, y| InnerProduct::evaluate(x, y),
2446            "pure distance function",
2447            &mut rng,
2448        );
2449
2450        test_bit_transpose_distances(
2451            MAX_DIM,
2452            TRIALS_PER_DIM,
2453            &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2454            "scalar",
2455            &mut rng,
2456        );
2457
2458        // Architecture Specific.
2459        #[cfg(target_arch = "x86_64")]
2460        if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2461            test_bit_transpose_distances(
2462                MAX_DIM,
2463                TRIALS_PER_DIM,
2464                &|x, y| arch.run2(InnerProduct, x, y),
2465                "x86-64-v3",
2466                &mut rng,
2467            );
2468        }
2469
2470        // Architecture Specific.
2471        #[cfg(target_arch = "x86_64")]
2472        if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2473            test_bit_transpose_distances(
2474                MAX_DIM,
2475                TRIALS_PER_DIM,
2476                &|x, y| arch.run2(InnerProduct, x, y),
2477                "x86-64-v4",
2478                &mut rng,
2479            );
2480        }
2481
2482        // Architecture Specific.
2483        #[cfg(target_arch = "aarch64")]
2484        if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2485            test_bit_transpose_distances(
2486                MAX_DIM,
2487                TRIALS_PER_DIM,
2488                &|x, y| arch.run2(InnerProduct, x, y),
2489                "neon",
2490                &mut rng,
2491            );
2492        }
2493    }
2494
2495    //////////
2496    // Full //
2497    //////////
2498
2499    fn test_full_distances<const NBITS: usize>(
2500        dim_max: usize,
2501        trials_per_dim: usize,
2502        evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
2503        context: &str,
2504        rng: &mut impl Rng,
2505    ) where
2506        Unsigned: Representation<NBITS>,
2507    {
2508        // let dist_float = Uniform::new_inclusive(-2.0f32, 2.0f32).unwrap();
2509        let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
2510        let dist_bit = {
2511            let domain = Unsigned::domain_const::<NBITS>();
2512            Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
2513        };
2514
2515        for dim in 0..dim_max {
2516            let mut x: Vec<f32> = vec![0.0; dim];
2517
2518            let mut y_reference: Vec<u8> = vec![0; dim];
2519            let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
2520
2521            for trial in 0..trials_per_dim {
2522                x.iter_mut()
2523                    .for_each(|i| *i = *dist_float.choose(rng).unwrap());
2524                y_reference
2525                    .iter_mut()
2526                    .for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
2527
2528                // First - pre-set all the values in the bit-slices to 1.
2529                y.as_mut_slice().fill(u8::MAX);
2530
2531                let mut expected = 0.0;
2532                for i in 0..dim {
2533                    y.set(i, y_reference[i].into()).unwrap();
2534                    expected += y_reference[i] as f32 * x[i];
2535                }
2536
2537                // Check IP
2538                let got = evaluate_ip(&x, y.reborrow()).unwrap();
2539
2540                // Integer computations should be exact.
2541                assert_eq!(
2542                    expected,
2543                    got.into_inner(),
2544                    "faild InnerProduct for dim = {}, trial = {} -- context {}",
2545                    dim,
2546                    trial,
2547                    context,
2548                );
2549
2550                // Ensure that using the `Scalar` architecture providers the same
2551                // results.
2552                let scalar: MV<f32> = InnerProduct
2553                    .run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
2554                    .unwrap();
2555                assert_eq!(got.into_inner(), scalar.into_inner());
2556            }
2557        }
2558
2559        // Error Checking
2560        let x = vec![0.0; 10];
2561        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
2562        assert!(
2563            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2564            "context: {}",
2565            context
2566        );
2567
2568        let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
2569        assert!(
2570            evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
2571            "context: {}",
2572            context
2573        );
2574    }
2575
2576    macro_rules! test_full {
2577        ($name:ident, $nbits:literal, $seed:literal) => {
2578            #[test]
2579            fn $name() {
2580                let mut rng = StdRng::seed_from_u64($seed);
2581
2582                test_full_distances::<$nbits>(
2583                    MAX_DIM,
2584                    TRIALS_PER_DIM,
2585                    &|x, y| InnerProduct::evaluate(x, y),
2586                    "pure distance function",
2587                    &mut rng,
2588                );
2589
2590                test_full_distances::<$nbits>(
2591                    MAX_DIM,
2592                    TRIALS_PER_DIM,
2593                    &|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
2594                    "scalar",
2595                    &mut rng,
2596                );
2597
2598                // Architecture Specific.
2599                #[cfg(target_arch = "x86_64")]
2600                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
2601                    test_full_distances::<$nbits>(
2602                        MAX_DIM,
2603                        TRIALS_PER_DIM,
2604                        &|x, y| arch.run2(InnerProduct, x, y),
2605                        "x86-64-v3",
2606                        &mut rng,
2607                    );
2608                }
2609
2610                #[cfg(target_arch = "x86_64")]
2611                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
2612                    test_full_distances::<$nbits>(
2613                        MAX_DIM,
2614                        TRIALS_PER_DIM,
2615                        &|x, y| arch.run2(InnerProduct, x, y),
2616                        "x86-64-v4",
2617                        &mut rng,
2618                    );
2619                }
2620
2621                #[cfg(target_arch = "aarch64")]
2622                if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
2623                    test_full_distances::<$nbits>(
2624                        MAX_DIM,
2625                        TRIALS_PER_DIM,
2626                        &|x, y| arch.run2(InnerProduct, x, y),
2627                        "neon",
2628                        &mut rng,
2629                    );
2630                }
2631            }
2632        };
2633    }
2634
2635    test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
2636    test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
2637    test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
2638    test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
2639    test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
2640    test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
2641    test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
2642    test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
2643}