Skip to main content

diskann_vector/distance/
simd.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::convert::AsRef;
7
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11#[cfg(not(target_arch = "aarch64"))]
12use diskann_wide::SIMDDotProduct;
13use diskann_wide::{
14    arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDMulAdd, SIMDSumTree,
15    SIMDVector,
16};
17
18use crate::Half;
19
20/// A helper trait to allow integer to f32 conversion (which may be lossy).
21pub trait LossyF32Conversion: Copy {
22    fn as_f32_lossy(self) -> f32;
23}
24
25impl LossyF32Conversion for f32 {
26    fn as_f32_lossy(self) -> f32 {
27        self
28    }
29}
30
31impl LossyF32Conversion for i32 {
32    fn as_f32_lossy(self) -> f32 {
33        self as f32
34    }
35}
36
37cfg_if::cfg_if! {
38    if #[cfg(miri)] {
39        fn force_eval(_x: f32) {}
40    } else if #[cfg(target_arch = "x86_64")] {
41        use std::arch::asm;
42
43        /// Force the evaluation of the argument, preventing the compiler from reordering the
44        /// computation of `x` behind a condition.
45        ///
46        /// In the context of Cosine similarity, this can help code generation for
47        /// static-dimensional kernels
48        #[inline(always)]
49        fn force_eval(x: f32) {
50            // SAFETY: This function executes no instructions. As such, it satisfies the long
51            // list of requirements for inline assembly.
52            //
53            // See: https://doc.rust-lang.org/reference/inline-assembly.html#rules-for-inline-assembly
54            unsafe {
55                asm!(
56                    // Assembly comment to "use" the argument
57                    "/* {0} */",
58                    // Use an `xmm_reg` since LLVM almost always uses such a register for
59                    // scalar floating point
60                    in(xmm_reg) x,
61                    // Explanation:
62                    // * `nostack`: This function does not touch the stack, so the compiler
63                    //   does not need to worry that the stack gets messed up.
64                    // * `nomem`: This function does not touch memory. The compiler doesn't
65                    //   have to reload any values.
66                    // * `preserves_flags`: This function preserves architectural condition
67                    //   flags. We can make this guarantee because this function literally
68                    //   does nothing.
69                    options(nostack, nomem, preserves_flags)
70                )
71            }
72        }
73    } else {
74        // Fallback implementation.
75        fn force_eval(_x: f32) {}
76    }
77}
78
79/// A utility struct to help with SIMD loading.
80///
81/// The main loop of SIMD kernels consists of various tilings of loads and arithmetic.
82/// Outside of the epilogue, these loads are all full-width vector loads.
83///
84/// To aid in defining different tilings, this struct takes the base pointers for left and
85/// right hand pointers and provides a `load` method to extract full vectors for both
86/// the left and right-hand sides.
87///
88/// This works in conjunction with the [`SIMDSchema`] to help write unrolled loops.
89#[derive(Debug, Clone, Copy)]
90pub struct Loader<Schema, Left, Right, A>
91where
92    Schema: SIMDSchema<Left, Right, A>,
93    A: Architecture,
94{
95    arch: A,
96    schema: Schema,
97    left: *const Left,
98    right: *const Right,
99    len: usize,
100}
101
102impl<Schema, Left, Right, A> Loader<Schema, Left, Right, A>
103where
104    Schema: SIMDSchema<Left, Right, A>,
105    A: Architecture,
106{
107    /// Construct a new loader for the left and right hand pointers.
108    ///
109    /// Requires that the memory ranges `[left, left + len)` and `[right, right + len)` are
110    /// both valid, where `len` is the *number* of the elements of type `T` and `U`.
111    #[inline(always)]
112    fn new(arch: A, schema: Schema, left: *const Left, right: *const Right, len: usize) -> Self {
113        Self {
114            arch,
115            schema,
116            left,
117            right,
118            len,
119        }
120    }
121
122    /// Return the underlying architecture.
123    #[inline(always)]
124    fn arch(&self) -> A {
125        self.arch
126    }
127
128    /// Return the SIMD Schema.
129    #[inline(always)]
130    fn schema(&self) -> Schema {
131        self.schema
132    }
133
134    /// Load full width vectors for the left and right hand memory spans.
135    ///
136    /// This loads a [`SIMDSchema::SIMDWidth`] chunk of data using the following formula:
137    ///
138    /// ```text
139    /// // The number of elements in an unrolled [`MainLoop`].
140    /// let simd_width = Schema::SIMDWidth::value();
141    /// let block_size = simd_width * Schema::Main::BLOCK_SIZE;
142    ///
143    /// load(px + block_size * block + simd_width * offset);
144    /// ```
145    ///
146    /// # Safety
147    ///
148    /// Requires that the following memory addresses are in-bounds (i.e., the highest
149    /// read address is at an offset less than `len`):
150    ///
151    /// ```text
152    /// [
153    ///     px + block_size * block + simd_width * offset,
154    ///     px + block_size * block + simd_width * (offset + 1)
155    /// )
156    ///
157    /// [
158    ///     py + block_size * block + simd_width * offset,
159    ///     py + block_size * block + simd_width * (offset + 1)
160    /// )
161    /// ```
162    ///
163    /// This invariant is checked in debug builds.
164    #[inline(always)]
165    unsafe fn load(&self, block: usize, offset: usize) -> (Schema::Left, Schema::Right) {
166        let stride = Schema::SIMDWidth::value();
167        let block_stride = stride * Schema::Main::BLOCK_SIZE;
168        let offset = block_stride * block + stride * offset;
169
170        debug_assert!(
171            offset + stride <= self.len,
172            "length = {}, offset = {}",
173            self.len,
174            offset
175        );
176
177        (
178            Schema::Left::load_simd(self.arch, self.left.add(offset)),
179            Schema::Right::load_simd(self.arch, self.right.add(offset)),
180        )
181    }
182}
183
184/// A representation of the main unrolled-loop for SIMD kernels.
185pub trait MainLoop {
186    /// The effective number of unrolling (in terms of SIMD vectors) performed by this
187    /// kernel. For example, if `BLOCK_SIZE = 4` and the SIMD width is 8, than each iteration
188    /// of the main loop will process `4 * 8 = 32` elements.
189    ///
190    /// This parameter will be used to compute the number of full-width epilogues that need
191    /// to be executed.
192    const BLOCK_SIZE: usize;
193
194    /// Perform the main unrolled loops of a SIMD kernel. This loop is expected to process
195    /// all elements in the range `[0, trip_count * S::get_simd_width() * Self::BLOCK_SIZE)`
196    /// and return an accumulator consisting of the result.
197    ///
198    /// # Arguments
199    ///
200    /// * `loader`: A SIMD loader to emit loads to the two source spans.
201    /// * `trip_count`: The number of blocks of size `BLOCK_SIZE` to process. A single "trip"
202    ///   will process `S::get_simd_width() * Self::BLOCK_SIZE` elements. So, computation of
203    ///   `trip_count` should be computed as:
204    ///   ```math
205    ///   let trip_count = len / (S::get_simd_width() * <_ as MainLoop>::BLOCK_SIZE);
206    ///   ```
207    /// * `epilogues`: The number of `S::get_simd_width()` vectors remaining after all the
208    ///   main blocks have been processed. This is guaranteed to be less than
209    ///   `Self::BLOCK_SIZE`.
210    ///
211    /// # Safety
212    ///
213    /// All elements in the accessed range must be valid. The memory addresses touched are
214    ///
215    /// ```text
216    /// let block_size = Self::BLOCK_SIZE;
217    /// let simd_width = S::get_simd_width();
218    /// [
219    ///     loader.left,
220    ///     loader.left + trip_count * simd_width + block_size + epilogues * simd_width
221    /// )
222    /// [
223    ///     loader.right,
224    ///     loader.right + trip_count * simd_width + block_size + epilogues * simd_width
225    /// )
226    /// ```
227    ///
228    /// The `loader` will ensure that all accesses are in-bounds in debug builds.
229    unsafe fn main<S, L, R, A>(
230        loader: &Loader<S, L, R, A>,
231        trip_count: usize,
232        epilogues: usize,
233    ) -> S::Accumulator
234    where
235        A: Architecture,
236        S: SIMDSchema<L, R, A>;
237}
238/// An inner loop implementation strategy using 1 parallel instances of the schema
239/// accumulator with a manual inner loop unroll of 1.
240pub struct Strategy1x1;
241
242/// An inner loop implementation strategy using 2 parallel instances of the schema
243/// accumulator with a manual inner loop unroll of 1.
244pub struct Strategy2x1;
245
246/// An inner loop implementation strategy using 4 parallel instances of the schema
247/// accumulator with a manual inner loop unroll of 1.
248pub struct Strategy4x1;
249
250/// An inner loop implementation strategy using 4 parallel instances of the schema
251/// accumulator with a manual inner loop unroll of 2.
252pub struct Strategy4x2;
253
254/// An inner loop implementation strategy using 2 parallel instances of the schema
255/// accumulator with a manual inner loop unroll of 4.
256pub struct Strategy2x4;
257
258impl MainLoop for Strategy1x1 {
259    const BLOCK_SIZE: usize = 1;
260
261    #[inline(always)]
262    unsafe fn main<S, L, R, A>(
263        loader: &Loader<S, L, R, A>,
264        trip_count: usize,
265        _epilogues: usize,
266    ) -> S::Accumulator
267    where
268        A: Architecture,
269        S: SIMDSchema<L, R, A>,
270    {
271        let arch = loader.arch();
272        let schema = loader.schema();
273
274        let mut s0 = schema.init(arch);
275        for i in 0..trip_count {
276            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
277        }
278
279        s0
280    }
281}
282
283impl MainLoop for Strategy2x1 {
284    const BLOCK_SIZE: usize = 2;
285
286    #[inline(always)]
287    unsafe fn main<S, L, R, A>(
288        loader: &Loader<S, L, R, A>,
289        trip_count: usize,
290        epilogues: usize,
291    ) -> S::Accumulator
292    where
293        A: Architecture,
294        S: SIMDSchema<L, R, A>,
295    {
296        let arch = loader.arch();
297        let schema = loader.schema();
298
299        let mut s0 = schema.init(arch);
300        let mut s1 = schema.init(arch);
301
302        for i in 0..trip_count {
303            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
304            s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
305        }
306
307        let mut s = schema.combine(s0, s1);
308        if epilogues != 0 {
309            s = schema.accumulate_tuple(s, loader.load(trip_count, 0));
310        }
311
312        s
313    }
314}
315
316impl MainLoop for Strategy4x1 {
317    const BLOCK_SIZE: usize = 4;
318
319    #[inline(always)]
320    unsafe fn main<S, L, R, A>(
321        loader: &Loader<S, L, R, A>,
322        trip_count: usize,
323        epilogues: usize,
324    ) -> S::Accumulator
325    where
326        A: Architecture,
327        S: SIMDSchema<L, R, A>,
328    {
329        let arch = loader.arch();
330        let schema = loader.schema();
331
332        let mut s0 = schema.init(arch);
333        let mut s1 = schema.init(arch);
334        let mut s2 = schema.init(arch);
335        let mut s3 = schema.init(arch);
336
337        for i in 0..trip_count {
338            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
339            s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
340            s2 = schema.accumulate_tuple(s2, loader.load(i, 2));
341            s3 = schema.accumulate_tuple(s3, loader.load(i, 3));
342        }
343
344        if epilogues >= 1 {
345            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
346        }
347
348        if epilogues >= 2 {
349            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
350        }
351
352        if epilogues >= 3 {
353            s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
354        }
355
356        schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
357    }
358}
359
360impl MainLoop for Strategy4x2 {
361    const BLOCK_SIZE: usize = 4;
362
363    #[inline(always)]
364    unsafe fn main<S, L, R, A>(
365        loader: &Loader<S, L, R, A>,
366        trip_count: usize,
367        epilogues: usize,
368    ) -> S::Accumulator
369    where
370        A: Architecture,
371        S: SIMDSchema<L, R, A>,
372    {
373        let arch = loader.arch();
374        let schema = loader.schema();
375
376        let mut s0 = schema.init(arch);
377        let mut s1 = schema.init(arch);
378        let mut s2 = schema.init(arch);
379        let mut s3 = schema.init(arch);
380
381        for i in 0..(trip_count / 2) {
382            let j = 2 * i;
383            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
384            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
385            s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
386            s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
387
388            s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
389            s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
390            s2 = schema.accumulate_tuple(s2, loader.load(j, 6));
391            s3 = schema.accumulate_tuple(s3, loader.load(j, 7));
392        }
393
394        if !trip_count.is_multiple_of(2) {
395            // Will not underflow because `trip_count` is odd.
396            let j = trip_count - 1;
397            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
398            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
399            s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
400            s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
401        }
402
403        if epilogues >= 1 {
404            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
405        }
406
407        if epilogues >= 2 {
408            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
409        }
410
411        if epilogues >= 3 {
412            s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
413        }
414
415        schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
416    }
417}
418
419impl MainLoop for Strategy2x4 {
420    const BLOCK_SIZE: usize = 4;
421
422    /// The implementation here has a global unroll of 4, but the unroll factor of the main
423    /// loop is actually 8.
424    ///
425    /// There is a single peeled iteration at the end that handles the last group of 4
426    /// if needed.
427    #[inline(always)]
428    unsafe fn main<S, L, R, A>(
429        loader: &Loader<S, L, R, A>,
430        trip_count: usize,
431        epilogues: usize,
432    ) -> S::Accumulator
433    where
434        A: Architecture,
435        S: SIMDSchema<L, R, A>,
436    {
437        let arch = loader.arch();
438        let schema = loader.schema();
439
440        let mut s0 = schema.init(arch);
441        let mut s1 = schema.init(arch);
442
443        for i in 0..(trip_count / 2) {
444            let j = 2 * i;
445            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
446            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
447            s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
448            s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
449
450            s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
451            s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
452            s0 = schema.accumulate_tuple(s0, loader.load(j, 6));
453            s1 = schema.accumulate_tuple(s1, loader.load(j, 7));
454        }
455
456        if !trip_count.is_multiple_of(2) {
457            let j = trip_count - 1;
458            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
459            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
460            s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
461            s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
462        }
463
464        if epilogues >= 1 {
465            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
466        }
467
468        if epilogues >= 2 {
469            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
470        }
471
472        if epilogues >= 3 {
473            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 2));
474        }
475
476        schema.combine(s0, s1)
477    }
478}
479
480/// An interface trait for SIMD operations.
481///
482/// Patterns like unrolling, pointer arithmetic, and epilogue handling are common across
483/// many different combinations of left and right hand types for distance computations.
484///
485/// This higher level handling is delegated to functions like `simd_op`, which in turn
486/// uses a `SIMDSchema` to customize the mechanics of loading and accumulation.
487pub trait SIMDSchema<T, U, A: Architecture = diskann_wide::arch::Current>: Copy {
488    /// The desired SIMD read width.
489    /// Reads from the input slice will be use this stride when accessing memory.
490    type SIMDWidth: Constant<Type = usize>;
491
492    /// The type used to represent partial accumulated values.
493    type Accumulator: std::ops::Add<Output = Self::Accumulator> + std::fmt::Debug + Copy;
494
495    /// The type used for the left-hand side.
496    type Left: SIMDVector<Arch = A, Scalar = T, ConstLanes = Self::SIMDWidth>;
497
498    /// The type used for the right-hand side.
499    type Right: SIMDVector<Arch = A, Scalar = U, ConstLanes = Self::SIMDWidth>;
500
501    /// The final return type.
502    /// This is often `f32` for complete distance functions, but need not always be.
503    type Return;
504
505    /// The implementation of the main loop.
506    type Main: MainLoop;
507
508    /// Initialize an empty (identity) accumulator.
509    fn init(&self, arch: A) -> Self::Accumulator;
510
511    /// Perform an accumulation.
512    fn accumulate(
513        &self,
514        x: Self::Left,
515        y: Self::Right,
516        acc: Self::Accumulator,
517    ) -> Self::Accumulator;
518
519    /// Combine two independent accumulators (allows for unrolling).
520    #[inline(always)]
521    fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
522        x + y
523    }
524
525    /// A supplied trait for dealing with non-full-width epilogues.
526    /// Often, masked based loading will do the right thing, but for architectures like AVX2
527    /// that have limited support for masking 8 and 16-bit operations, using a scalar
528    /// fallback may just be better.
529    ///
530    /// This provides a customization point to enable a scalar fallback.
531    ///
532    /// # Safety
533    ///
534    /// * Both pointers `x` and `y` must point to memory.
535    /// * It must be safe to read `len` contiguous items of type `T` starting at `x` and
536    ///   `len` contiguous items of type `U` starting at `y`.
537    ///
538    /// The following guarantee is made:
539    ///
540    /// * No read will be emitted to memory locations at and after `x.add(len)` and
541    ///   `y.add(len)`.
542    #[inline(always)]
543    unsafe fn epilogue(
544        &self,
545        arch: A,
546        x: *const T,
547        y: *const U,
548        len: usize,
549        acc: Self::Accumulator,
550    ) -> Self::Accumulator {
551        // SAFETY: Performing this read is safe by the safety preconditions of `epilogue`.
552        // Guarentee: The load implementation must be correct.
553        let a = Self::Left::load_simd_first(arch, x, len);
554
555        // SAFETY: Performing this read is safe by the safety preconditions of `epilogue`.
556        // Guarentee: The load implementation must be correct.
557        let b = Self::Right::load_simd_first(arch, y, len);
558        self.accumulate(a, b, acc)
559    }
560
561    /// Perform a reduction on the accumulator to yield the final result.
562    ///
563    /// This will be called at the end of distance processing.
564    fn reduce(&self, x: Self::Accumulator) -> Self::Return;
565
566    /// !! Do not extend this function !!
567    ///
568    /// Due to limitations on how associated constants can be used, we need a function
569    /// to access the SIMD width and rely on the compiler to constant propagate the result.
570    #[inline(always)]
571    fn get_simd_width() -> usize {
572        Self::SIMDWidth::value()
573    }
574
575    /// !! Do not extend this function !!
576    ///
577    /// Due to limitations on how associated constants can be used, we need a function
578    /// to access the unroll factor of the main loop and rely on the compiler to constant
579    /// propagate the result.
580    #[inline(always)]
581    fn get_main_bocksize() -> usize {
582        Self::Main::BLOCK_SIZE
583    }
584
585    /// A helper method to access [`Self::accumulate`] in a way that is immediately
586    /// compatible with [`Loader::load`].
587    #[doc(hidden)]
588    #[inline(always)]
589    fn accumulate_tuple(
590        &self,
591        acc: Self::Accumulator,
592        (x, y): (Self::Left, Self::Right),
593    ) -> Self::Accumulator {
594        self.accumulate(x, y, acc)
595    }
596}
597
598/// In some contexts - it can be beneficial to begin a computation on one pair of slices and
599/// then store intermediate state for resumption on another pair of slices.
600///
601/// A good example of this is direct-computation of PQ distances where different chunks need
602/// to be gathered and partially accumulated before the final reduction.
603///
604/// The `ResumableSchema` provides a relatively straight-forward way of achieving this.
605pub trait ResumableSIMDSchema<T, U, A = diskann_wide::arch::Current>: Copy
606where
607    A: Architecture,
608{
609    // The associated type for this function that is non-reentrant.
610    type NonResumable: SIMDSchema<T, U, A> + Default;
611    type FinalReturn;
612
613    fn init(arch: A) -> Self;
614    fn combine_with(&self, other: <Self::NonResumable as SIMDSchema<T, U, A>>::Accumulator)
615        -> Self;
616    fn sum(&self) -> Self::FinalReturn;
617}
618
619#[derive(Debug, Clone, Copy)]
620pub struct Resumable<T>(T);
621
622impl<T> Resumable<T> {
623    pub fn new(val: T) -> Self {
624        Self(val)
625    }
626
627    pub fn consume(self) -> T {
628        self.0
629    }
630}
631
632impl<T, U, R, A> SIMDSchema<T, U, A> for Resumable<R>
633where
634    A: Architecture,
635    R: ResumableSIMDSchema<T, U, A>,
636{
637    type SIMDWidth = <R::NonResumable as SIMDSchema<T, U, A>>::SIMDWidth;
638    type Accumulator = <R::NonResumable as SIMDSchema<T, U, A>>::Accumulator;
639    type Left = <R::NonResumable as SIMDSchema<T, U, A>>::Left;
640    type Right = <R::NonResumable as SIMDSchema<T, U, A>>::Right;
641    type Return = Self;
642    type Main = <R::NonResumable as SIMDSchema<T, U, A>>::Main;
643
644    fn init(&self, arch: A) -> Self::Accumulator {
645        R::NonResumable::default().init(arch)
646    }
647
648    fn accumulate(
649        &self,
650        x: Self::Left,
651        y: Self::Right,
652        acc: Self::Accumulator,
653    ) -> Self::Accumulator {
654        R::NonResumable::default().accumulate(x, y, acc)
655    }
656
657    fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
658        R::NonResumable::default().combine(x, y)
659    }
660
661    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
662        Self(self.0.combine_with(x))
663    }
664}
665
666#[inline(never)]
667#[allow(clippy::panic)]
668fn emit_length_error(xlen: usize, ylen: usize) -> ! {
669    panic!(
670        "lengths must be equal, instead got: xlen = {}, ylen = {}",
671        xlen, ylen
672    )
673}
674
675/// A SIMD executor for binary ops using the provided `SIMDSchema`.
676///
677/// # Panics
678///
679/// Panics if `x.len() != y.len()`.
680#[inline(always)]
681pub fn simd_op<L, R, S, T, U, A>(schema: &S, arch: A, x: T, y: U) -> S::Return
682where
683    A: Architecture,
684    T: AsRef<[L]>,
685    U: AsRef<[R]>,
686    S: SIMDSchema<L, R, A>,
687{
688    let x: &[L] = x.as_ref();
689    let y: &[R] = y.as_ref();
690
691    let len = x.len();
692
693    // The two lengths of the vectors must be the same.
694    // Eventually - it will probably be worth looking into various wrapper functions for
695    // `simd_op` that perform this checking, but for now, consider providing two
696    // different-length slices as a hard program bug.
697    //
698    // N.B.: Redirect through `emit_length_error` to keep code generation as clean as
699    // possible.
700    if len != y.len() {
701        emit_length_error(len, y.len());
702    }
703    let px = x.as_ptr();
704    let py = y.as_ptr();
705
706    // N.B.: Due to limitations in Rust's handling of const generics (and outer type
707    // parameters), we cannot just reach into `S` and pull out the constant SIMDWidth.
708    //
709    // Instead, we need to go through a helper function. Since associated functions cannot
710    // be marked as `const`, we cannot require that the extracted width is evaluated at
711    // compile time.
712    //
713    // HOWEVER, compilers are very good at optimizing these kinds of patterns and
714    // recognizing that this value is indeed constant and optimizing accordingly.
715    let simd_width: usize = S::get_simd_width();
716    let unroll: usize = S::get_main_bocksize();
717
718    let trip_count = len / (simd_width * unroll);
719    let epilogues = (len - simd_width * unroll * trip_count) / simd_width;
720
721    // Create a loader that (in debug mode) will check that all of our full-width accesses
722    // are in-bounds.
723    let loader: Loader<S, L, R, A> = Loader::new(arch, *schema, px, py, len);
724
725    // SAFETY: The value of `trip_count`  and `epilogues` so
726    // `[0, trip_count * simd_width * unroll + epilogues * simd_width)` is in-bounds,
727    // satifying the requirements of `main`.
728    let mut s0 = unsafe { <S as SIMDSchema<L, R, A>>::Main::main(&loader, trip_count, epilogues) };
729
730    let remainder = len % simd_width;
731    if remainder != 0 {
732        let i = len - remainder;
733
734        // SAFETY: We have ensured that the lengths of the two inputs are the same.
735        //
736        // Furthermore, preceding computations on the induction variable mean that the
737        // remaining memory must be valid.
738        s0 = unsafe { schema.epilogue(arch, px.add(i), py.add(i), remainder, s0) };
739    }
740
741    schema.reduce(s0)
742}
743
744/////
745///// L2 Implementations
746/////
747
748// A pure L2 distance function that provides a final reduction.
749#[derive(Debug, Default, Clone, Copy)]
750pub struct L2;
751
752#[cfg(target_arch = "x86_64")]
753impl SIMDSchema<f32, f32, V4> for L2 {
754    type SIMDWidth = Const<8>;
755    type Accumulator = <V4 as Architecture>::f32x8;
756    type Left = <V4 as Architecture>::f32x8;
757    type Right = <V4 as Architecture>::f32x8;
758    type Return = f32;
759    type Main = Strategy4x1;
760
761    #[inline(always)]
762    fn init(&self, arch: V4) -> Self::Accumulator {
763        Self::Accumulator::default(arch)
764    }
765
766    #[inline(always)]
767    fn accumulate(
768        &self,
769        x: Self::Left,
770        y: Self::Right,
771        acc: Self::Accumulator,
772    ) -> Self::Accumulator {
773        let c = x - y;
774        c.mul_add_simd(c, acc)
775    }
776
777    #[inline(always)]
778    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
779        x.sum_tree()
780    }
781}
782
783#[cfg(target_arch = "x86_64")]
784impl SIMDSchema<f32, f32, V3> for L2 {
785    type SIMDWidth = Const<8>;
786    type Accumulator = <V3 as Architecture>::f32x8;
787    type Left = <V3 as Architecture>::f32x8;
788    type Right = <V3 as Architecture>::f32x8;
789    type Return = f32;
790    type Main = Strategy4x1;
791
792    #[inline(always)]
793    fn init(&self, arch: V3) -> Self::Accumulator {
794        Self::Accumulator::default(arch)
795    }
796
797    #[inline(always)]
798    fn accumulate(
799        &self,
800        x: Self::Left,
801        y: Self::Right,
802        acc: Self::Accumulator,
803    ) -> Self::Accumulator {
804        let c = x - y;
805        c.mul_add_simd(c, acc)
806    }
807
808    #[inline(always)]
809    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
810        x.sum_tree()
811    }
812}
813
814impl SIMDSchema<f32, f32, Scalar> for L2 {
815    type SIMDWidth = Const<4>;
816    type Accumulator = Emulated<f32, 4>;
817    type Left = Emulated<f32, 4>;
818    type Right = Emulated<f32, 4>;
819    type Return = f32;
820    type Main = Strategy2x1;
821
822    #[inline(always)]
823    fn init(&self, arch: Scalar) -> Self::Accumulator {
824        Self::Accumulator::default(arch)
825    }
826
827    #[inline(always)]
828    fn accumulate(
829        &self,
830        x: Self::Left,
831        y: Self::Right,
832        acc: Self::Accumulator,
833    ) -> Self::Accumulator {
834        // Don't assume the presence of FMA.
835        let c = x - y;
836        (c * c) + acc
837    }
838
839    #[inline(always)]
840    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
841        x.sum_tree()
842    }
843
844    #[inline(always)]
845    unsafe fn epilogue(
846        &self,
847        arch: Scalar,
848        x: *const f32,
849        y: *const f32,
850        len: usize,
851        acc: Self::Accumulator,
852    ) -> Self::Accumulator {
853        let mut s: f32 = 0.0;
854        for i in 0..len {
855            // SAFETY: The range `[x, x.add(len))` is valid for reads.
856            let vx = unsafe { x.add(i).read() };
857            // SAFETY: The range `[y, y.add(len))` is valid for reads.
858            let vy = unsafe { y.add(i).read() };
859            let d = vx - vy;
860            s += d * d;
861        }
862        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
863    }
864}
865
866#[cfg(target_arch = "x86_64")]
867impl SIMDSchema<Half, Half, V4> for L2 {
868    type SIMDWidth = Const<8>;
869    type Accumulator = <V4 as Architecture>::f32x8;
870    type Left = <V4 as Architecture>::f16x8;
871    type Right = <V4 as Architecture>::f16x8;
872    type Return = f32;
873    type Main = Strategy2x4;
874
875    #[inline(always)]
876    fn init(&self, arch: V4) -> Self::Accumulator {
877        Self::Accumulator::default(arch)
878    }
879
880    #[inline(always)]
881    fn accumulate(
882        &self,
883        x: Self::Left,
884        y: Self::Right,
885        acc: Self::Accumulator,
886    ) -> Self::Accumulator {
887        diskann_wide::alias!(f32s = <V4>::f32x8);
888
889        let x: f32s = x.into();
890        let y: f32s = y.into();
891
892        let c = x - y;
893        c.mul_add_simd(c, acc)
894    }
895
896    #[inline(always)]
897    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
898        x.sum_tree()
899    }
900}
901
902#[cfg(target_arch = "x86_64")]
903impl SIMDSchema<Half, Half, V3> for L2 {
904    type SIMDWidth = Const<8>;
905    type Accumulator = <V3 as Architecture>::f32x8;
906    type Left = <V3 as Architecture>::f16x8;
907    type Right = <V3 as Architecture>::f16x8;
908    type Return = f32;
909    type Main = Strategy2x4;
910
911    #[inline(always)]
912    fn init(&self, arch: V3) -> Self::Accumulator {
913        Self::Accumulator::default(arch)
914    }
915
916    #[inline(always)]
917    fn accumulate(
918        &self,
919        x: Self::Left,
920        y: Self::Right,
921        acc: Self::Accumulator,
922    ) -> Self::Accumulator {
923        diskann_wide::alias!(f32s = <V3>::f32x8);
924
925        let x: f32s = x.into();
926        let y: f32s = y.into();
927
928        let c = x - y;
929        c.mul_add_simd(c, acc)
930    }
931
932    // Perform a final reduction.
933    #[inline(always)]
934    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
935        x.sum_tree()
936    }
937}
938
939impl SIMDSchema<Half, Half, Scalar> for L2 {
940    type SIMDWidth = Const<1>;
941    type Accumulator = Emulated<f32, 1>;
942    type Left = Emulated<Half, 1>;
943    type Right = Emulated<Half, 1>;
944    type Return = f32;
945    type Main = Strategy1x1;
946
947    #[inline(always)]
948    fn init(&self, arch: Scalar) -> Self::Accumulator {
949        Self::Accumulator::default(arch)
950    }
951
952    #[inline(always)]
953    fn accumulate(
954        &self,
955        x: Self::Left,
956        y: Self::Right,
957        acc: Self::Accumulator,
958    ) -> Self::Accumulator {
959        let x: Self::Accumulator = x.into();
960        let y: Self::Accumulator = y.into();
961
962        let c = x - y;
963        acc + (c * c)
964    }
965
966    #[inline(always)]
967    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
968        x.to_array()[0]
969    }
970}
971
972impl<A> SIMDSchema<f32, Half, A> for L2
973where
974    A: Architecture,
975{
976    type SIMDWidth = Const<8>;
977    type Accumulator = A::f32x8;
978    type Left = A::f32x8;
979    type Right = A::f16x8;
980    type Return = f32;
981    type Main = Strategy4x2;
982
983    #[inline(always)]
984    fn init(&self, arch: A) -> Self::Accumulator {
985        Self::Accumulator::default(arch)
986    }
987
988    #[inline(always)]
989    fn accumulate(
990        &self,
991        x: Self::Left,
992        y: Self::Right,
993        acc: Self::Accumulator,
994    ) -> Self::Accumulator {
995        let y: A::f32x8 = y.into();
996        let c = x - y;
997        c.mul_add_simd(c, acc)
998    }
999
1000    // Perform a final reduction.
1001    #[inline(always)]
1002    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1003        x.sum_tree()
1004    }
1005}
1006
1007#[cfg(target_arch = "x86_64")]
1008impl SIMDSchema<i8, i8, V4> for L2 {
1009    type SIMDWidth = Const<32>;
1010    type Accumulator = <V4 as Architecture>::i32x16;
1011    type Left = <V4 as Architecture>::i8x32;
1012    type Right = <V4 as Architecture>::i8x32;
1013    type Return = f32;
1014    type Main = Strategy4x1;
1015
1016    #[inline(always)]
1017    fn init(&self, arch: V4) -> Self::Accumulator {
1018        Self::Accumulator::default(arch)
1019    }
1020
1021    #[inline(always)]
1022    fn accumulate(
1023        &self,
1024        x: Self::Left,
1025        y: Self::Right,
1026        acc: Self::Accumulator,
1027    ) -> Self::Accumulator {
1028        diskann_wide::alias!(i16s = <V4>::i16x32);
1029
1030        let x: i16s = x.into();
1031        let y: i16s = y.into();
1032        let c = x - y;
1033        acc.dot_simd(c, c)
1034    }
1035
1036    #[inline(always)]
1037    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1038        x.sum_tree().as_f32_lossy()
1039    }
1040}
1041
1042#[cfg(target_arch = "x86_64")]
1043impl SIMDSchema<i8, i8, V3> for L2 {
1044    type SIMDWidth = Const<16>;
1045    type Accumulator = <V3 as Architecture>::i32x8;
1046    type Left = <V3 as Architecture>::i8x16;
1047    type Right = <V3 as Architecture>::i8x16;
1048    type Return = f32;
1049    type Main = Strategy4x1;
1050
1051    #[inline(always)]
1052    fn init(&self, arch: V3) -> Self::Accumulator {
1053        Self::Accumulator::default(arch)
1054    }
1055
1056    #[inline(always)]
1057    fn accumulate(
1058        &self,
1059        x: Self::Left,
1060        y: Self::Right,
1061        acc: Self::Accumulator,
1062    ) -> Self::Accumulator {
1063        diskann_wide::alias!(i16s = <V3>::i16x16);
1064
1065        let x: i16s = x.into();
1066        let y: i16s = y.into();
1067        let c = x - y;
1068        acc.dot_simd(c, c)
1069    }
1070
1071    // Perform a final reduction.
1072    #[inline(always)]
1073    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1074        x.sum_tree().as_f32_lossy()
1075    }
1076}
1077
1078impl SIMDSchema<i8, i8, Scalar> for L2 {
1079    type SIMDWidth = Const<4>;
1080    type Accumulator = Emulated<i32, 4>;
1081    type Left = Emulated<i8, 4>;
1082    type Right = Emulated<i8, 4>;
1083    type Return = f32;
1084    type Main = Strategy1x1;
1085
1086    #[inline(always)]
1087    fn init(&self, arch: Scalar) -> Self::Accumulator {
1088        Self::Accumulator::default(arch)
1089    }
1090
1091    #[inline(always)]
1092    fn accumulate(
1093        &self,
1094        x: Self::Left,
1095        y: Self::Right,
1096        acc: Self::Accumulator,
1097    ) -> Self::Accumulator {
1098        let x: Self::Accumulator = x.into();
1099        let y: Self::Accumulator = y.into();
1100        let c = x - y;
1101        acc + c * c
1102    }
1103
1104    // Perform a final reduction.
1105    #[inline(always)]
1106    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1107        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1108    }
1109
1110    #[inline(always)]
1111    unsafe fn epilogue(
1112        &self,
1113        arch: Scalar,
1114        x: *const i8,
1115        y: *const i8,
1116        len: usize,
1117        acc: Self::Accumulator,
1118    ) -> Self::Accumulator {
1119        let mut s: i32 = 0;
1120        for i in 0..len {
1121            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1122            let vx: i32 = unsafe { x.add(i).read() }.into();
1123            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1124            let vy: i32 = unsafe { y.add(i).read() }.into();
1125            let d = vx - vy;
1126            s += d * d;
1127        }
1128        acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1129    }
1130}
1131
1132#[cfg(target_arch = "x86_64")]
1133impl SIMDSchema<u8, u8, V4> for L2 {
1134    type SIMDWidth = Const<32>;
1135    type Accumulator = <V4 as Architecture>::i32x16;
1136    type Left = <V4 as Architecture>::u8x32;
1137    type Right = <V4 as Architecture>::u8x32;
1138    type Return = f32;
1139    type Main = Strategy4x1;
1140
1141    #[inline(always)]
1142    fn init(&self, arch: V4) -> Self::Accumulator {
1143        Self::Accumulator::default(arch)
1144    }
1145
1146    #[inline(always)]
1147    fn accumulate(
1148        &self,
1149        x: Self::Left,
1150        y: Self::Right,
1151        acc: Self::Accumulator,
1152    ) -> Self::Accumulator {
1153        diskann_wide::alias!(i16s = <V4>::i16x32);
1154
1155        let x: i16s = x.into();
1156        let y: i16s = y.into();
1157        let c = x - y;
1158        acc.dot_simd(c, c)
1159    }
1160
1161    #[inline(always)]
1162    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1163        x.sum_tree().as_f32_lossy()
1164    }
1165}
1166
1167#[cfg(target_arch = "x86_64")]
1168impl SIMDSchema<u8, u8, V3> for L2 {
1169    type SIMDWidth = Const<16>;
1170    type Accumulator = <V3 as Architecture>::i32x8;
1171    type Left = <V3 as Architecture>::u8x16;
1172    type Right = <V3 as Architecture>::u8x16;
1173    type Return = f32;
1174    type Main = Strategy4x1;
1175
1176    #[inline(always)]
1177    fn init(&self, arch: V3) -> Self::Accumulator {
1178        Self::Accumulator::default(arch)
1179    }
1180
1181    #[inline(always)]
1182    fn accumulate(
1183        &self,
1184        x: Self::Left,
1185        y: Self::Right,
1186        acc: Self::Accumulator,
1187    ) -> Self::Accumulator {
1188        diskann_wide::alias!(i16s = <V3>::i16x16);
1189
1190        let x: i16s = x.into();
1191        let y: i16s = y.into();
1192        let c = x - y;
1193        acc.dot_simd(c, c)
1194    }
1195
1196    // Perform a final reduction.
1197    #[inline(always)]
1198    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1199        x.sum_tree().as_f32_lossy()
1200    }
1201}
1202
1203impl SIMDSchema<u8, u8, Scalar> for L2 {
1204    type SIMDWidth = Const<4>;
1205    type Accumulator = Emulated<i32, 4>;
1206    type Left = Emulated<u8, 4>;
1207    type Right = Emulated<u8, 4>;
1208    type Return = f32;
1209    type Main = Strategy1x1;
1210
1211    #[inline(always)]
1212    fn init(&self, arch: Scalar) -> Self::Accumulator {
1213        Self::Accumulator::default(arch)
1214    }
1215
1216    #[inline(always)]
1217    fn accumulate(
1218        &self,
1219        x: Self::Left,
1220        y: Self::Right,
1221        acc: Self::Accumulator,
1222    ) -> Self::Accumulator {
1223        let x: Self::Accumulator = x.into();
1224        let y: Self::Accumulator = y.into();
1225        let c = x - y;
1226        acc + c * c
1227    }
1228
1229    // Perform a final reduction.
1230    #[inline(always)]
1231    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1232        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1233    }
1234
1235    #[inline(always)]
1236    unsafe fn epilogue(
1237        &self,
1238        arch: Scalar,
1239        x: *const u8,
1240        y: *const u8,
1241        len: usize,
1242        acc: Self::Accumulator,
1243    ) -> Self::Accumulator {
1244        let mut s: i32 = 0;
1245        for i in 0..len {
1246            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1247            let vx: i32 = unsafe { x.add(i).read() }.into();
1248            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1249            let vy: i32 = unsafe { y.add(i).read() }.into();
1250            let d = vx - vy;
1251            s += d * d;
1252        }
1253        acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1254    }
1255}
1256
1257// A L2 distance function that defers a final reduction, allowing for a distance
1258// computation to take place across multiple slice pairs.
1259#[derive(Clone, Copy, Debug)]
1260pub struct ResumableL2<A = diskann_wide::arch::Current>
1261where
1262    A: Architecture,
1263    L2: SIMDSchema<f32, f32, A>,
1264{
1265    acc: <L2 as SIMDSchema<f32, f32, A>>::Accumulator,
1266}
1267
1268impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableL2<A>
1269where
1270    A: Architecture,
1271    L2: SIMDSchema<f32, f32, A, Return = f32>,
1272{
1273    type NonResumable = L2;
1274    type FinalReturn = f32;
1275
1276    #[inline(always)]
1277    fn init(arch: A) -> Self {
1278        Self { acc: L2.init(arch) }
1279    }
1280
1281    #[inline(always)]
1282    fn combine_with(&self, other: <L2 as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1283        Self {
1284            acc: self.acc + other,
1285        }
1286    }
1287
1288    #[inline(always)]
1289    fn sum(&self) -> f32 {
1290        L2.reduce(self.acc)
1291    }
1292}
1293
1294/////
1295///// IP Implementations
1296/////
1297
1298// A pure IP distance function that provides a final reduction.
1299#[derive(Clone, Copy, Debug, Default)]
1300pub struct IP;
1301
1302#[cfg(target_arch = "x86_64")]
1303impl SIMDSchema<f32, f32, V4> for IP {
1304    type SIMDWidth = Const<8>;
1305    type Accumulator = <V4 as Architecture>::f32x8;
1306    type Left = <V4 as Architecture>::f32x8;
1307    type Right = <V4 as Architecture>::f32x8;
1308    type Return = f32;
1309    type Main = Strategy4x1;
1310
1311    #[inline(always)]
1312    fn init(&self, arch: V4) -> Self::Accumulator {
1313        Self::Accumulator::default(arch)
1314    }
1315
1316    #[inline(always)]
1317    fn accumulate(
1318        &self,
1319        x: Self::Left,
1320        y: Self::Right,
1321        acc: Self::Accumulator,
1322    ) -> Self::Accumulator {
1323        x.mul_add_simd(y, acc)
1324    }
1325
1326    #[inline(always)]
1327    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1328        x.sum_tree()
1329    }
1330}
1331
1332#[cfg(target_arch = "x86_64")]
1333impl SIMDSchema<f32, f32, V3> for IP {
1334    type SIMDWidth = Const<8>;
1335    type Accumulator = <V3 as Architecture>::f32x8;
1336    type Left = <V3 as Architecture>::f32x8;
1337    type Right = <V3 as Architecture>::f32x8;
1338    type Return = f32;
1339    type Main = Strategy4x1;
1340
1341    #[inline(always)]
1342    fn init(&self, arch: V3) -> Self::Accumulator {
1343        Self::Accumulator::default(arch)
1344    }
1345
1346    #[inline(always)]
1347    fn accumulate(
1348        &self,
1349        x: Self::Left,
1350        y: Self::Right,
1351        acc: Self::Accumulator,
1352    ) -> Self::Accumulator {
1353        x.mul_add_simd(y, acc)
1354    }
1355
1356    // Perform a final reduction.
1357    #[inline(always)]
1358    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1359        x.sum_tree()
1360    }
1361}
1362
1363impl SIMDSchema<f32, f32, Scalar> for IP {
1364    type SIMDWidth = Const<4>;
1365    type Accumulator = Emulated<f32, 4>;
1366    type Left = Emulated<f32, 4>;
1367    type Right = Emulated<f32, 4>;
1368    type Return = f32;
1369    type Main = Strategy2x1;
1370
1371    #[inline(always)]
1372    fn init(&self, arch: Scalar) -> Self::Accumulator {
1373        Self::Accumulator::default(arch)
1374    }
1375
1376    #[inline(always)]
1377    fn accumulate(
1378        &self,
1379        x: Self::Left,
1380        y: Self::Right,
1381        acc: Self::Accumulator,
1382    ) -> Self::Accumulator {
1383        x * y + acc
1384    }
1385
1386    // Perform a final reduction.
1387    #[inline(always)]
1388    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1389        x.sum_tree()
1390    }
1391
1392    #[inline(always)]
1393    unsafe fn epilogue(
1394        &self,
1395        arch: Scalar,
1396        x: *const f32,
1397        y: *const f32,
1398        len: usize,
1399        acc: Self::Accumulator,
1400    ) -> Self::Accumulator {
1401        let mut s: f32 = 0.0;
1402        for i in 0..len {
1403            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1404            let vx = unsafe { x.add(i).read() };
1405            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1406            let vy = unsafe { y.add(i).read() };
1407            s += vx * vy;
1408        }
1409        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
1410    }
1411}
1412
1413#[cfg(target_arch = "x86_64")]
1414impl SIMDSchema<Half, Half, V4> for IP {
1415    type SIMDWidth = Const<8>;
1416    type Accumulator = <V4 as Architecture>::f32x8;
1417    type Left = <V4 as Architecture>::f16x8;
1418    type Right = <V4 as Architecture>::f16x8;
1419    type Return = f32;
1420    type Main = Strategy4x1;
1421
1422    #[inline(always)]
1423    fn init(&self, arch: V4) -> Self::Accumulator {
1424        Self::Accumulator::default(arch)
1425    }
1426
1427    #[inline(always)]
1428    fn accumulate(
1429        &self,
1430        x: Self::Left,
1431        y: Self::Right,
1432        acc: Self::Accumulator,
1433    ) -> Self::Accumulator {
1434        diskann_wide::alias!(f32s = <V4>::f32x8);
1435
1436        let x: f32s = x.into();
1437        let y: f32s = y.into();
1438        x.mul_add_simd(y, acc)
1439    }
1440
1441    #[inline(always)]
1442    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1443        x.sum_tree()
1444    }
1445}
1446
1447#[cfg(target_arch = "x86_64")]
1448impl SIMDSchema<Half, Half, V3> for IP {
1449    type SIMDWidth = Const<8>;
1450    type Accumulator = <V3 as Architecture>::f32x8;
1451    type Left = <V3 as Architecture>::f16x8;
1452    type Right = <V3 as Architecture>::f16x8;
1453    type Return = f32;
1454    type Main = Strategy2x4;
1455
1456    #[inline(always)]
1457    fn init(&self, arch: V3) -> Self::Accumulator {
1458        Self::Accumulator::default(arch)
1459    }
1460
1461    #[inline(always)]
1462    fn accumulate(
1463        &self,
1464        x: Self::Left,
1465        y: Self::Right,
1466        acc: Self::Accumulator,
1467    ) -> Self::Accumulator {
1468        diskann_wide::alias!(f32s = <V3>::f32x8);
1469
1470        let x: f32s = x.into();
1471        let y: f32s = y.into();
1472        x.mul_add_simd(y, acc)
1473    }
1474
1475    // Perform a final reduction.
1476    #[inline(always)]
1477    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1478        x.sum_tree()
1479    }
1480}
1481
1482impl SIMDSchema<Half, Half, Scalar> for IP {
1483    type SIMDWidth = Const<1>;
1484    type Accumulator = Emulated<f32, 1>;
1485    type Left = Emulated<Half, 1>;
1486    type Right = Emulated<Half, 1>;
1487    type Return = f32;
1488    type Main = Strategy1x1;
1489
1490    #[inline(always)]
1491    fn init(&self, arch: Scalar) -> Self::Accumulator {
1492        Self::Accumulator::default(arch)
1493    }
1494
1495    #[inline(always)]
1496    fn accumulate(
1497        &self,
1498        x: Self::Left,
1499        y: Self::Right,
1500        acc: Self::Accumulator,
1501    ) -> Self::Accumulator {
1502        let x: Self::Accumulator = x.into();
1503        let y: Self::Accumulator = y.into();
1504        x * y + acc
1505    }
1506
1507    #[inline(always)]
1508    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1509        x.to_array()[0]
1510    }
1511}
1512
1513impl<A> SIMDSchema<f32, Half, A> for IP
1514where
1515    A: Architecture,
1516{
1517    type SIMDWidth = Const<8>;
1518    type Accumulator = A::f32x8;
1519    type Left = A::f32x8;
1520    type Right = A::f16x8;
1521    type Return = f32;
1522    type Main = Strategy4x2;
1523
1524    #[inline(always)]
1525    fn init(&self, arch: A) -> Self::Accumulator {
1526        Self::Accumulator::default(arch)
1527    }
1528
1529    #[inline(always)]
1530    fn accumulate(
1531        &self,
1532        x: Self::Left,
1533        y: Self::Right,
1534        acc: Self::Accumulator,
1535    ) -> Self::Accumulator {
1536        let y: A::f32x8 = y.into();
1537        x.mul_add_simd(y, acc)
1538    }
1539
1540    // Perform a final reduction.
1541    #[inline(always)]
1542    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1543        x.sum_tree()
1544    }
1545}
1546
1547#[cfg(target_arch = "x86_64")]
1548impl SIMDSchema<i8, i8, V4> for IP {
1549    type SIMDWidth = Const<32>;
1550    type Accumulator = <V4 as Architecture>::i32x16;
1551    type Left = <V4 as Architecture>::i8x32;
1552    type Right = <V4 as Architecture>::i8x32;
1553    type Return = f32;
1554    type Main = Strategy4x1;
1555
1556    #[inline(always)]
1557    fn init(&self, arch: V4) -> Self::Accumulator {
1558        Self::Accumulator::default(arch)
1559    }
1560
1561    #[inline(always)]
1562    fn accumulate(
1563        &self,
1564        x: Self::Left,
1565        y: Self::Right,
1566        acc: Self::Accumulator,
1567    ) -> Self::Accumulator {
1568        diskann_wide::alias!(i16s = <V4>::i16x32);
1569
1570        let x: i16s = x.into();
1571        let y: i16s = y.into();
1572        acc.dot_simd(x, y)
1573    }
1574
1575    #[inline(always)]
1576    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1577        x.sum_tree().as_f32_lossy()
1578    }
1579}
1580
1581#[cfg(target_arch = "x86_64")]
1582impl SIMDSchema<i8, i8, V3> for IP {
1583    type SIMDWidth = Const<16>;
1584    type Accumulator = <V3 as Architecture>::i32x8;
1585    type Left = <V3 as Architecture>::i8x16;
1586    type Right = <V3 as Architecture>::i8x16;
1587    type Return = f32;
1588    type Main = Strategy4x1;
1589
1590    #[inline(always)]
1591    fn init(&self, arch: V3) -> Self::Accumulator {
1592        Self::Accumulator::default(arch)
1593    }
1594
1595    #[inline(always)]
1596    fn accumulate(
1597        &self,
1598        x: Self::Left,
1599        y: Self::Right,
1600        acc: Self::Accumulator,
1601    ) -> Self::Accumulator {
1602        diskann_wide::alias!(i16s = <V3>::i16x16);
1603
1604        let x: i16s = x.into();
1605        let y: i16s = y.into();
1606        acc.dot_simd(x, y)
1607    }
1608
1609    // Perform a final reduction.
1610    #[inline(always)]
1611    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1612        x.sum_tree().as_f32_lossy()
1613    }
1614}
1615
1616impl SIMDSchema<i8, i8, Scalar> for IP {
1617    type SIMDWidth = Const<1>;
1618    type Accumulator = Emulated<i32, 1>;
1619    type Left = Emulated<i8, 1>;
1620    type Right = Emulated<i8, 1>;
1621    type Return = f32;
1622    type Main = Strategy1x1;
1623
1624    #[inline(always)]
1625    fn init(&self, arch: Scalar) -> Self::Accumulator {
1626        Self::Accumulator::default(arch)
1627    }
1628
1629    #[inline(always)]
1630    fn accumulate(
1631        &self,
1632        x: Self::Left,
1633        y: Self::Right,
1634        acc: Self::Accumulator,
1635    ) -> Self::Accumulator {
1636        let x: Self::Accumulator = x.into();
1637        let y: Self::Accumulator = y.into();
1638        x * y + acc
1639    }
1640
1641    // Perform a final reduction.
1642    #[inline(always)]
1643    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1644        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1645    }
1646
1647    #[inline(always)]
1648    unsafe fn epilogue(
1649        &self,
1650        _arch: Scalar,
1651        _x: *const i8,
1652        _y: *const i8,
1653        _len: usize,
1654        _acc: Self::Accumulator,
1655    ) -> Self::Accumulator {
1656        unreachable!("The SIMD width is 1, so there should be no epilogue")
1657    }
1658}
1659
1660#[cfg(target_arch = "x86_64")]
1661impl SIMDSchema<u8, u8, V4> for IP {
1662    type SIMDWidth = Const<32>;
1663    type Accumulator = <V4 as Architecture>::i32x16;
1664    type Left = <V4 as Architecture>::u8x32;
1665    type Right = <V4 as Architecture>::u8x32;
1666    type Return = f32;
1667    type Main = Strategy4x1;
1668
1669    #[inline(always)]
1670    fn init(&self, arch: V4) -> Self::Accumulator {
1671        Self::Accumulator::default(arch)
1672    }
1673
1674    #[inline(always)]
1675    fn accumulate(
1676        &self,
1677        x: Self::Left,
1678        y: Self::Right,
1679        acc: Self::Accumulator,
1680    ) -> Self::Accumulator {
1681        diskann_wide::alias!(i16s = <V4>::i16x32);
1682
1683        let x: i16s = x.into();
1684        let y: i16s = y.into();
1685        acc.dot_simd(x, y)
1686    }
1687
1688    #[inline(always)]
1689    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1690        x.sum_tree().as_f32_lossy()
1691    }
1692}
1693
1694#[cfg(target_arch = "x86_64")]
1695impl SIMDSchema<u8, u8, V3> for IP {
1696    type SIMDWidth = Const<16>;
1697    type Accumulator = <V3 as Architecture>::i32x8;
1698    type Left = <V3 as Architecture>::u8x16;
1699    type Right = <V3 as Architecture>::u8x16;
1700    type Return = f32;
1701    type Main = Strategy4x1;
1702
1703    #[inline(always)]
1704    fn init(&self, arch: V3) -> Self::Accumulator {
1705        Self::Accumulator::default(arch)
1706    }
1707
1708    #[inline(always)]
1709    fn accumulate(
1710        &self,
1711        x: Self::Left,
1712        y: Self::Right,
1713        acc: Self::Accumulator,
1714    ) -> Self::Accumulator {
1715        diskann_wide::alias!(i16s = <V3>::i16x16);
1716
1717        // NOTE: Promotiving to `i16` rather than `u16` to hit specialized AVX2
1718        // instructions on x86 hardware.
1719        let x: i16s = x.into();
1720        let y: i16s = y.into();
1721        acc.dot_simd(x, y)
1722    }
1723
1724    // Perform a final reduction.
1725    #[inline(always)]
1726    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1727        x.sum_tree().as_f32_lossy()
1728    }
1729}
1730
1731impl SIMDSchema<u8, u8, Scalar> for IP {
1732    type SIMDWidth = Const<1>;
1733    type Accumulator = Emulated<i32, 1>;
1734    type Left = Emulated<u8, 1>;
1735    type Right = Emulated<u8, 1>;
1736    type Return = f32;
1737    type Main = Strategy1x1;
1738
1739    #[inline(always)]
1740    fn init(&self, arch: Scalar) -> Self::Accumulator {
1741        Self::Accumulator::default(arch)
1742    }
1743
1744    #[inline(always)]
1745    fn accumulate(
1746        &self,
1747        x: Self::Left,
1748        y: Self::Right,
1749        acc: Self::Accumulator,
1750    ) -> Self::Accumulator {
1751        let x: Self::Accumulator = x.into();
1752        let y: Self::Accumulator = y.into();
1753        x * y + acc
1754    }
1755
1756    // Perform a final reduction.
1757    #[inline(always)]
1758    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1759        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1760    }
1761
1762    #[inline(always)]
1763    unsafe fn epilogue(
1764        &self,
1765        _arch: Scalar,
1766        _x: *const u8,
1767        _y: *const u8,
1768        _len: usize,
1769        _acc: Self::Accumulator,
1770    ) -> Self::Accumulator {
1771        unreachable!("The SIMD width is 1, so there should be no epilogue")
1772    }
1773}
1774
1775// An IP distance function that defers a final reduction.
1776#[derive(Clone, Copy, Debug)]
1777pub struct ResumableIP<A = diskann_wide::arch::Current>
1778where
1779    A: Architecture,
1780    IP: SIMDSchema<f32, f32, A>,
1781{
1782    acc: <IP as SIMDSchema<f32, f32, A>>::Accumulator,
1783}
1784
1785impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableIP<A>
1786where
1787    A: Architecture,
1788    IP: SIMDSchema<f32, f32, A, Return = f32>,
1789{
1790    type NonResumable = IP;
1791    type FinalReturn = f32;
1792
1793    #[inline(always)]
1794    fn init(arch: A) -> Self {
1795        Self { acc: IP.init(arch) }
1796    }
1797
1798    #[inline(always)]
1799    fn combine_with(&self, other: <IP as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1800        Self {
1801            acc: self.acc + other,
1802        }
1803    }
1804
1805    #[inline(always)]
1806    fn sum(&self) -> f32 {
1807        IP.reduce(self.acc)
1808    }
1809}
1810
1811/////////////////////////////////
1812// Stateless Cosine Similarity //
1813/////////////////////////////////
1814
1815/// Accumulator of partial products for a full cosine distance computation (where
1816/// the norms of both the query and the dataset vector are computed on the fly).
1817#[derive(Debug, Clone, Copy)]
1818pub struct FullCosineAccumulator<T> {
1819    normx: T,
1820    normy: T,
1821    xy: T,
1822}
1823
1824impl<T> FullCosineAccumulator<T>
1825where
1826    T: SIMDVector
1827        + SIMDSumTree
1828        + SIMDMulAdd
1829        + std::ops::Mul<Output = T>
1830        + std::ops::Add<Output = T>,
1831    T::Scalar: LossyF32Conversion,
1832{
1833    #[inline(always)]
1834    pub fn new(arch: T::Arch) -> Self {
1835        // SAFETY: Zero initializing a SIMD vector is safe.
1836        let zero = T::default(arch);
1837        Self {
1838            normx: zero,
1839            normy: zero,
1840            xy: zero,
1841        }
1842    }
1843
1844    #[inline(always)]
1845    pub fn add_with(&self, x: T, y: T) -> Self {
1846        // SAFETY: Arithmetic on valid arguments is valid.
1847        FullCosineAccumulator {
1848            normx: x.mul_add_simd(x, self.normx),
1849            normy: y.mul_add_simd(y, self.normy),
1850            xy: x.mul_add_simd(y, self.xy),
1851        }
1852    }
1853
1854    #[inline(always)]
1855    pub fn add_with_unfused(&self, x: T, y: T) -> Self {
1856        // SAFETY: Arithmetic on valid arguments is valid.
1857        FullCosineAccumulator {
1858            normx: x * x + self.normx,
1859            normy: y * y + self.normy,
1860            xy: x * y + self.xy,
1861        }
1862    }
1863
1864    #[inline(always)]
1865    pub fn sum(&self) -> f32 {
1866        let normx = self.normx.sum_tree().as_f32_lossy();
1867        let normy = self.normy.sum_tree().as_f32_lossy();
1868
1869        // Evaluate the denominator early and use `force_eval`.
1870        // This will allow the long `sqrt` to be overlapped with some other instructions
1871        // rather than waiting at the end of the function.
1872        //
1873        // There is some worry of subnormal numbers, but we're optimizing for the common
1874        // case where norms are reasonable values.
1875        let denominator = normx.sqrt() * normy.sqrt();
1876        let prod = self.xy.sum_tree().as_f32_lossy();
1877
1878        // Force the final products to be completely computed before the range check.
1879        //
1880        // This prevents LLVM from trying to compute `normy` or `prod` *after* the check
1881        // to `normx`, which causes it to spill heavily to the stack.
1882        //
1883        // Unfortunately, this results in a reduction pattern that appears to be slightly
1884        // slower on AMD or Windows.
1885        force_eval(denominator);
1886        force_eval(prod);
1887
1888        // This basically checks if either norm is subnormal and if so, we treat the vector
1889        // as having norm zero.
1890        //
1891        // The reason to do this rather than checking `denominator` directly is to have
1892        // consistent behavior when one vector has a small norm (i.e., always treat it as
1893        // zero) rather than potentially changing behavior when the other vector has a very
1894        // large norm to compensate.
1895        if normx < f32::MIN_POSITIVE || normy < f32::MIN_POSITIVE {
1896            return 0.0;
1897        }
1898
1899        let v = prod / denominator;
1900        (-1.0f32).max(1.0f32.min(v))
1901    }
1902
1903    /// Compute the L2 distance from the partial products rather than the cosine similarity.
1904    #[inline(always)]
1905    pub fn sum_as_l2(&self) -> f32 {
1906        let normx = self.normx.sum_tree().as_f32_lossy();
1907        let normy = self.normy.sum_tree().as_f32_lossy();
1908        let xy = self.xy.sum_tree().as_f32_lossy();
1909        normx + normy - (xy + xy)
1910    }
1911}
1912
1913impl<T> std::ops::Add for FullCosineAccumulator<T>
1914where
1915    T: std::ops::Add<Output = T>,
1916{
1917    type Output = Self;
1918    #[inline(always)]
1919    fn add(self, other: Self) -> Self {
1920        FullCosineAccumulator {
1921            normx: self.normx + other.normx,
1922            normy: self.normy + other.normy,
1923            xy: self.xy + other.xy,
1924        }
1925    }
1926}
1927
1928/// A pure Cosine Similarity function that provides a final reduction.
1929#[derive(Default, Clone, Copy)]
1930pub struct CosineStateless;
1931
1932#[cfg(target_arch = "x86_64")]
1933impl SIMDSchema<f32, f32, V4> for CosineStateless {
1934    type SIMDWidth = Const<16>;
1935    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
1936    type Left = <V4 as Architecture>::f32x16;
1937    type Right = <V4 as Architecture>::f32x16;
1938    type Return = f32;
1939
1940    // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a
1941    // hefty unroll factor.
1942    type Main = Strategy2x4;
1943
1944    #[inline(always)]
1945    fn init(&self, arch: V4) -> Self::Accumulator {
1946        Self::Accumulator::new(arch)
1947    }
1948
1949    #[inline(always)]
1950    fn accumulate(
1951        &self,
1952        x: Self::Left,
1953        y: Self::Right,
1954        acc: Self::Accumulator,
1955    ) -> Self::Accumulator {
1956        acc.add_with(x, y)
1957    }
1958
1959    // Perform a final reduction.
1960    #[inline(always)]
1961    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
1962        acc.sum()
1963    }
1964}
1965
1966#[cfg(target_arch = "x86_64")]
1967impl SIMDSchema<f32, f32, V3> for CosineStateless {
1968    type SIMDWidth = Const<8>;
1969    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
1970    type Left = <V3 as Architecture>::f32x8;
1971    type Right = <V3 as Architecture>::f32x8;
1972    type Return = f32;
1973
1974    // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a
1975    // hefty unroll factor.
1976    type Main = Strategy2x4;
1977
1978    #[inline(always)]
1979    fn init(&self, arch: V3) -> Self::Accumulator {
1980        Self::Accumulator::new(arch)
1981    }
1982
1983    #[inline(always)]
1984    fn accumulate(
1985        &self,
1986        x: Self::Left,
1987        y: Self::Right,
1988        acc: Self::Accumulator,
1989    ) -> Self::Accumulator {
1990        acc.add_with(x, y)
1991    }
1992
1993    // Perform a final reduction.
1994    #[inline(always)]
1995    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
1996        acc.sum()
1997    }
1998}
1999
2000impl SIMDSchema<f32, f32, Scalar> for CosineStateless {
2001    type SIMDWidth = Const<4>;
2002    type Accumulator = FullCosineAccumulator<Emulated<f32, 4>>;
2003    type Left = Emulated<f32, 4>;
2004    type Right = Emulated<f32, 4>;
2005    type Return = f32;
2006
2007    type Main = Strategy2x1;
2008
2009    #[inline(always)]
2010    fn init(&self, arch: Scalar) -> Self::Accumulator {
2011        Self::Accumulator::new(arch)
2012    }
2013
2014    #[inline(always)]
2015    fn accumulate(
2016        &self,
2017        x: Self::Left,
2018        y: Self::Right,
2019        acc: Self::Accumulator,
2020    ) -> Self::Accumulator {
2021        acc.add_with_unfused(x, y)
2022    }
2023
2024    #[inline(always)]
2025    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2026        acc.sum()
2027    }
2028}
2029
2030#[cfg(target_arch = "x86_64")]
2031impl SIMDSchema<Half, Half, V4> for CosineStateless {
2032    type SIMDWidth = Const<16>;
2033    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2034    type Left = <V4 as Architecture>::f16x16;
2035    type Right = <V4 as Architecture>::f16x16;
2036    type Return = f32;
2037    type Main = Strategy2x4;
2038
2039    #[inline(always)]
2040    fn init(&self, arch: V4) -> Self::Accumulator {
2041        Self::Accumulator::new(arch)
2042    }
2043
2044    #[inline(always)]
2045    fn accumulate(
2046        &self,
2047        x: Self::Left,
2048        y: Self::Right,
2049        acc: Self::Accumulator,
2050    ) -> Self::Accumulator {
2051        diskann_wide::alias!(f32s = <V4>::f32x16);
2052
2053        let x: f32s = x.into();
2054        let y: f32s = y.into();
2055        acc.add_with(x, y)
2056    }
2057
2058    #[inline(always)]
2059    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2060        acc.sum()
2061    }
2062}
2063
2064#[cfg(target_arch = "x86_64")]
2065impl SIMDSchema<Half, Half, V3> for CosineStateless {
2066    type SIMDWidth = Const<8>;
2067    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2068    type Left = <V3 as Architecture>::f16x8;
2069    type Right = <V3 as Architecture>::f16x8;
2070    type Return = f32;
2071    type Main = Strategy2x4;
2072
2073    #[inline(always)]
2074    fn init(&self, arch: V3) -> Self::Accumulator {
2075        Self::Accumulator::new(arch)
2076    }
2077
2078    #[inline(always)]
2079    fn accumulate(
2080        &self,
2081        x: Self::Left,
2082        y: Self::Right,
2083        acc: Self::Accumulator,
2084    ) -> Self::Accumulator {
2085        diskann_wide::alias!(f32s = <V3>::f32x8);
2086
2087        let x: f32s = x.into();
2088        let y: f32s = y.into();
2089        acc.add_with(x, y)
2090    }
2091
2092    // Perform a final reduction.
2093    #[inline(always)]
2094    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2095        acc.sum()
2096    }
2097}
2098
2099impl SIMDSchema<Half, Half, Scalar> for CosineStateless {
2100    type SIMDWidth = Const<1>;
2101    type Accumulator = FullCosineAccumulator<Emulated<f32, 1>>;
2102    type Left = Emulated<Half, 1>;
2103    type Right = Emulated<Half, 1>;
2104    type Return = f32;
2105    type Main = Strategy1x1;
2106
2107    #[inline(always)]
2108    fn init(&self, arch: Scalar) -> Self::Accumulator {
2109        Self::Accumulator::new(arch)
2110    }
2111
2112    #[inline(always)]
2113    fn accumulate(
2114        &self,
2115        x: Self::Left,
2116        y: Self::Right,
2117        acc: Self::Accumulator,
2118    ) -> Self::Accumulator {
2119        let x: Emulated<f32, 1> = x.into();
2120        let y: Emulated<f32, 1> = y.into();
2121        acc.add_with_unfused(x, y)
2122    }
2123
2124    #[inline(always)]
2125    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2126        acc.sum()
2127    }
2128}
2129impl<A> SIMDSchema<f32, Half, A> for CosineStateless
2130where
2131    A: Architecture,
2132{
2133    type SIMDWidth = Const<8>;
2134    type Accumulator = FullCosineAccumulator<A::f32x8>;
2135    type Left = A::f32x8;
2136    type Right = A::f16x8;
2137    type Return = f32;
2138    type Main = Strategy2x4;
2139
2140    #[inline(always)]
2141    fn init(&self, arch: A) -> Self::Accumulator {
2142        Self::Accumulator::new(arch)
2143    }
2144
2145    #[inline(always)]
2146    fn accumulate(
2147        &self,
2148        x: Self::Left,
2149        y: Self::Right,
2150        acc: Self::Accumulator,
2151    ) -> Self::Accumulator {
2152        let y: A::f32x8 = y.into();
2153        acc.add_with(x, y)
2154    }
2155
2156    #[inline(always)]
2157    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2158        acc.sum()
2159    }
2160}
2161
2162#[cfg(target_arch = "x86_64")]
2163impl SIMDSchema<i8, i8, V4> for CosineStateless {
2164    type SIMDWidth = Const<32>;
2165    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2166    type Left = <V4 as Architecture>::i8x32;
2167    type Right = <V4 as Architecture>::i8x32;
2168    type Return = f32;
2169    type Main = Strategy4x1;
2170
2171    #[inline(always)]
2172    fn init(&self, arch: V4) -> Self::Accumulator {
2173        Self::Accumulator::new(arch)
2174    }
2175
2176    #[inline(always)]
2177    fn accumulate(
2178        &self,
2179        x: Self::Left,
2180        y: Self::Right,
2181        acc: Self::Accumulator,
2182    ) -> Self::Accumulator {
2183        diskann_wide::alias!(i16s = <V4>::i16x32);
2184
2185        let x: i16s = x.into();
2186        let y: i16s = y.into();
2187
2188        FullCosineAccumulator {
2189            normx: acc.normx.dot_simd(x, x),
2190            normy: acc.normy.dot_simd(y, y),
2191            xy: acc.xy.dot_simd(x, y),
2192        }
2193    }
2194
2195    // Perform a final reduction.
2196    #[inline(always)]
2197    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2198        x.sum()
2199    }
2200}
2201
2202#[cfg(target_arch = "x86_64")]
2203impl SIMDSchema<i8, i8, V3> for CosineStateless {
2204    type SIMDWidth = Const<16>;
2205    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2206    type Left = <V3 as Architecture>::i8x16;
2207    type Right = <V3 as Architecture>::i8x16;
2208    type Return = f32;
2209    type Main = Strategy4x1;
2210
2211    #[inline(always)]
2212    fn init(&self, arch: V3) -> Self::Accumulator {
2213        Self::Accumulator::new(arch)
2214    }
2215
2216    #[inline(always)]
2217    fn accumulate(
2218        &self,
2219        x: Self::Left,
2220        y: Self::Right,
2221        acc: Self::Accumulator,
2222    ) -> Self::Accumulator {
2223        diskann_wide::alias!(i16s = <V3>::i16x16);
2224
2225        let x: i16s = x.into();
2226        let y: i16s = y.into();
2227
2228        FullCosineAccumulator {
2229            normx: acc.normx.dot_simd(x, x),
2230            normy: acc.normy.dot_simd(y, y),
2231            xy: acc.xy.dot_simd(x, y),
2232        }
2233    }
2234
2235    // Perform a final reduction.
2236    #[inline(always)]
2237    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2238        x.sum()
2239    }
2240}
2241
2242impl SIMDSchema<i8, i8, Scalar> for CosineStateless {
2243    type SIMDWidth = Const<4>;
2244    type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2245    type Left = Emulated<i8, 4>;
2246    type Right = Emulated<i8, 4>;
2247    type Return = f32;
2248    type Main = Strategy1x1;
2249
2250    #[inline(always)]
2251    fn init(&self, arch: Scalar) -> Self::Accumulator {
2252        Self::Accumulator::new(arch)
2253    }
2254
2255    #[inline(always)]
2256    fn accumulate(
2257        &self,
2258        x: Self::Left,
2259        y: Self::Right,
2260        acc: Self::Accumulator,
2261    ) -> Self::Accumulator {
2262        let x: Emulated<i32, 4> = x.into();
2263        let y: Emulated<i32, 4> = y.into();
2264        acc.add_with(x, y)
2265    }
2266
2267    // Perform a final reduction.
2268    #[inline(always)]
2269    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2270        x.sum()
2271    }
2272
2273    #[inline(always)]
2274    unsafe fn epilogue(
2275        &self,
2276        arch: Scalar,
2277        x: *const i8,
2278        y: *const i8,
2279        len: usize,
2280        acc: Self::Accumulator,
2281    ) -> Self::Accumulator {
2282        let mut xy: i32 = 0;
2283        let mut xx: i32 = 0;
2284        let mut yy: i32 = 0;
2285
2286        for i in 0..len {
2287            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2288            let vx: i32 = unsafe { x.add(i).read() }.into();
2289            // SAFETY: The range `[y, y.add(len))` is valid for reads.
2290            let vy: i32 = unsafe { y.add(i).read() }.into();
2291
2292            xx += vx * vx;
2293            xy += vx * vy;
2294            yy += vy * vy;
2295        }
2296
2297        acc + FullCosineAccumulator {
2298            normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2299            normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2300            xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2301        }
2302    }
2303}
2304
2305#[cfg(target_arch = "x86_64")]
2306impl SIMDSchema<u8, u8, V4> for CosineStateless {
2307    type SIMDWidth = Const<32>;
2308    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2309    type Left = <V4 as Architecture>::u8x32;
2310    type Right = <V4 as Architecture>::u8x32;
2311    type Return = f32;
2312    type Main = Strategy4x1;
2313
2314    #[inline(always)]
2315    fn init(&self, arch: V4) -> Self::Accumulator {
2316        Self::Accumulator::new(arch)
2317    }
2318
2319    #[inline(always)]
2320    fn accumulate(
2321        &self,
2322        x: Self::Left,
2323        y: Self::Right,
2324        acc: Self::Accumulator,
2325    ) -> Self::Accumulator {
2326        diskann_wide::alias!(i16s = <V4>::i16x32);
2327
2328        let x: i16s = x.into();
2329        let y: i16s = y.into();
2330
2331        FullCosineAccumulator {
2332            normx: acc.normx.dot_simd(x, x),
2333            normy: acc.normy.dot_simd(y, y),
2334            xy: acc.xy.dot_simd(x, y),
2335        }
2336    }
2337
2338    // Perform a final reduction.
2339    #[inline(always)]
2340    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2341        x.sum()
2342    }
2343}
2344
2345#[cfg(target_arch = "x86_64")]
2346impl SIMDSchema<u8, u8, V3> for CosineStateless {
2347    type SIMDWidth = Const<16>;
2348    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2349    type Left = <V3 as Architecture>::u8x16;
2350    type Right = <V3 as Architecture>::u8x16;
2351    type Return = f32;
2352    type Main = Strategy4x1;
2353
2354    #[inline(always)]
2355    fn init(&self, arch: V3) -> Self::Accumulator {
2356        Self::Accumulator::new(arch)
2357    }
2358
2359    #[inline(always)]
2360    fn accumulate(
2361        &self,
2362        x: Self::Left,
2363        y: Self::Right,
2364        acc: Self::Accumulator,
2365    ) -> Self::Accumulator {
2366        diskann_wide::alias!(i16s = <V3>::i16x16);
2367
2368        let x: i16s = x.into();
2369        let y: i16s = y.into();
2370
2371        FullCosineAccumulator {
2372            normx: acc.normx.dot_simd(x, x),
2373            normy: acc.normy.dot_simd(y, y),
2374            xy: acc.xy.dot_simd(x, y),
2375        }
2376    }
2377
2378    // Perform a final reduction.
2379    #[inline(always)]
2380    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2381        x.sum()
2382    }
2383}
2384
2385impl SIMDSchema<u8, u8, Scalar> for CosineStateless {
2386    type SIMDWidth = Const<4>;
2387    type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2388    type Left = Emulated<u8, 4>;
2389    type Right = Emulated<u8, 4>;
2390    type Return = f32;
2391    type Main = Strategy1x1;
2392
2393    #[inline(always)]
2394    fn init(&self, arch: Scalar) -> Self::Accumulator {
2395        Self::Accumulator::new(arch)
2396    }
2397
2398    #[inline(always)]
2399    fn accumulate(
2400        &self,
2401        x: Self::Left,
2402        y: Self::Right,
2403        acc: Self::Accumulator,
2404    ) -> Self::Accumulator {
2405        let x: Emulated<i32, 4> = x.into();
2406        let y: Emulated<i32, 4> = y.into();
2407        acc.add_with(x, y)
2408    }
2409
2410    // Perform a final reduction.
2411    #[inline(always)]
2412    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2413        x.sum()
2414    }
2415
2416    #[inline(always)]
2417    unsafe fn epilogue(
2418        &self,
2419        arch: Scalar,
2420        x: *const u8,
2421        y: *const u8,
2422        len: usize,
2423        acc: Self::Accumulator,
2424    ) -> Self::Accumulator {
2425        let mut xy: i32 = 0;
2426        let mut xx: i32 = 0;
2427        let mut yy: i32 = 0;
2428
2429        for i in 0..len {
2430            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2431            let vx: i32 = unsafe { x.add(i).read() }.into();
2432            // SAFETY: The range `[y, y.add(len))` is valid for reads.
2433            let vy: i32 = unsafe { y.add(i).read() }.into();
2434
2435            xx += vx * vx;
2436            xy += vx * vy;
2437            yy += vy * vy;
2438        }
2439
2440        acc + FullCosineAccumulator {
2441            normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2442            normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2443            xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2444        }
2445    }
2446}
2447
2448/// A resumable cosine similarity computation.
2449#[derive(Debug, Clone, Copy)]
2450pub struct ResumableCosine<A = diskann_wide::arch::Current>(
2451    <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
2452)
2453where
2454    A: Architecture,
2455    CosineStateless: SIMDSchema<f32, f32, A>;
2456
2457impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableCosine<A>
2458where
2459    A: Architecture,
2460    CosineStateless: SIMDSchema<f32, f32, A, Return = f32>,
2461{
2462    type NonResumable = CosineStateless;
2463    type FinalReturn = f32;
2464
2465    #[inline(always)]
2466    fn init(arch: A) -> Self {
2467        Self(CosineStateless.init(arch))
2468    }
2469
2470    #[inline(always)]
2471    fn combine_with(
2472        &self,
2473        other: <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
2474    ) -> Self {
2475        Self(self.0 + other)
2476    }
2477
2478    #[inline(always)]
2479    fn sum(&self) -> f32 {
2480        CosineStateless.reduce(self.0)
2481    }
2482}
2483
2484/////
2485///// L1 Norm Implementations
2486/////
2487
2488// ==================================================================================================
2489// NOTE: L1Norm IS A LOGICAL UNARY OPERATION
2490// --------------------------------------------------------------------------------------------------
2491// Although wired through the generic binary 'SIMDSchema'/'simd_op' infrastructure (which expects
2492// two input slices of equal length), 'L1Norm' conceptually computes: sum_i |x_i|
2493// The right-hand operand is completely ignored and exists ONLY to satisfy the shared execution
2494// machinery (loop tiling, epilogue handling, etc.).
2495// ==================================================================================================
2496
2497// A pure L1 norm function that provides a final reduction.
2498#[derive(Clone, Copy, Debug, Default)]
2499pub struct L1Norm;
2500
2501#[cfg(target_arch = "x86_64")]
2502impl SIMDSchema<f32, f32, V4> for L1Norm {
2503    type SIMDWidth = Const<16>;
2504    type Accumulator = <V4 as Architecture>::f32x16;
2505    type Left = <V4 as Architecture>::f32x16;
2506    type Right = <V4 as Architecture>::f32x16;
2507    type Return = f32;
2508    type Main = Strategy4x1;
2509
2510    #[inline(always)]
2511    fn init(&self, arch: V4) -> Self::Accumulator {
2512        Self::Accumulator::default(arch)
2513    }
2514
2515    #[inline(always)]
2516    fn accumulate(
2517        &self,
2518        x: Self::Left,
2519        _y: Self::Right,
2520        acc: Self::Accumulator,
2521    ) -> Self::Accumulator {
2522        x.abs_simd() + acc
2523    }
2524
2525    // Perform a final reduction.
2526    #[inline(always)]
2527    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2528        x.sum_tree()
2529    }
2530}
2531
2532#[cfg(target_arch = "x86_64")]
2533impl SIMDSchema<f32, f32, V3> for L1Norm {
2534    type SIMDWidth = Const<8>;
2535    type Accumulator = <V3 as Architecture>::f32x8;
2536    type Left = <V3 as Architecture>::f32x8;
2537    type Right = <V3 as Architecture>::f32x8;
2538    type Return = f32;
2539    type Main = Strategy4x1;
2540
2541    #[inline(always)]
2542    fn init(&self, arch: V3) -> Self::Accumulator {
2543        Self::Accumulator::default(arch)
2544    }
2545
2546    #[inline(always)]
2547    fn accumulate(
2548        &self,
2549        x: Self::Left,
2550        _y: Self::Right,
2551        acc: Self::Accumulator,
2552    ) -> Self::Accumulator {
2553        x.abs_simd() + acc
2554    }
2555
2556    // Perform a final reduction.
2557    #[inline(always)]
2558    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2559        x.sum_tree()
2560    }
2561}
2562
2563impl SIMDSchema<f32, f32, Scalar> for L1Norm {
2564    type SIMDWidth = Const<4>;
2565    type Accumulator = Emulated<f32, 4>;
2566    type Left = Emulated<f32, 4>;
2567    type Right = Emulated<f32, 4>;
2568    type Return = f32;
2569    type Main = Strategy2x1;
2570
2571    #[inline(always)]
2572    fn init(&self, arch: Scalar) -> Self::Accumulator {
2573        Self::Accumulator::default(arch)
2574    }
2575
2576    #[inline(always)]
2577    fn accumulate(
2578        &self,
2579        x: Self::Left,
2580        _y: Self::Right,
2581        acc: Self::Accumulator,
2582    ) -> Self::Accumulator {
2583        x.abs_simd() + acc
2584    }
2585
2586    // Perform a final reduction.
2587    #[inline(always)]
2588    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2589        x.sum_tree()
2590    }
2591
2592    #[inline(always)]
2593    unsafe fn epilogue(
2594        &self,
2595        arch: Scalar,
2596        x: *const f32,
2597        _y: *const f32,
2598        len: usize,
2599        acc: Self::Accumulator,
2600    ) -> Self::Accumulator {
2601        let mut s: f32 = 0.0;
2602        for i in 0..len {
2603            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2604            let vx = unsafe { x.add(i).read() };
2605            s += vx.abs();
2606        }
2607        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
2608    }
2609}
2610
2611#[cfg(target_arch = "x86_64")]
2612impl SIMDSchema<Half, Half, V4> for L1Norm {
2613    type SIMDWidth = Const<8>;
2614    type Accumulator = <V4 as Architecture>::f32x8;
2615    type Left = <V4 as Architecture>::f16x8;
2616    type Right = <V4 as Architecture>::f16x8;
2617    type Return = f32;
2618    type Main = Strategy2x4;
2619
2620    #[inline(always)]
2621    fn init(&self, arch: V4) -> Self::Accumulator {
2622        Self::Accumulator::default(arch)
2623    }
2624
2625    #[inline(always)]
2626    fn accumulate(
2627        &self,
2628        x: Self::Left,
2629        _y: Self::Right,
2630        acc: Self::Accumulator,
2631    ) -> Self::Accumulator {
2632        let x: <V4 as Architecture>::f32x8 = x.into();
2633        x.abs_simd() + acc
2634    }
2635
2636    // Perform a final reduction.
2637    #[inline(always)]
2638    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2639        x.sum_tree()
2640    }
2641}
2642
2643#[cfg(target_arch = "x86_64")]
2644impl SIMDSchema<Half, Half, V3> for L1Norm {
2645    type SIMDWidth = Const<8>;
2646    type Accumulator = <V3 as Architecture>::f32x8;
2647    type Left = <V3 as Architecture>::f16x8;
2648    type Right = <V3 as Architecture>::f16x8;
2649    type Return = f32;
2650    type Main = Strategy2x4;
2651
2652    #[inline(always)]
2653    fn init(&self, arch: V3) -> Self::Accumulator {
2654        Self::Accumulator::default(arch)
2655    }
2656
2657    #[inline(always)]
2658    fn accumulate(
2659        &self,
2660        x: Self::Left,
2661        _y: Self::Right,
2662        acc: Self::Accumulator,
2663    ) -> Self::Accumulator {
2664        let x: <V3 as Architecture>::f32x8 = x.into();
2665        x.abs_simd() + acc
2666    }
2667
2668    // Perform a final reduction.
2669    #[inline(always)]
2670    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2671        x.sum_tree()
2672    }
2673}
2674
2675impl SIMDSchema<Half, Half, Scalar> for L1Norm {
2676    type SIMDWidth = Const<1>;
2677    type Accumulator = Emulated<f32, 1>;
2678    type Left = Emulated<Half, 1>;
2679    type Right = Emulated<Half, 1>;
2680    type Return = f32;
2681    type Main = Strategy1x1;
2682
2683    #[inline(always)]
2684    fn init(&self, arch: Scalar) -> Self::Accumulator {
2685        Self::Accumulator::default(arch)
2686    }
2687
2688    #[inline(always)]
2689    fn accumulate(
2690        &self,
2691        x: Self::Left,
2692        _y: Self::Right,
2693        acc: Self::Accumulator,
2694    ) -> Self::Accumulator {
2695        let x: Self::Accumulator = x.into();
2696        x.abs_simd() + acc
2697    }
2698
2699    // Perform a final reduction.
2700    #[inline(always)]
2701    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2702        x.to_array()[0]
2703    }
2704
2705    #[inline(always)]
2706    unsafe fn epilogue(
2707        &self,
2708        _arch: Scalar,
2709        _x: *const Half,
2710        _y: *const Half,
2711        _len: usize,
2712        _acc: Self::Accumulator,
2713    ) -> Self::Accumulator {
2714        unreachable!("The SIMD width is 1, so there should be no epilogue")
2715    }
2716}
2717
2718///////////
2719// Tests //
2720///////////
2721
2722#[cfg(test)]
2723mod tests {
2724    use std::{collections::HashMap, sync::LazyLock};
2725
2726    use approx::assert_relative_eq;
2727    use diskann_wide::{arch::Target1, ARCH};
2728    use half::f16;
2729    use rand::{distr::StandardUniform, rngs::StdRng, Rng, SeedableRng};
2730    use rand_distr;
2731
2732    use super::*;
2733    use crate::{distance::reference, norm::LInfNorm, test_util};
2734
2735    ///////////////////////
2736    // Cosine Norm Check //
2737    ///////////////////////
2738
2739    fn cosine_norm_check_impl<A>(arch: A)
2740    where
2741        A: diskann_wide::Architecture,
2742        CosineStateless:
2743            SIMDSchema<f32, f32, A, Return = f32> + SIMDSchema<Half, Half, A, Return = f32>,
2744    {
2745        // Zero - f32
2746        {
2747            let x: [f32; 2] = [0.0, 0.0];
2748            let y: [f32; 2] = [0.0, 1.0];
2749            assert_eq!(
2750                simd_op(&CosineStateless {}, arch, x, x),
2751                0.0,
2752                "when both vectors are zero, similarity should be zero",
2753            );
2754            assert_eq!(
2755                simd_op(&CosineStateless {}, arch, x, y),
2756                0.0,
2757                "when one vector is zero, similarity should be zero",
2758            );
2759            assert_eq!(
2760                simd_op(&CosineStateless {}, arch, y, x),
2761                0.0,
2762                "when one vector is zero, similarity should be zero",
2763            );
2764        }
2765
2766        // Subnormal - f32
2767        {
2768            let x: [f32; 4] = [0.0, 0.0, 2.938736e-39f32, 0.0];
2769            let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
2770            assert_eq!(
2771                simd_op(&CosineStateless {}, arch, x, x),
2772                0.0,
2773                "when both vectors are almost zero, similarity should be zero",
2774            );
2775            assert_eq!(
2776                simd_op(&CosineStateless {}, arch, x, y),
2777                0.0,
2778                "when one vector is almost zero, similarity should be zero",
2779            );
2780            assert_eq!(
2781                simd_op(&CosineStateless {}, arch, y, x),
2782                0.0,
2783                "when one vector is almost zero, similarity should be zero",
2784            );
2785        }
2786
2787        // Small - f32
2788        {
2789            let x: [f32; 4] = [0.0, 0.0, 1.0842022e-19f32, 0.0];
2790            let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
2791            assert_eq!(
2792                simd_op(&CosineStateless {}, arch, x, x),
2793                1.0,
2794                "cosine-stateless should handle vectors this small",
2795            );
2796            assert_eq!(
2797                simd_op(&CosineStateless {}, arch, x, y),
2798                1.0,
2799                "cosine-stateless should handle vectors this small",
2800            );
2801            assert_eq!(
2802                simd_op(&CosineStateless {}, arch, y, x),
2803                1.0,
2804                "cosine-stateless should handle vectors this small",
2805            );
2806        }
2807
2808        let cvt = diskann_wide::cast_f32_to_f16;
2809
2810        // Zero - f16
2811        {
2812            let x: [Half; 2] = [Half::default(), Half::default()];
2813            let y: [Half; 2] = [Half::default(), cvt(1.0)];
2814            assert_eq!(
2815                simd_op(&CosineStateless {}, arch, x, x),
2816                0.0,
2817                "when both vectors are zero, similarity should be zero",
2818            );
2819            assert_eq!(
2820                simd_op(&CosineStateless {}, arch, x, y),
2821                0.0,
2822                "when one vector is zero, similarity should be zero",
2823            );
2824            assert_eq!(
2825                simd_op(&CosineStateless {}, arch, y, x),
2826                0.0,
2827                "when one vector is zero, similarity should be zero",
2828            );
2829        }
2830
2831        // Subnormal - f16
2832        {
2833            let x: [Half; 4] = [
2834                Half::default(),
2835                Half::default(),
2836                Half::MIN_POSITIVE_SUBNORMAL,
2837                Half::default(),
2838            ];
2839            let y: [Half; 4] = [Half::default(), Half::default(), cvt(1.0), Half::default()];
2840            assert_eq!(
2841                simd_op(&CosineStateless {}, arch, x, x),
2842                1.0,
2843                "when both vectors are almost zero, similarity should be zero",
2844            );
2845            assert_eq!(
2846                simd_op(&CosineStateless {}, arch, x, y),
2847                1.0,
2848                "when one vector is almost zero, similarity should be zero",
2849            );
2850            assert_eq!(
2851                simd_op(&CosineStateless {}, arch, y, x),
2852                1.0,
2853                "when one vector is almost zero, similarity should be zero",
2854            );
2855
2856            // Grab a range of floating point numbers whose squares cover the range of
2857            // our target threshold.
2858            //
2859            // Ensure that all combinations of values within this critical range to not
2860            // result in a misrounding.
2861            let threshold = f32::MIN_POSITIVE;
2862            let bound = 50;
2863            let values = {
2864                let mut down = threshold;
2865                let mut up = threshold;
2866                for _ in 0..bound {
2867                    down = down.next_down();
2868                    up = up.next_up();
2869                }
2870                assert!(down > 0.0);
2871                let min = down.sqrt();
2872                let max = up.sqrt();
2873                let mut v = min;
2874                let mut values = Vec::new();
2875                while v <= max {
2876                    values.push(v);
2877                    v = v.next_up();
2878                }
2879                values
2880            };
2881
2882            let mut lo = 0;
2883            let mut hi = 0;
2884            for i in values.iter() {
2885                for j in values.iter() {
2886                    let s: f32 = simd_op(&CosineStateless {}, arch, [*i], [*j]);
2887                    if i * i < threshold || j * j < threshold {
2888                        lo += 1;
2889                        assert_eq!(s, 0.0, "failed for i = {}, j = {}", i, j);
2890                    } else {
2891                        hi += 1;
2892                        assert_eq!(s, 1.0, "failed for i = {}, j = {}", i, j);
2893                    }
2894                }
2895            }
2896            assert_ne!(lo, 0);
2897            assert_ne!(hi, 0);
2898        }
2899    }
2900
2901    #[test]
2902    fn cosine_norm_check() {
2903        cosine_norm_check_impl::<diskann_wide::arch::Current>(diskann_wide::arch::current());
2904        cosine_norm_check_impl::<diskann_wide::arch::Scalar>(diskann_wide::arch::Scalar::new());
2905    }
2906
2907    #[test]
2908    #[cfg(target_arch = "x86_64")]
2909    fn cosine_norm_check_x86_64() {
2910        if let Some(arch) = V3::new_checked() {
2911            cosine_norm_check_impl::<V3>(arch);
2912        }
2913
2914        if let Some(arch) = V4::new_checked_miri() {
2915            cosine_norm_check_impl::<V4>(arch);
2916        }
2917    }
2918
2919    ////////////
2920    // Schema //
2921    ////////////
2922
2923    // Chunk the left and right hand slices and compute the result using a resumable function.
2924    fn test_resumable<T, L, R, A>(arch: A, x: &[L], y: &[R], chunk_size: usize) -> f32
2925    where
2926        A: Architecture,
2927        T: ResumableSIMDSchema<L, R, A, FinalReturn = f32>,
2928    {
2929        let mut acc = Resumable(<T as ResumableSIMDSchema<L, R, A>>::init(arch));
2930        let iter = std::iter::zip(x.chunks(chunk_size), y.chunks(chunk_size));
2931        for (a, b) in iter {
2932            acc = simd_op(&acc, arch, a, b);
2933        }
2934        acc.0.sum()
2935    }
2936
2937    fn stress_test_with_resumable<
2938        A: Architecture,
2939        O: Default + SIMDSchema<f32, f32, A, Return = f32>,
2940        T: ResumableSIMDSchema<f32, f32, A, NonResumable = O, FinalReturn = f32>,
2941        Rand: Rng,
2942    >(
2943        arch: A,
2944        reference: fn(&[f32], &[f32]) -> f32,
2945        dim: usize,
2946        epsilon: f32,
2947        max_relative: f32,
2948        rng: &mut Rand,
2949    ) {
2950        // Pick chunk sizes that exercise combinations of the unrolled loops.
2951        let chunk_divisors: Vec<usize> = vec![1, 2, 3, 4, 16, 54, 64, 65, 70, 77];
2952        let checker = test_util::AdHocChecker::<f32, f32>::new(|a: &[f32], b: &[f32]| {
2953            let expected = reference(a, b);
2954            let got = simd_op(&O::default(), arch, a, b);
2955            println!("dim = {}", dim);
2956            assert_relative_eq!(
2957                expected,
2958                got,
2959                epsilon = epsilon,
2960                max_relative = max_relative,
2961            );
2962
2963            if dim == 0 {
2964                return;
2965            }
2966
2967            for d in &chunk_divisors {
2968                let chunk_size = dim / d + (!dim.is_multiple_of(*d) as usize);
2969                let chunked = test_resumable::<T, f32, f32, _>(arch, a, b, chunk_size);
2970                assert_relative_eq!(chunked, got, epsilon = epsilon, max_relative = max_relative);
2971            }
2972        });
2973
2974        test_util::test_distance_function(
2975            checker,
2976            rand_distr::Normal::new(0.0, 10.0).unwrap(),
2977            rand_distr::Normal::new(0.0, 10.0).unwrap(),
2978            dim,
2979            10,
2980            rng,
2981        )
2982    }
2983
2984    #[allow(clippy::too_many_arguments)]
2985    fn stress_test<L, R, DistLeft, DistRight, O, Rand, A>(
2986        arch: A,
2987        reference: fn(&[L], &[R]) -> f32,
2988        left_dist: DistLeft,
2989        right_dist: DistRight,
2990        dim: usize,
2991        epsilon: f32,
2992        max_relative: f32,
2993        rng: &mut Rand,
2994    ) where
2995        L: test_util::CornerCases,
2996        R: test_util::CornerCases,
2997        DistLeft: test_util::GenerateRandomArguments<L>,
2998        DistRight: test_util::GenerateRandomArguments<R>,
2999        O: Default + SIMDSchema<L, R, A, Return = f32>,
3000        Rand: Rng,
3001        A: Architecture,
3002    {
3003        let checker = test_util::Checker::<L, R, f32>::new(
3004            |x: &[L], y: &[R]| simd_op(&O::default(), arch, x, y),
3005            reference,
3006            |got, expected| {
3007                assert_relative_eq!(
3008                    expected,
3009                    got,
3010                    epsilon = epsilon,
3011                    max_relative = max_relative
3012                );
3013            },
3014        );
3015
3016        let trials = if cfg!(miri) { 0 } else { 10 };
3017
3018        test_util::test_distance_function(checker, left_dist, right_dist, dim, trials, rng);
3019    }
3020
3021    fn stress_test_linf<L, Dist, Rand, A>(
3022        arch: A,
3023        reference: fn(&[L]) -> f32,
3024        dist: Dist,
3025        dim: usize,
3026        epsilon: f32,
3027        max_relative: f32,
3028        rng: &mut Rand,
3029    ) where
3030        L: test_util::CornerCases + Copy,
3031        Dist: Clone + test_util::GenerateRandomArguments<L>,
3032        Rand: Rng,
3033        A: Architecture,
3034        LInfNorm: for<'a> Target1<A, f32, &'a [L]>,
3035    {
3036        let checker = test_util::Checker::<L, L, f32>::new(
3037            |x: &[L], _y: &[L]| (LInfNorm).run(arch, x),
3038            |x: &[L], _y: &[L]| reference(x),
3039            |got, expected| {
3040                assert_relative_eq!(
3041                    expected,
3042                    got,
3043                    epsilon = epsilon,
3044                    max_relative = max_relative
3045                );
3046            },
3047        );
3048
3049        println!("checking {dim}");
3050        test_util::test_distance_function(checker, dist.clone(), dist, dim, 10, rng);
3051    }
3052
3053    /////////
3054    // f32 //
3055    /////////
3056
3057    macro_rules! float_test {
3058        ($name:ident,
3059         $impl:ty,
3060         $resumable:ident,
3061         $reference:path,
3062         $eps:literal,
3063         $relative:literal,
3064         $seed:literal,
3065         $upper:literal,
3066         $($arch:tt)*
3067        ) => {
3068            #[test]
3069            fn $name() {
3070                if let Some(arch) = $($arch)* {
3071                    let mut rng = StdRng::seed_from_u64($seed);
3072                    for dim in 0..$upper {
3073                        stress_test_with_resumable::<_, $impl, $resumable<_>, StdRng>(
3074                            arch,
3075                            |l, r| $reference(l, r).into_inner(),
3076                            dim,
3077                            $eps,
3078                            $relative,
3079                            &mut rng,
3080                        );
3081                    }
3082                }
3083            }
3084        }
3085    }
3086
3087    //----//
3088    // L2 //
3089    //----//
3090
3091    float_test!(
3092        test_l2_f32_current,
3093        L2,
3094        ResumableL2,
3095        reference::reference_squared_l2_f32_mathematical,
3096        1e-5,
3097        1e-5,
3098        0xf149c2bcde660128,
3099        64,
3100        Some(diskann_wide::ARCH)
3101    );
3102
3103    float_test!(
3104        test_l2_f32_scalar,
3105        L2,
3106        ResumableL2,
3107        reference::reference_squared_l2_f32_mathematical,
3108        1e-5,
3109        1e-5,
3110        0xf149c2bcde660128,
3111        64,
3112        Some(diskann_wide::arch::Scalar)
3113    );
3114
3115    #[cfg(target_arch = "x86_64")]
3116    float_test!(
3117        test_l2_f32_x86_64_v3,
3118        L2,
3119        ResumableL2,
3120        reference::reference_squared_l2_f32_mathematical,
3121        1e-5,
3122        1e-5,
3123        0xf149c2bcde660128,
3124        256,
3125        V3::new_checked()
3126    );
3127
3128    #[cfg(target_arch = "x86_64")]
3129    float_test!(
3130        test_l2_f32_x86_64_v4,
3131        L2,
3132        ResumableL2,
3133        reference::reference_squared_l2_f32_mathematical,
3134        1e-5,
3135        1e-5,
3136        0xf149c2bcde660128,
3137        256,
3138        V4::new_checked_miri()
3139    );
3140
3141    //----//
3142    // IP //
3143    //----//
3144
3145    float_test!(
3146        test_ip_f32_current,
3147        IP,
3148        ResumableIP,
3149        reference::reference_innerproduct_f32_mathematical,
3150        2e-4,
3151        1e-3,
3152        0xb4687c17a9ea9866,
3153        64,
3154        Some(diskann_wide::ARCH)
3155    );
3156
3157    float_test!(
3158        test_ip_f32_scalar,
3159        IP,
3160        ResumableIP,
3161        reference::reference_innerproduct_f32_mathematical,
3162        2e-4,
3163        1e-3,
3164        0xb4687c17a9ea9866,
3165        64,
3166        Some(diskann_wide::arch::Scalar)
3167    );
3168
3169    #[cfg(target_arch = "x86_64")]
3170    float_test!(
3171        test_ip_f32_x86_64_v3,
3172        IP,
3173        ResumableIP,
3174        reference::reference_innerproduct_f32_mathematical,
3175        2e-4,
3176        1e-3,
3177        0xb4687c17a9ea9866,
3178        256,
3179        V3::new_checked()
3180    );
3181
3182    #[cfg(target_arch = "x86_64")]
3183    float_test!(
3184        test_ip_f32_x86_64_v4,
3185        IP,
3186        ResumableIP,
3187        reference::reference_innerproduct_f32_mathematical,
3188        2e-4,
3189        1e-3,
3190        0xb4687c17a9ea9866,
3191        256,
3192        V4::new_checked_miri()
3193    );
3194
3195    //--------//
3196    // Cosine //
3197    //--------//
3198
3199    float_test!(
3200        test_cosine_f32_current,
3201        CosineStateless,
3202        ResumableCosine,
3203        reference::reference_cosine_f32_mathematical,
3204        1e-5,
3205        1e-5,
3206        0xe860e9dc65f38bb8,
3207        64,
3208        Some(diskann_wide::ARCH)
3209    );
3210
3211    float_test!(
3212        test_cosine_f32_scalar,
3213        CosineStateless,
3214        ResumableCosine,
3215        reference::reference_cosine_f32_mathematical,
3216        1e-5,
3217        1e-5,
3218        0xe860e9dc65f38bb8,
3219        64,
3220        Some(diskann_wide::arch::Scalar)
3221    );
3222
3223    #[cfg(target_arch = "x86_64")]
3224    float_test!(
3225        test_cosine_f32_x86_64_v3,
3226        CosineStateless,
3227        ResumableCosine,
3228        reference::reference_cosine_f32_mathematical,
3229        1e-5,
3230        1e-5,
3231        0xe860e9dc65f38bb8,
3232        256,
3233        V3::new_checked()
3234    );
3235
3236    #[cfg(target_arch = "x86_64")]
3237    float_test!(
3238        test_cosine_f32_x86_64_v4,
3239        CosineStateless,
3240        ResumableCosine,
3241        reference::reference_cosine_f32_mathematical,
3242        1e-5,
3243        1e-5,
3244        0xe860e9dc65f38bb8,
3245        256,
3246        V4::new_checked_miri()
3247    );
3248
3249    /////////
3250    // f16 //
3251    /////////
3252
3253    macro_rules! half_test {
3254        ($name:ident,
3255         $impl:ty,
3256         $reference:path,
3257         $eps:literal,
3258         $relative:literal,
3259         $seed:literal,
3260         $upper:literal,
3261         $($arch:tt)*
3262        ) => {
3263            #[test]
3264            fn $name() {
3265                if let Some(arch) = $($arch)* {
3266                    let mut rng = StdRng::seed_from_u64($seed);
3267                    for dim in 0..$upper {
3268                        stress_test::<
3269                            Half,
3270                            Half,
3271                            rand_distr::Normal<f32>,
3272                            rand_distr::Normal<f32>,
3273                            $impl,
3274                            StdRng,
3275                            _
3276                        >(
3277                            arch,
3278                            |l, r| $reference(l, r).into_inner(),
3279                            rand_distr::Normal::new(0.0, 10.0).unwrap(),
3280                            rand_distr::Normal::new(0.0, 10.0).unwrap(),
3281                            dim,
3282                            $eps,
3283                            $relative,
3284                            &mut rng
3285                        );
3286                    }
3287                }
3288            }
3289        }
3290    }
3291
3292    //----//
3293    // L2 //
3294    //----//
3295
3296    half_test!(
3297        test_l2_f16_current,
3298        L2,
3299        reference::reference_squared_l2_f16_mathematical,
3300        1e-5,
3301        1e-5,
3302        0x87ca6f1051667500,
3303        64,
3304        Some(diskann_wide::ARCH)
3305    );
3306
3307    half_test!(
3308        test_l2_f16_scalar,
3309        L2,
3310        reference::reference_squared_l2_f16_mathematical,
3311        1e-5,
3312        1e-5,
3313        0x87ca6f1051667500,
3314        64,
3315        Some(diskann_wide::arch::Scalar)
3316    );
3317
3318    #[cfg(target_arch = "x86_64")]
3319    half_test!(
3320        test_l2_f16_x86_64_v3,
3321        L2,
3322        reference::reference_squared_l2_f16_mathematical,
3323        1e-5,
3324        1e-5,
3325        0x87ca6f1051667500,
3326        256,
3327        V3::new_checked()
3328    );
3329
3330    #[cfg(target_arch = "x86_64")]
3331    half_test!(
3332        test_l2_f16_x86_64_v4,
3333        L2,
3334        reference::reference_squared_l2_f16_mathematical,
3335        1e-5,
3336        1e-5,
3337        0x87ca6f1051667500,
3338        256,
3339        V4::new_checked_miri()
3340    );
3341
3342    //----//
3343    // IP //
3344    //----//
3345
3346    half_test!(
3347        test_ip_f16_current,
3348        IP,
3349        reference::reference_innerproduct_f16_mathematical,
3350        2e-4,
3351        2e-4,
3352        0x5909f5f20307ccbe,
3353        64,
3354        Some(diskann_wide::ARCH)
3355    );
3356
3357    half_test!(
3358        test_ip_f16_scalar,
3359        IP,
3360        reference::reference_innerproduct_f16_mathematical,
3361        2e-4,
3362        2e-4,
3363        0x5909f5f20307ccbe,
3364        64,
3365        Some(diskann_wide::arch::Scalar)
3366    );
3367
3368    #[cfg(target_arch = "x86_64")]
3369    half_test!(
3370        test_ip_f16_x86_64_v3,
3371        IP,
3372        reference::reference_innerproduct_f16_mathematical,
3373        2e-4,
3374        2e-4,
3375        0x5909f5f20307ccbe,
3376        256,
3377        V3::new_checked()
3378    );
3379
3380    #[cfg(target_arch = "x86_64")]
3381    half_test!(
3382        test_ip_f16_x86_64_v4,
3383        IP,
3384        reference::reference_innerproduct_f16_mathematical,
3385        2e-4,
3386        2e-4,
3387        0x5909f5f20307ccbe,
3388        256,
3389        V4::new_checked_miri()
3390    );
3391
3392    //--------//
3393    // Cosine //
3394    //--------//
3395
3396    half_test!(
3397        test_cosine_f16_current,
3398        CosineStateless,
3399        reference::reference_cosine_f16_mathematical,
3400        1e-5,
3401        1e-5,
3402        0x41dda34655f05ef6,
3403        64,
3404        Some(diskann_wide::ARCH)
3405    );
3406
3407    half_test!(
3408        test_cosine_f16_scalar,
3409        CosineStateless,
3410        reference::reference_cosine_f16_mathematical,
3411        1e-5,
3412        1e-5,
3413        0x41dda34655f05ef6,
3414        64,
3415        Some(diskann_wide::arch::Scalar)
3416    );
3417
3418    #[cfg(target_arch = "x86_64")]
3419    half_test!(
3420        test_cosine_f16_x86_64_v3,
3421        CosineStateless,
3422        reference::reference_cosine_f16_mathematical,
3423        1e-5,
3424        1e-5,
3425        0x41dda34655f05ef6,
3426        256,
3427        V3::new_checked()
3428    );
3429
3430    #[cfg(target_arch = "x86_64")]
3431    half_test!(
3432        test_cosine_f16_x86_64_v4,
3433        CosineStateless,
3434        reference::reference_cosine_f16_mathematical,
3435        1e-5,
3436        1e-5,
3437        0x41dda34655f05ef6,
3438        256,
3439        V4::new_checked_miri()
3440    );
3441
3442    /////////////
3443    // Integer //
3444    /////////////
3445
3446    macro_rules! int_test {
3447        (
3448            $name:ident,
3449            $T:ty,
3450            $impl:ty,
3451            $reference:path,
3452            $seed:literal,
3453            $upper:literal,
3454            { $($arch:tt)* }
3455        ) => {
3456            #[test]
3457            fn $name() {
3458                if let Some(arch) = $($arch)* {
3459                    let mut rng = StdRng::seed_from_u64($seed);
3460                    for dim in 0..$upper {
3461                        stress_test::<$T, $T, _, _, $impl, _, _>(
3462                            arch,
3463                            |l, r| $reference(l, r).into_inner(),
3464                            StandardUniform,
3465                            StandardUniform,
3466                            dim,
3467                            0.0,
3468                            0.0,
3469                            &mut rng,
3470                        )
3471                    }
3472                }
3473            }
3474        }
3475    }
3476
3477    //----//
3478    // U8 //
3479    //----//
3480
3481    int_test!(
3482        test_l2_u8_current,
3483        u8,
3484        L2,
3485        reference::reference_squared_l2_u8_mathematical,
3486        0x945bdc37d8279d4b,
3487        128,
3488        { Some(ARCH) }
3489    );
3490
3491    int_test!(
3492        test_l2_u8_scalar,
3493        u8,
3494        L2,
3495        reference::reference_squared_l2_u8_mathematical,
3496        0x74c86334ab7a51f9,
3497        128,
3498        { Some(diskann_wide::arch::Scalar) }
3499    );
3500
3501    #[cfg(target_arch = "x86_64")]
3502    int_test!(
3503        test_l2_u8_x86_64_v3,
3504        u8,
3505        L2,
3506        reference::reference_squared_l2_u8_mathematical,
3507        0x74c86334ab7a51f9,
3508        256,
3509        { V3::new_checked() }
3510    );
3511
3512    #[cfg(target_arch = "x86_64")]
3513    int_test!(
3514        test_l2_u8_x86_64_v4,
3515        u8,
3516        L2,
3517        reference::reference_squared_l2_u8_mathematical,
3518        0x74c86334ab7a51f9,
3519        320,
3520        { V4::new_checked_miri() }
3521    );
3522
3523    int_test!(
3524        test_ip_u8_current,
3525        u8,
3526        IP,
3527        reference::reference_innerproduct_u8_mathematical,
3528        0xcbe0342c75085fd5,
3529        64,
3530        { Some(ARCH) }
3531    );
3532
3533    int_test!(
3534        test_ip_u8_scalar,
3535        u8,
3536        IP,
3537        reference::reference_innerproduct_u8_mathematical,
3538        0x888e07fc489e773f,
3539        64,
3540        { Some(diskann_wide::arch::Scalar) }
3541    );
3542
3543    #[cfg(target_arch = "x86_64")]
3544    int_test!(
3545        test_ip_u8_x86_64_v3,
3546        u8,
3547        IP,
3548        reference::reference_innerproduct_u8_mathematical,
3549        0x888e07fc489e773f,
3550        256,
3551        { V3::new_checked() }
3552    );
3553
3554    #[cfg(target_arch = "x86_64")]
3555    int_test!(
3556        test_ip_u8_x86_64_v4,
3557        u8,
3558        IP,
3559        reference::reference_innerproduct_u8_mathematical,
3560        0x888e07fc489e773f,
3561        320,
3562        { V4::new_checked_miri() }
3563    );
3564
3565    int_test!(
3566        test_cosine_u8_current,
3567        u8,
3568        CosineStateless,
3569        reference::reference_cosine_u8_mathematical,
3570        0x96867b6aff616b28,
3571        64,
3572        { Some(ARCH) }
3573    );
3574
3575    int_test!(
3576        test_cosine_u8_scalar,
3577        u8,
3578        CosineStateless,
3579        reference::reference_cosine_u8_mathematical,
3580        0xcc258c9391733211,
3581        64,
3582        { Some(diskann_wide::arch::Scalar) }
3583    );
3584
3585    #[cfg(target_arch = "x86_64")]
3586    int_test!(
3587        test_cosine_u8_x86_64_v3,
3588        u8,
3589        CosineStateless,
3590        reference::reference_cosine_u8_mathematical,
3591        0xcc258c9391733211,
3592        256,
3593        { V3::new_checked() }
3594    );
3595
3596    #[cfg(target_arch = "x86_64")]
3597    int_test!(
3598        test_cosine_u8_x86_64_v4,
3599        u8,
3600        CosineStateless,
3601        reference::reference_cosine_u8_mathematical,
3602        0xcc258c9391733211,
3603        320,
3604        { V4::new_checked_miri() }
3605    );
3606
3607    //----//
3608    // I8 //
3609    //----//
3610
3611    int_test!(
3612        test_l2_i8_current,
3613        i8,
3614        L2,
3615        reference::reference_squared_l2_i8_mathematical,
3616        0xa60136248cd3c2f0,
3617        64,
3618        { Some(ARCH) }
3619    );
3620
3621    int_test!(
3622        test_l2_i8_scalar,
3623        i8,
3624        L2,
3625        reference::reference_squared_l2_i8_mathematical,
3626        0x3e8bada709e176be,
3627        64,
3628        { Some(diskann_wide::arch::Scalar) }
3629    );
3630
3631    #[cfg(target_arch = "x86_64")]
3632    int_test!(
3633        test_l2_i8_x86_64_v3,
3634        i8,
3635        L2,
3636        reference::reference_squared_l2_i8_mathematical,
3637        0x3e8bada709e176be,
3638        256,
3639        { V3::new_checked() }
3640    );
3641
3642    #[cfg(target_arch = "x86_64")]
3643    int_test!(
3644        test_l2_i8_x86_64_v4,
3645        i8,
3646        L2,
3647        reference::reference_squared_l2_i8_mathematical,
3648        0x3e8bada709e176be,
3649        320,
3650        { V4::new_checked_miri() }
3651    );
3652
3653    int_test!(
3654        test_ip_i8_current,
3655        i8,
3656        IP,
3657        reference::reference_innerproduct_i8_mathematical,
3658        0xe8306104740509e1,
3659        64,
3660        { Some(ARCH) }
3661    );
3662
3663    int_test!(
3664        test_ip_i8_scalar,
3665        i8,
3666        IP,
3667        reference::reference_innerproduct_i8_mathematical,
3668        0x8a263408c7b31d85,
3669        64,
3670        { Some(diskann_wide::arch::Scalar) }
3671    );
3672
3673    #[cfg(target_arch = "x86_64")]
3674    int_test!(
3675        test_ip_i8_x86_64_v3,
3676        i8,
3677        IP,
3678        reference::reference_innerproduct_i8_mathematical,
3679        0x8a263408c7b31d85,
3680        256,
3681        { V3::new_checked() }
3682    );
3683
3684    #[cfg(target_arch = "x86_64")]
3685    int_test!(
3686        test_ip_i8_x86_64_v4,
3687        i8,
3688        IP,
3689        reference::reference_innerproduct_i8_mathematical,
3690        0x8a263408c7b31d85,
3691        320,
3692        { V4::new_checked_miri() }
3693    );
3694
3695    int_test!(
3696        test_cosine_i8_current,
3697        i8,
3698        CosineStateless,
3699        reference::reference_cosine_i8_mathematical,
3700        0x818c210190701e4b,
3701        64,
3702        { Some(ARCH) }
3703    );
3704
3705    int_test!(
3706        test_cosine_i8_scalar,
3707        i8,
3708        CosineStateless,
3709        reference::reference_cosine_i8_mathematical,
3710        0x2d077bed2629b18e,
3711        64,
3712        { Some(diskann_wide::arch::Scalar) }
3713    );
3714
3715    #[cfg(target_arch = "x86_64")]
3716    int_test!(
3717        test_cosine_i8_x86_64_v3,
3718        i8,
3719        CosineStateless,
3720        reference::reference_cosine_i8_mathematical,
3721        0x2d077bed2629b18e,
3722        256,
3723        { V3::new_checked() }
3724    );
3725
3726    #[cfg(target_arch = "x86_64")]
3727    int_test!(
3728        test_cosine_i8_x86_64_v4,
3729        i8,
3730        CosineStateless,
3731        reference::reference_cosine_i8_mathematical,
3732        0x2d077bed2629b18e,
3733        320,
3734        { V4::new_checked_miri() }
3735    );
3736
3737    //////////
3738    // LInf //
3739    //////////
3740
3741    macro_rules! linf_test {
3742        ($name:ident,
3743         $T:ty,
3744         $reference:path,
3745         $eps:literal,
3746         $relative:literal,
3747         $seed:literal,
3748         $upper:literal,
3749         $($arch:tt)*
3750        ) => {
3751            #[test]
3752            fn $name() {
3753                if let Some(arch) = $($arch)* {
3754                    let mut rng = StdRng::seed_from_u64($seed);
3755                    for dim in 0..$upper {
3756                        stress_test_linf::<$T, _, StdRng, _>(
3757                            arch,
3758                            |l| $reference(l).into_inner(),
3759                            rand_distr::Normal::new(-10.0, 10.0).unwrap(),
3760                            dim,
3761                            $eps,
3762                            $relative,
3763                            &mut rng,
3764                        );
3765                    }
3766                }
3767            }
3768        }
3769    }
3770
3771    linf_test!(
3772        test_linf_f32_scalar,
3773        f32,
3774        reference::reference_linf_f32_mathematical,
3775        1e-6,
3776        1e-6,
3777        0xf149c2bcde660128,
3778        256,
3779        Some(Scalar::new())
3780    );
3781
3782    #[cfg(target_arch = "x86_64")]
3783    linf_test!(
3784        test_linf_f32_v3,
3785        f32,
3786        reference::reference_linf_f32_mathematical,
3787        1e-6,
3788        1e-6,
3789        0xf149c2bcde660128,
3790        256,
3791        V3::new_checked()
3792    );
3793
3794    #[cfg(target_arch = "x86_64")]
3795    linf_test!(
3796        test_linf_f32_v4,
3797        f32,
3798        reference::reference_linf_f32_mathematical,
3799        1e-6,
3800        1e-6,
3801        0xf149c2bcde660128,
3802        256,
3803        V4::new_checked_miri()
3804    );
3805
3806    linf_test!(
3807        test_linf_f16_scalar,
3808        f16,
3809        reference::reference_linf_f16_mathematical,
3810        1e-6,
3811        1e-6,
3812        0xf149c2bcde660128,
3813        256,
3814        Some(Scalar::new())
3815    );
3816
3817    #[cfg(target_arch = "x86_64")]
3818    linf_test!(
3819        test_linf_f16_v3,
3820        f16,
3821        reference::reference_linf_f16_mathematical,
3822        1e-6,
3823        1e-6,
3824        0xf149c2bcde660128,
3825        256,
3826        V3::new_checked()
3827    );
3828
3829    #[cfg(target_arch = "x86_64")]
3830    linf_test!(
3831        test_linf_f16_v4,
3832        f16,
3833        reference::reference_linf_f16_mathematical,
3834        1e-6,
3835        1e-6,
3836        0xf149c2bcde660128,
3837        256,
3838        V4::new_checked_miri()
3839    );
3840
3841    ////////////////
3842    // Miri Tests //
3843    ////////////////
3844
3845    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3846    enum DataType {
3847        Float32,
3848        Float16,
3849        UInt8,
3850        Int8,
3851    }
3852
3853    trait AsDataType {
3854        fn as_data_type() -> DataType;
3855    }
3856
3857    impl AsDataType for f32 {
3858        fn as_data_type() -> DataType {
3859            DataType::Float32
3860        }
3861    }
3862
3863    impl AsDataType for f16 {
3864        fn as_data_type() -> DataType {
3865            DataType::Float16
3866        }
3867    }
3868
3869    impl AsDataType for u8 {
3870        fn as_data_type() -> DataType {
3871            DataType::UInt8
3872        }
3873    }
3874
3875    impl AsDataType for i8 {
3876        fn as_data_type() -> DataType {
3877            DataType::Int8
3878        }
3879    }
3880
3881    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3882    enum Arch {
3883        Scalar,
3884        #[expect(non_camel_case_types)]
3885        X86_64_V3,
3886        #[expect(non_camel_case_types)]
3887        X86_64_V4,
3888    }
3889
3890    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3891    struct Key {
3892        arch: Arch,
3893        left: DataType,
3894        right: DataType,
3895    }
3896
3897    impl Key {
3898        fn new(arch: Arch, left: DataType, right: DataType) -> Self {
3899            Self { arch, left, right }
3900        }
3901    }
3902
3903    static MIRI_BOUNDS: LazyLock<HashMap<Key, usize>> = LazyLock::new(|| {
3904        use Arch::{Scalar, X86_64_V3, X86_64_V4};
3905        use DataType::{Float16, Float32, Int8, UInt8};
3906
3907        [
3908            (Key::new(Scalar, Float32, Float32), 64),
3909            (Key::new(X86_64_V3, Float32, Float32), 256),
3910            (Key::new(X86_64_V4, Float32, Float32), 256),
3911            (Key::new(Scalar, Float16, Float16), 64),
3912            (Key::new(X86_64_V3, Float16, Float16), 256),
3913            (Key::new(X86_64_V4, Float16, Float16), 256),
3914            (Key::new(Scalar, Float32, Float16), 64),
3915            (Key::new(X86_64_V3, Float32, Float16), 256),
3916            (Key::new(X86_64_V4, Float32, Float16), 256),
3917            (Key::new(Scalar, UInt8, UInt8), 64),
3918            (Key::new(X86_64_V3, UInt8, UInt8), 256),
3919            (Key::new(X86_64_V4, UInt8, UInt8), 320),
3920            (Key::new(Scalar, Int8, Int8), 64),
3921            (Key::new(X86_64_V3, Int8, Int8), 256),
3922            (Key::new(X86_64_V4, Int8, Int8), 320),
3923        ]
3924        .into_iter()
3925        .collect()
3926    });
3927
3928    macro_rules! test_bounds {
3929        (
3930            $function:ident,
3931            $left:ty,
3932            $left_ex:expr,
3933            $right:ty,
3934            $right_ex:expr
3935        ) => {
3936            #[test]
3937            fn $function() {
3938                let left: $left = $left_ex;
3939                let right: $right = $right_ex;
3940
3941                let left_type = <$left>::as_data_type();
3942                let right_type = <$right>::as_data_type();
3943
3944                // Scalar
3945                {
3946                    let max = MIRI_BOUNDS[&Key::new(Arch::Scalar, left_type, right_type)];
3947                    for dim in 0..max {
3948                        let left: Vec<$left> = vec![left; dim];
3949                        let right: Vec<$right> = vec![right; dim];
3950
3951                        let arch = diskann_wide::arch::Scalar;
3952                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
3953                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
3954                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3955                    }
3956                }
3957
3958                #[cfg(target_arch = "x86_64")]
3959                if let Some(arch) = V3::new_checked() {
3960                    let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V3, left_type, right_type)];
3961                    for dim in 0..max {
3962                        let left: Vec<$left> = vec![left; dim];
3963                        let right: Vec<$right> = vec![right; dim];
3964
3965                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
3966                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
3967                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3968                    }
3969                }
3970
3971                #[cfg(target_arch = "x86_64")]
3972                if let Some(arch) = V4::new_checked_miri() {
3973                    let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V4, left_type, right_type)];
3974                    for dim in 0..max {
3975                        let left: Vec<$left> = vec![left; dim];
3976                        let right: Vec<$right> = vec![right; dim];
3977
3978                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
3979                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
3980                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
3981                    }
3982                }
3983            }
3984        };
3985    }
3986
3987    test_bounds!(miri_test_bounds_f32xf32, f32, 1.0f32, f32, 2.0f32);
3988    test_bounds!(
3989        miri_test_bounds_f16xf16,
3990        f16,
3991        diskann_wide::cast_f32_to_f16(1.0f32),
3992        f16,
3993        diskann_wide::cast_f32_to_f16(2.0f32)
3994    );
3995    test_bounds!(
3996        miri_test_bounds_f32xf16,
3997        f32,
3998        1.0f32,
3999        f16,
4000        diskann_wide::cast_f32_to_f16(2.0f32)
4001    );
4002    test_bounds!(miri_test_bounds_u8xu8, u8, 1u8, u8, 1u8);
4003    test_bounds!(miri_test_bounds_i8xi8, i8, 1i8, i8, 1i8);
4004}