Skip to main content

diskann_vector/distance/
simd.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::convert::AsRef;
7
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10
11#[cfg(target_arch = "aarch64")]
12use diskann_wide::arch::aarch64::{algorithms, Neon};
13
14use diskann_wide::{
15    arch::Scalar, Architecture, Const, Constant, Emulated, SIMDAbs, SIMDDotProduct, SIMDMulAdd,
16    SIMDSumTree, SIMDVector,
17};
18
19use crate::Half;
20
21/// A helper trait to allow integer to f32 conversion (which may be lossy).
22pub trait LossyF32Conversion: Copy {
23    fn as_f32_lossy(self) -> f32;
24}
25
26impl LossyF32Conversion for f32 {
27    fn as_f32_lossy(self) -> f32 {
28        self
29    }
30}
31
32impl LossyF32Conversion for i32 {
33    fn as_f32_lossy(self) -> f32 {
34        self as f32
35    }
36}
37
38impl LossyF32Conversion for u32 {
39    fn as_f32_lossy(self) -> f32 {
40        self as f32
41    }
42}
43
44cfg_if::cfg_if! {
45    if #[cfg(miri)] {
46        fn force_eval(_x: f32) {}
47    } else if #[cfg(target_arch = "x86_64")] {
48        use std::arch::asm;
49
50        /// Force the evaluation of the argument, preventing the compiler from reordering the
51        /// computation of `x` behind a condition.
52        ///
53        /// In the context of Cosine similarity, this can help code generation for
54        /// static-dimensional kernels
55        #[inline(always)]
56        fn force_eval(x: f32) {
57            // SAFETY: This function executes no instructions. As such, it satisfies the long
58            // list of requirements for inline assembly.
59            //
60            // See: https://doc.rust-lang.org/reference/inline-assembly.html#rules-for-inline-assembly
61            unsafe {
62                asm!(
63                    // Assembly comment to "use" the argument
64                    "/* {0} */",
65                    // Use an `xmm_reg` since LLVM almost always uses such a register for
66                    // scalar floating point
67                    in(xmm_reg) x,
68                    // Explanation:
69                    // * `nostack`: This function does not touch the stack, so the compiler
70                    //   does not need to worry that the stack gets messed up.
71                    // * `nomem`: This function does not touch memory. The compiler doesn't
72                    //   have to reload any values.
73                    // * `preserves_flags`: This function preserves architectural condition
74                    //   flags. We can make this guarantee because this function literally
75                    //   does nothing.
76                    options(nostack, nomem, preserves_flags)
77                )
78            }
79        }
80    } else {
81        // Fallback implementation.
82        fn force_eval(_x: f32) {}
83    }
84}
85
86/// A utility struct to help with SIMD loading.
87///
88/// The main loop of SIMD kernels consists of various tilings of loads and arithmetic.
89/// Outside of the epilogue, these loads are all full-width vector loads.
90///
91/// To aid in defining different tilings, this struct takes the base pointers for left and
92/// right hand pointers and provides a `load` method to extract full vectors for both
93/// the left and right-hand sides.
94///
95/// This works in conjunction with the [`SIMDSchema`] to help write unrolled loops.
96#[derive(Debug, Clone, Copy)]
97pub struct Loader<Schema, Left, Right, A>
98where
99    Schema: SIMDSchema<Left, Right, A>,
100    A: Architecture,
101{
102    arch: A,
103    schema: Schema,
104    left: *const Left,
105    right: *const Right,
106    len: usize,
107}
108
109impl<Schema, Left, Right, A> Loader<Schema, Left, Right, A>
110where
111    Schema: SIMDSchema<Left, Right, A>,
112    A: Architecture,
113{
114    /// Construct a new loader for the left and right hand pointers.
115    ///
116    /// Requires that the memory ranges `[left, left + len)` and `[right, right + len)` are
117    /// both valid, where `len` is the *number* of the elements of type `T` and `U`.
118    #[inline(always)]
119    fn new(arch: A, schema: Schema, left: *const Left, right: *const Right, len: usize) -> Self {
120        Self {
121            arch,
122            schema,
123            left,
124            right,
125            len,
126        }
127    }
128
129    /// Return the underlying architecture.
130    #[inline(always)]
131    fn arch(&self) -> A {
132        self.arch
133    }
134
135    /// Return the SIMD Schema.
136    #[inline(always)]
137    fn schema(&self) -> Schema {
138        self.schema
139    }
140
141    /// Load full width vectors for the left and right hand memory spans.
142    ///
143    /// This loads a [`SIMDSchema::SIMDWidth`] chunk of data using the following formula:
144    ///
145    /// ```text
146    /// // The number of elements in an unrolled [`MainLoop`].
147    /// let simd_width = Schema::SIMDWidth::value();
148    /// let block_size = simd_width * Schema::Main::BLOCK_SIZE;
149    ///
150    /// load(px + block_size * block + simd_width * offset);
151    /// ```
152    ///
153    /// # Safety
154    ///
155    /// Requires that the following memory addresses are in-bounds (i.e., the highest
156    /// read address is at an offset less than `len`):
157    ///
158    /// ```text
159    /// [
160    ///     px + block_size * block + simd_width * offset,
161    ///     px + block_size * block + simd_width * (offset + 1)
162    /// )
163    ///
164    /// [
165    ///     py + block_size * block + simd_width * offset,
166    ///     py + block_size * block + simd_width * (offset + 1)
167    /// )
168    /// ```
169    ///
170    /// This invariant is checked in debug builds.
171    #[inline(always)]
172    unsafe fn load(&self, block: usize, offset: usize) -> (Schema::Left, Schema::Right) {
173        let stride = Schema::SIMDWidth::value();
174        let block_stride = stride * Schema::Main::BLOCK_SIZE;
175        let offset = block_stride * block + stride * offset;
176
177        debug_assert!(
178            offset + stride <= self.len,
179            "length = {}, offset = {}",
180            self.len,
181            offset
182        );
183
184        (
185            Schema::Left::load_simd(self.arch, self.left.add(offset)),
186            Schema::Right::load_simd(self.arch, self.right.add(offset)),
187        )
188    }
189}
190
191/// A representation of the main unrolled-loop for SIMD kernels.
192pub trait MainLoop {
193    /// The effective number of unrolling (in terms of SIMD vectors) performed by this
194    /// kernel. For example, if `BLOCK_SIZE = 4` and the SIMD width is 8, than each iteration
195    /// of the main loop will process `4 * 8 = 32` elements.
196    ///
197    /// This parameter will be used to compute the number of full-width epilogues that need
198    /// to be executed.
199    const BLOCK_SIZE: usize;
200
201    /// Perform the main unrolled loops of a SIMD kernel. This loop is expected to process
202    /// all elements in the range `[0, trip_count * S::get_simd_width() * Self::BLOCK_SIZE)`
203    /// and return an accumulator consisting of the result.
204    ///
205    /// # Arguments
206    ///
207    /// * `loader`: A SIMD loader to emit loads to the two source spans.
208    /// * `trip_count`: The number of blocks of size `BLOCK_SIZE` to process. A single "trip"
209    ///   will process `S::get_simd_width() * Self::BLOCK_SIZE` elements. So, computation of
210    ///   `trip_count` should be computed as:
211    ///   ```math
212    ///   let trip_count = len / (S::get_simd_width() * <_ as MainLoop>::BLOCK_SIZE);
213    ///   ```
214    /// * `epilogues`: The number of `S::get_simd_width()` vectors remaining after all the
215    ///   main blocks have been processed. This is guaranteed to be less than
216    ///   `Self::BLOCK_SIZE`.
217    ///
218    /// # Safety
219    ///
220    /// All elements in the accessed range must be valid. The memory addresses touched are
221    ///
222    /// ```text
223    /// let block_size = Self::BLOCK_SIZE;
224    /// let simd_width = S::get_simd_width();
225    /// [
226    ///     loader.left,
227    ///     loader.left + trip_count * simd_width + block_size + epilogues * simd_width
228    /// )
229    /// [
230    ///     loader.right,
231    ///     loader.right + trip_count * simd_width + block_size + epilogues * simd_width
232    /// )
233    /// ```
234    ///
235    /// The `loader` will ensure that all accesses are in-bounds in debug builds.
236    unsafe fn main<S, L, R, A>(
237        loader: &Loader<S, L, R, A>,
238        trip_count: usize,
239        epilogues: usize,
240    ) -> S::Accumulator
241    where
242        A: Architecture,
243        S: SIMDSchema<L, R, A>;
244}
245/// An inner loop implementation strategy using 1 parallel instances of the schema
246/// accumulator with a manual inner loop unroll of 1.
247pub struct Strategy1x1;
248
249/// An inner loop implementation strategy using 2 parallel instances of the schema
250/// accumulator with a manual inner loop unroll of 1.
251pub struct Strategy2x1;
252
253/// An inner loop implementation strategy using 4 parallel instances of the schema
254/// accumulator with a manual inner loop unroll of 1.
255pub struct Strategy4x1;
256
257/// An inner loop implementation strategy using 4 parallel instances of the schema
258/// accumulator with a manual inner loop unroll of 2.
259pub struct Strategy4x2;
260
261/// An inner loop implementation strategy using 2 parallel instances of the schema
262/// accumulator with a manual inner loop unroll of 4.
263pub struct Strategy2x4;
264
265impl MainLoop for Strategy1x1 {
266    const BLOCK_SIZE: usize = 1;
267
268    #[inline(always)]
269    unsafe fn main<S, L, R, A>(
270        loader: &Loader<S, L, R, A>,
271        trip_count: usize,
272        _epilogues: usize,
273    ) -> S::Accumulator
274    where
275        A: Architecture,
276        S: SIMDSchema<L, R, A>,
277    {
278        let arch = loader.arch();
279        let schema = loader.schema();
280
281        let mut s0 = schema.init(arch);
282        for i in 0..trip_count {
283            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
284        }
285
286        s0
287    }
288}
289
290impl MainLoop for Strategy2x1 {
291    const BLOCK_SIZE: usize = 2;
292
293    #[inline(always)]
294    unsafe fn main<S, L, R, A>(
295        loader: &Loader<S, L, R, A>,
296        trip_count: usize,
297        epilogues: usize,
298    ) -> S::Accumulator
299    where
300        A: Architecture,
301        S: SIMDSchema<L, R, A>,
302    {
303        let arch = loader.arch();
304        let schema = loader.schema();
305
306        let mut s0 = schema.init(arch);
307        let mut s1 = schema.init(arch);
308
309        for i in 0..trip_count {
310            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
311            s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
312        }
313
314        let mut s = schema.combine(s0, s1);
315        if epilogues != 0 {
316            s = schema.accumulate_tuple(s, loader.load(trip_count, 0));
317        }
318
319        s
320    }
321}
322
323impl MainLoop for Strategy4x1 {
324    const BLOCK_SIZE: usize = 4;
325
326    #[inline(always)]
327    unsafe fn main<S, L, R, A>(
328        loader: &Loader<S, L, R, A>,
329        trip_count: usize,
330        epilogues: usize,
331    ) -> S::Accumulator
332    where
333        A: Architecture,
334        S: SIMDSchema<L, R, A>,
335    {
336        let arch = loader.arch();
337        let schema = loader.schema();
338
339        let mut s0 = schema.init(arch);
340        let mut s1 = schema.init(arch);
341        let mut s2 = schema.init(arch);
342        let mut s3 = schema.init(arch);
343
344        for i in 0..trip_count {
345            s0 = schema.accumulate_tuple(s0, loader.load(i, 0));
346            s1 = schema.accumulate_tuple(s1, loader.load(i, 1));
347            s2 = schema.accumulate_tuple(s2, loader.load(i, 2));
348            s3 = schema.accumulate_tuple(s3, loader.load(i, 3));
349        }
350
351        if epilogues >= 1 {
352            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
353        }
354
355        if epilogues >= 2 {
356            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
357        }
358
359        if epilogues >= 3 {
360            s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
361        }
362
363        schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
364    }
365}
366
367impl MainLoop for Strategy4x2 {
368    const BLOCK_SIZE: usize = 4;
369
370    #[inline(always)]
371    unsafe fn main<S, L, R, A>(
372        loader: &Loader<S, L, R, A>,
373        trip_count: usize,
374        epilogues: usize,
375    ) -> S::Accumulator
376    where
377        A: Architecture,
378        S: SIMDSchema<L, R, A>,
379    {
380        let arch = loader.arch();
381        let schema = loader.schema();
382
383        let mut s0 = schema.init(arch);
384        let mut s1 = schema.init(arch);
385        let mut s2 = schema.init(arch);
386        let mut s3 = schema.init(arch);
387
388        for i in 0..(trip_count / 2) {
389            let j = 2 * i;
390            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
391            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
392            s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
393            s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
394
395            s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
396            s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
397            s2 = schema.accumulate_tuple(s2, loader.load(j, 6));
398            s3 = schema.accumulate_tuple(s3, loader.load(j, 7));
399        }
400
401        if !trip_count.is_multiple_of(2) {
402            // Will not underflow because `trip_count` is odd.
403            let j = trip_count - 1;
404            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
405            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
406            s2 = schema.accumulate_tuple(s2, loader.load(j, 2));
407            s3 = schema.accumulate_tuple(s3, loader.load(j, 3));
408        }
409
410        if epilogues >= 1 {
411            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
412        }
413
414        if epilogues >= 2 {
415            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
416        }
417
418        if epilogues >= 3 {
419            s2 = schema.accumulate_tuple(s2, loader.load(trip_count, 2));
420        }
421
422        schema.combine(schema.combine(s0, s1), schema.combine(s2, s3))
423    }
424}
425
426impl MainLoop for Strategy2x4 {
427    const BLOCK_SIZE: usize = 4;
428
429    /// The implementation here has a global unroll of 4, but the unroll factor of the main
430    /// loop is actually 8.
431    ///
432    /// There is a single peeled iteration at the end that handles the last group of 4
433    /// if needed.
434    #[inline(always)]
435    unsafe fn main<S, L, R, A>(
436        loader: &Loader<S, L, R, A>,
437        trip_count: usize,
438        epilogues: usize,
439    ) -> S::Accumulator
440    where
441        A: Architecture,
442        S: SIMDSchema<L, R, A>,
443    {
444        let arch = loader.arch();
445        let schema = loader.schema();
446
447        let mut s0 = schema.init(arch);
448        let mut s1 = schema.init(arch);
449
450        for i in 0..(trip_count / 2) {
451            let j = 2 * i;
452            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
453            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
454            s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
455            s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
456
457            s0 = schema.accumulate_tuple(s0, loader.load(j, 4));
458            s1 = schema.accumulate_tuple(s1, loader.load(j, 5));
459            s0 = schema.accumulate_tuple(s0, loader.load(j, 6));
460            s1 = schema.accumulate_tuple(s1, loader.load(j, 7));
461        }
462
463        if !trip_count.is_multiple_of(2) {
464            let j = trip_count - 1;
465            s0 = schema.accumulate_tuple(s0, loader.load(j, 0));
466            s1 = schema.accumulate_tuple(s1, loader.load(j, 1));
467            s0 = schema.accumulate_tuple(s0, loader.load(j, 2));
468            s1 = schema.accumulate_tuple(s1, loader.load(j, 3));
469        }
470
471        if epilogues >= 1 {
472            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 0));
473        }
474
475        if epilogues >= 2 {
476            s1 = schema.accumulate_tuple(s1, loader.load(trip_count, 1));
477        }
478
479        if epilogues >= 3 {
480            s0 = schema.accumulate_tuple(s0, loader.load(trip_count, 2));
481        }
482
483        schema.combine(s0, s1)
484    }
485}
486
487/// An interface trait for SIMD operations.
488///
489/// Patterns like unrolling, pointer arithmetic, and epilogue handling are common across
490/// many different combinations of left and right hand types for distance computations.
491///
492/// This higher level handling is delegated to functions like `simd_op`, which in turn
493/// uses a `SIMDSchema` to customize the mechanics of loading and accumulation.
494pub trait SIMDSchema<T, U, A: Architecture = diskann_wide::arch::Current>: Copy {
495    /// The desired SIMD read width.
496    /// Reads from the input slice will be use this stride when accessing memory.
497    type SIMDWidth: Constant<Type = usize>;
498
499    /// The type used to represent partial accumulated values.
500    type Accumulator: std::ops::Add<Output = Self::Accumulator> + std::fmt::Debug + Copy;
501
502    /// The type used for the left-hand side.
503    type Left: SIMDVector<Arch = A, Scalar = T, ConstLanes = Self::SIMDWidth>;
504
505    /// The type used for the right-hand side.
506    type Right: SIMDVector<Arch = A, Scalar = U, ConstLanes = Self::SIMDWidth>;
507
508    /// The final return type.
509    /// This is often `f32` for complete distance functions, but need not always be.
510    type Return;
511
512    /// The implementation of the main loop.
513    type Main: MainLoop;
514
515    /// Initialize an empty (identity) accumulator.
516    fn init(&self, arch: A) -> Self::Accumulator;
517
518    /// Perform an accumulation.
519    fn accumulate(
520        &self,
521        x: Self::Left,
522        y: Self::Right,
523        acc: Self::Accumulator,
524    ) -> Self::Accumulator;
525
526    /// Combine two independent accumulators (allows for unrolling).
527    #[inline(always)]
528    fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
529        x + y
530    }
531
532    /// A supplied trait for dealing with non-full-width epilogues.
533    /// Often, masked based loading will do the right thing, but for architectures like AVX2
534    /// that have limited support for masking 8 and 16-bit operations, using a scalar
535    /// fallback may just be better.
536    ///
537    /// This provides a customization point to enable a scalar fallback.
538    ///
539    /// # Safety
540    ///
541    /// * Both pointers `x` and `y` must point to memory.
542    /// * It must be safe to read `len` contiguous items of type `T` starting at `x` and
543    ///   `len` contiguous items of type `U` starting at `y`.
544    ///
545    /// The following guarantee is made:
546    ///
547    /// * No read will be emitted to memory locations at and after `x.add(len)` and
548    ///   `y.add(len)`.
549    #[inline(always)]
550    unsafe fn epilogue(
551        &self,
552        arch: A,
553        x: *const T,
554        y: *const U,
555        len: usize,
556        acc: Self::Accumulator,
557    ) -> Self::Accumulator {
558        // SAFETY: Performing this read is safe by the safety preconditions of `epilogue`.
559        // Guarentee: The load implementation must be correct.
560        let a = Self::Left::load_simd_first(arch, x, len);
561
562        // SAFETY: Performing this read is safe by the safety preconditions of `epilogue`.
563        // Guarentee: The load implementation must be correct.
564        let b = Self::Right::load_simd_first(arch, y, len);
565        self.accumulate(a, b, acc)
566    }
567
568    /// Perform a reduction on the accumulator to yield the final result.
569    ///
570    /// This will be called at the end of distance processing.
571    fn reduce(&self, x: Self::Accumulator) -> Self::Return;
572
573    /// !! Do not extend this function !!
574    ///
575    /// Due to limitations on how associated constants can be used, we need a function
576    /// to access the SIMD width and rely on the compiler to constant propagate the result.
577    #[inline(always)]
578    fn get_simd_width() -> usize {
579        Self::SIMDWidth::value()
580    }
581
582    /// !! Do not extend this function !!
583    ///
584    /// Due to limitations on how associated constants can be used, we need a function
585    /// to access the unroll factor of the main loop and rely on the compiler to constant
586    /// propagate the result.
587    #[inline(always)]
588    fn get_main_bocksize() -> usize {
589        Self::Main::BLOCK_SIZE
590    }
591
592    /// A helper method to access [`Self::accumulate`] in a way that is immediately
593    /// compatible with [`Loader::load`].
594    #[doc(hidden)]
595    #[inline(always)]
596    fn accumulate_tuple(
597        &self,
598        acc: Self::Accumulator,
599        (x, y): (Self::Left, Self::Right),
600    ) -> Self::Accumulator {
601        self.accumulate(x, y, acc)
602    }
603}
604
605/// In some contexts - it can be beneficial to begin a computation on one pair of slices and
606/// then store intermediate state for resumption on another pair of slices.
607///
608/// A good example of this is direct-computation of PQ distances where different chunks need
609/// to be gathered and partially accumulated before the final reduction.
610///
611/// The `ResumableSchema` provides a relatively straight-forward way of achieving this.
612pub trait ResumableSIMDSchema<T, U, A = diskann_wide::arch::Current>: Copy
613where
614    A: Architecture,
615{
616    // The associated type for this function that is non-reentrant.
617    type NonResumable: SIMDSchema<T, U, A> + Default;
618    type FinalReturn;
619
620    fn init(arch: A) -> Self;
621    fn combine_with(&self, other: <Self::NonResumable as SIMDSchema<T, U, A>>::Accumulator)
622        -> Self;
623    fn sum(&self) -> Self::FinalReturn;
624}
625
626#[derive(Debug, Clone, Copy)]
627pub struct Resumable<T>(T);
628
629impl<T> Resumable<T> {
630    pub fn new(val: T) -> Self {
631        Self(val)
632    }
633
634    pub fn consume(self) -> T {
635        self.0
636    }
637}
638
639impl<T, U, R, A> SIMDSchema<T, U, A> for Resumable<R>
640where
641    A: Architecture,
642    R: ResumableSIMDSchema<T, U, A>,
643{
644    type SIMDWidth = <R::NonResumable as SIMDSchema<T, U, A>>::SIMDWidth;
645    type Accumulator = <R::NonResumable as SIMDSchema<T, U, A>>::Accumulator;
646    type Left = <R::NonResumable as SIMDSchema<T, U, A>>::Left;
647    type Right = <R::NonResumable as SIMDSchema<T, U, A>>::Right;
648    type Return = Self;
649    type Main = <R::NonResumable as SIMDSchema<T, U, A>>::Main;
650
651    fn init(&self, arch: A) -> Self::Accumulator {
652        R::NonResumable::default().init(arch)
653    }
654
655    fn accumulate(
656        &self,
657        x: Self::Left,
658        y: Self::Right,
659        acc: Self::Accumulator,
660    ) -> Self::Accumulator {
661        R::NonResumable::default().accumulate(x, y, acc)
662    }
663
664    fn combine(&self, x: Self::Accumulator, y: Self::Accumulator) -> Self::Accumulator {
665        R::NonResumable::default().combine(x, y)
666    }
667
668    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
669        Self(self.0.combine_with(x))
670    }
671}
672
673#[inline(never)]
674#[allow(clippy::panic)]
675fn emit_length_error(xlen: usize, ylen: usize) -> ! {
676    panic!(
677        "lengths must be equal, instead got: xlen = {}, ylen = {}",
678        xlen, ylen
679    )
680}
681
682/// A SIMD executor for binary ops using the provided `SIMDSchema`.
683///
684/// # Panics
685///
686/// Panics if `x.len() != y.len()`.
687#[inline(always)]
688pub fn simd_op<L, R, S, T, U, A>(schema: &S, arch: A, x: T, y: U) -> S::Return
689where
690    A: Architecture,
691    T: AsRef<[L]>,
692    U: AsRef<[R]>,
693    S: SIMDSchema<L, R, A>,
694{
695    let x: &[L] = x.as_ref();
696    let y: &[R] = y.as_ref();
697
698    let len = x.len();
699
700    // The two lengths of the vectors must be the same.
701    // Eventually - it will probably be worth looking into various wrapper functions for
702    // `simd_op` that perform this checking, but for now, consider providing two
703    // different-length slices as a hard program bug.
704    //
705    // N.B.: Redirect through `emit_length_error` to keep code generation as clean as
706    // possible.
707    if len != y.len() {
708        emit_length_error(len, y.len());
709    }
710    let px = x.as_ptr();
711    let py = y.as_ptr();
712
713    // N.B.: Due to limitations in Rust's handling of const generics (and outer type
714    // parameters), we cannot just reach into `S` and pull out the constant SIMDWidth.
715    //
716    // Instead, we need to go through a helper function. Since associated functions cannot
717    // be marked as `const`, we cannot require that the extracted width is evaluated at
718    // compile time.
719    //
720    // HOWEVER, compilers are very good at optimizing these kinds of patterns and
721    // recognizing that this value is indeed constant and optimizing accordingly.
722    let simd_width: usize = S::get_simd_width();
723    let unroll: usize = S::get_main_bocksize();
724
725    let trip_count = len / (simd_width * unroll);
726    let epilogues = (len - simd_width * unroll * trip_count) / simd_width;
727
728    // Create a loader that (in debug mode) will check that all of our full-width accesses
729    // are in-bounds.
730    let loader: Loader<S, L, R, A> = Loader::new(arch, *schema, px, py, len);
731
732    // SAFETY: The value of `trip_count`  and `epilogues` so
733    // `[0, trip_count * simd_width * unroll + epilogues * simd_width)` is in-bounds,
734    // satifying the requirements of `main`.
735    let mut s0 = unsafe { <S as SIMDSchema<L, R, A>>::Main::main(&loader, trip_count, epilogues) };
736
737    let remainder = len % simd_width;
738    if remainder != 0 {
739        let i = len - remainder;
740
741        // SAFETY: We have ensured that the lengths of the two inputs are the same.
742        //
743        // Furthermore, preceding computations on the induction variable mean that the
744        // remaining memory must be valid.
745        s0 = unsafe { schema.epilogue(arch, px.add(i), py.add(i), remainder, s0) };
746    }
747
748    schema.reduce(s0)
749}
750
751//----------//
752// Epilogue //
753//----------//
754
755#[cfg(target_arch = "aarch64")]
756#[inline(always)]
757unsafe fn scalar_epilogue<L, R, F, Acc>(
758    left: *const L,
759    right: *const R,
760    len: usize,
761    mut acc: Acc,
762    mut f: F,
763) -> Acc
764where
765    L: Copy,
766    R: Copy,
767    F: FnMut(Acc, L, R) -> Acc,
768{
769    for i in 0..len {
770        // SAFETY: The range `[x, x.add(len))` is valid for reads.
771        let left = unsafe { left.add(i).read_unaligned() };
772        // SAFETY: The range `[y, y.add(len))` is valid for reads.
773        let right = unsafe { right.add(i).read_unaligned() };
774        acc = f(acc, left, right);
775    }
776    acc
777}
778
779/////
780///// L2 Implementations
781/////
782
783// A pure L2 distance function that provides a final reduction.
784#[derive(Debug, Default, Clone, Copy)]
785pub struct L2;
786
787#[cfg(target_arch = "x86_64")]
788impl SIMDSchema<f32, f32, V4> for L2 {
789    type SIMDWidth = Const<8>;
790    type Accumulator = <V4 as Architecture>::f32x8;
791    type Left = <V4 as Architecture>::f32x8;
792    type Right = <V4 as Architecture>::f32x8;
793    type Return = f32;
794    type Main = Strategy4x1;
795
796    #[inline(always)]
797    fn init(&self, arch: V4) -> Self::Accumulator {
798        Self::Accumulator::default(arch)
799    }
800
801    #[inline(always)]
802    fn accumulate(
803        &self,
804        x: Self::Left,
805        y: Self::Right,
806        acc: Self::Accumulator,
807    ) -> Self::Accumulator {
808        let c = x - y;
809        c.mul_add_simd(c, acc)
810    }
811
812    #[inline(always)]
813    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
814        x.sum_tree()
815    }
816}
817
818#[cfg(target_arch = "x86_64")]
819impl SIMDSchema<f32, f32, V3> for L2 {
820    type SIMDWidth = Const<8>;
821    type Accumulator = <V3 as Architecture>::f32x8;
822    type Left = <V3 as Architecture>::f32x8;
823    type Right = <V3 as Architecture>::f32x8;
824    type Return = f32;
825    type Main = Strategy4x1;
826
827    #[inline(always)]
828    fn init(&self, arch: V3) -> Self::Accumulator {
829        Self::Accumulator::default(arch)
830    }
831
832    #[inline(always)]
833    fn accumulate(
834        &self,
835        x: Self::Left,
836        y: Self::Right,
837        acc: Self::Accumulator,
838    ) -> Self::Accumulator {
839        let c = x - y;
840        c.mul_add_simd(c, acc)
841    }
842
843    #[inline(always)]
844    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
845        x.sum_tree()
846    }
847}
848
849#[cfg(target_arch = "aarch64")]
850impl SIMDSchema<f32, f32, Neon> for L2 {
851    type SIMDWidth = Const<4>;
852    type Accumulator = <Neon as Architecture>::f32x4;
853    type Left = <Neon as Architecture>::f32x4;
854    type Right = <Neon as Architecture>::f32x4;
855    type Return = f32;
856    type Main = Strategy4x1;
857
858    #[inline(always)]
859    fn init(&self, arch: Neon) -> Self::Accumulator {
860        Self::Accumulator::default(arch)
861    }
862
863    #[inline(always)]
864    fn accumulate(
865        &self,
866        x: Self::Left,
867        y: Self::Right,
868        acc: Self::Accumulator,
869    ) -> Self::Accumulator {
870        let c = x - y;
871        c.mul_add_simd(c, acc)
872    }
873
874    #[inline(always)]
875    unsafe fn epilogue(
876        &self,
877        arch: Neon,
878        x: *const f32,
879        y: *const f32,
880        len: usize,
881        acc: Self::Accumulator,
882    ) -> Self::Accumulator {
883        let scalar = scalar_epilogue(
884            x,
885            y,
886            len.min(Self::SIMDWidth::value() - 1),
887            0.0f32,
888            |acc, x, y| -> f32 {
889                let c = x - y;
890                c.mul_add(c, acc)
891            },
892        );
893        acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0])
894    }
895
896    #[inline(always)]
897    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
898        x.sum_tree()
899    }
900}
901
902impl SIMDSchema<f32, f32, Scalar> for L2 {
903    type SIMDWidth = Const<4>;
904    type Accumulator = Emulated<f32, 4>;
905    type Left = Emulated<f32, 4>;
906    type Right = Emulated<f32, 4>;
907    type Return = f32;
908    type Main = Strategy2x1;
909
910    #[inline(always)]
911    fn init(&self, arch: Scalar) -> Self::Accumulator {
912        Self::Accumulator::default(arch)
913    }
914
915    #[inline(always)]
916    fn accumulate(
917        &self,
918        x: Self::Left,
919        y: Self::Right,
920        acc: Self::Accumulator,
921    ) -> Self::Accumulator {
922        // Don't assume the presence of FMA.
923        let c = x - y;
924        (c * c) + acc
925    }
926
927    #[inline(always)]
928    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
929        x.sum_tree()
930    }
931
932    #[inline(always)]
933    unsafe fn epilogue(
934        &self,
935        arch: Scalar,
936        x: *const f32,
937        y: *const f32,
938        len: usize,
939        acc: Self::Accumulator,
940    ) -> Self::Accumulator {
941        let mut s: f32 = 0.0;
942        for i in 0..len {
943            // SAFETY: The range `[x, x.add(len))` is valid for reads.
944            let vx = unsafe { x.add(i).read_unaligned() };
945            // SAFETY: The range `[y, y.add(len))` is valid for reads.
946            let vy = unsafe { y.add(i).read_unaligned() };
947            let d = vx - vy;
948            s += d * d;
949        }
950        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
951    }
952}
953
954#[cfg(target_arch = "x86_64")]
955impl SIMDSchema<Half, Half, V4> for L2 {
956    type SIMDWidth = Const<8>;
957    type Accumulator = <V4 as Architecture>::f32x8;
958    type Left = <V4 as Architecture>::f16x8;
959    type Right = <V4 as Architecture>::f16x8;
960    type Return = f32;
961    type Main = Strategy2x4;
962
963    #[inline(always)]
964    fn init(&self, arch: V4) -> Self::Accumulator {
965        Self::Accumulator::default(arch)
966    }
967
968    #[inline(always)]
969    fn accumulate(
970        &self,
971        x: Self::Left,
972        y: Self::Right,
973        acc: Self::Accumulator,
974    ) -> Self::Accumulator {
975        diskann_wide::alias!(f32s = <V4>::f32x8);
976
977        let x: f32s = x.into();
978        let y: f32s = y.into();
979
980        let c = x - y;
981        c.mul_add_simd(c, acc)
982    }
983
984    #[inline(always)]
985    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
986        x.sum_tree()
987    }
988}
989
990#[cfg(target_arch = "x86_64")]
991impl SIMDSchema<Half, Half, V3> for L2 {
992    type SIMDWidth = Const<8>;
993    type Accumulator = <V3 as Architecture>::f32x8;
994    type Left = <V3 as Architecture>::f16x8;
995    type Right = <V3 as Architecture>::f16x8;
996    type Return = f32;
997    type Main = Strategy2x4;
998
999    #[inline(always)]
1000    fn init(&self, arch: V3) -> Self::Accumulator {
1001        Self::Accumulator::default(arch)
1002    }
1003
1004    #[inline(always)]
1005    fn accumulate(
1006        &self,
1007        x: Self::Left,
1008        y: Self::Right,
1009        acc: Self::Accumulator,
1010    ) -> Self::Accumulator {
1011        diskann_wide::alias!(f32s = <V3>::f32x8);
1012
1013        let x: f32s = x.into();
1014        let y: f32s = y.into();
1015
1016        let c = x - y;
1017        c.mul_add_simd(c, acc)
1018    }
1019
1020    // Perform a final reduction.
1021    #[inline(always)]
1022    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1023        x.sum_tree()
1024    }
1025}
1026
1027#[cfg(target_arch = "aarch64")]
1028impl SIMDSchema<Half, Half, Neon> for L2 {
1029    type SIMDWidth = Const<4>;
1030    type Accumulator = <Neon as Architecture>::f32x4;
1031    type Left = diskann_wide::arch::aarch64::f16x4;
1032    type Right = diskann_wide::arch::aarch64::f16x4;
1033    type Return = f32;
1034    type Main = Strategy4x1;
1035
1036    #[inline(always)]
1037    fn init(&self, arch: Neon) -> Self::Accumulator {
1038        Self::Accumulator::default(arch)
1039    }
1040
1041    #[inline(always)]
1042    fn accumulate(
1043        &self,
1044        x: Self::Left,
1045        y: Self::Right,
1046        acc: Self::Accumulator,
1047    ) -> Self::Accumulator {
1048        diskann_wide::alias!(f32s = <Neon>::f32x4);
1049
1050        let x: f32s = x.into();
1051        let y: f32s = y.into();
1052
1053        let c = x - y;
1054        c.mul_add_simd(c, acc)
1055    }
1056
1057    #[inline(always)]
1058    unsafe fn epilogue(
1059        &self,
1060        arch: Neon,
1061        x: *const Half,
1062        y: *const Half,
1063        len: usize,
1064        acc: Self::Accumulator,
1065    ) -> Self::Accumulator {
1066        diskann_wide::alias!(f32s = <Neon>::f32x4);
1067
1068        let rest = scalar_epilogue(
1069            x,
1070            y,
1071            len.min(Self::SIMDWidth::value() - 1),
1072            f32s::default(arch),
1073            |acc, x: Half, y: Half| -> f32s {
1074                let zero = Half::default();
1075                let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
1076                let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
1077                let c: f32s = x - y;
1078                c.mul_add_simd(c, acc)
1079            },
1080        );
1081        acc + rest
1082    }
1083
1084    #[inline(always)]
1085    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1086        x.sum_tree()
1087    }
1088}
1089
1090impl SIMDSchema<Half, Half, Scalar> for L2 {
1091    type SIMDWidth = Const<1>;
1092    type Accumulator = Emulated<f32, 1>;
1093    type Left = Emulated<Half, 1>;
1094    type Right = Emulated<Half, 1>;
1095    type Return = f32;
1096    type Main = Strategy1x1;
1097
1098    #[inline(always)]
1099    fn init(&self, arch: Scalar) -> Self::Accumulator {
1100        Self::Accumulator::default(arch)
1101    }
1102
1103    #[inline(always)]
1104    fn accumulate(
1105        &self,
1106        x: Self::Left,
1107        y: Self::Right,
1108        acc: Self::Accumulator,
1109    ) -> Self::Accumulator {
1110        let x: Self::Accumulator = x.into();
1111        let y: Self::Accumulator = y.into();
1112
1113        let c = x - y;
1114        acc + (c * c)
1115    }
1116
1117    #[inline(always)]
1118    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1119        x.to_array()[0]
1120    }
1121}
1122
1123impl<A> SIMDSchema<f32, Half, A> for L2
1124where
1125    A: Architecture,
1126{
1127    type SIMDWidth = Const<8>;
1128    type Accumulator = A::f32x8;
1129    type Left = A::f32x8;
1130    type Right = A::f16x8;
1131    type Return = f32;
1132    type Main = Strategy4x2;
1133
1134    #[inline(always)]
1135    fn init(&self, arch: A) -> Self::Accumulator {
1136        Self::Accumulator::default(arch)
1137    }
1138
1139    #[inline(always)]
1140    fn accumulate(
1141        &self,
1142        x: Self::Left,
1143        y: Self::Right,
1144        acc: Self::Accumulator,
1145    ) -> Self::Accumulator {
1146        let y: A::f32x8 = y.into();
1147        let c = x - y;
1148        c.mul_add_simd(c, acc)
1149    }
1150
1151    // Perform a final reduction.
1152    #[inline(always)]
1153    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1154        x.sum_tree()
1155    }
1156}
1157
1158#[cfg(target_arch = "x86_64")]
1159impl SIMDSchema<i8, i8, V4> for L2 {
1160    type SIMDWidth = Const<32>;
1161    type Accumulator = <V4 as Architecture>::i32x16;
1162    type Left = <V4 as Architecture>::i8x32;
1163    type Right = <V4 as Architecture>::i8x32;
1164    type Return = f32;
1165    type Main = Strategy4x1;
1166
1167    #[inline(always)]
1168    fn init(&self, arch: V4) -> Self::Accumulator {
1169        Self::Accumulator::default(arch)
1170    }
1171
1172    #[inline(always)]
1173    fn accumulate(
1174        &self,
1175        x: Self::Left,
1176        y: Self::Right,
1177        acc: Self::Accumulator,
1178    ) -> Self::Accumulator {
1179        diskann_wide::alias!(i16s = <V4>::i16x32);
1180
1181        let x: i16s = x.into();
1182        let y: i16s = y.into();
1183        let c = x - y;
1184        acc.dot_simd(c, c)
1185    }
1186
1187    #[inline(always)]
1188    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1189        x.sum_tree().as_f32_lossy()
1190    }
1191}
1192
1193#[cfg(target_arch = "x86_64")]
1194impl SIMDSchema<i8, i8, V3> for L2 {
1195    type SIMDWidth = Const<16>;
1196    type Accumulator = <V3 as Architecture>::i32x8;
1197    type Left = <V3 as Architecture>::i8x16;
1198    type Right = <V3 as Architecture>::i8x16;
1199    type Return = f32;
1200    type Main = Strategy4x1;
1201
1202    #[inline(always)]
1203    fn init(&self, arch: V3) -> Self::Accumulator {
1204        Self::Accumulator::default(arch)
1205    }
1206
1207    #[inline(always)]
1208    fn accumulate(
1209        &self,
1210        x: Self::Left,
1211        y: Self::Right,
1212        acc: Self::Accumulator,
1213    ) -> Self::Accumulator {
1214        diskann_wide::alias!(i16s = <V3>::i16x16);
1215
1216        let x: i16s = x.into();
1217        let y: i16s = y.into();
1218        let c = x - y;
1219        acc.dot_simd(c, c)
1220    }
1221
1222    // Perform a final reduction.
1223    #[inline(always)]
1224    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1225        x.sum_tree().as_f32_lossy()
1226    }
1227}
1228
1229#[cfg(target_arch = "aarch64")]
1230impl SIMDSchema<i8, i8, Neon> for L2 {
1231    type SIMDWidth = Const<16>;
1232    type Accumulator = <Neon as Architecture>::i32x8;
1233    type Left = diskann_wide::arch::aarch64::i8x16;
1234    type Right = diskann_wide::arch::aarch64::i8x16;
1235    type Return = f32;
1236    type Main = Strategy2x1;
1237
1238    #[inline(always)]
1239    fn init(&self, arch: Neon) -> Self::Accumulator {
1240        Self::Accumulator::default(arch)
1241    }
1242
1243    #[inline(always)]
1244    fn accumulate(
1245        &self,
1246        x: Self::Left,
1247        y: Self::Right,
1248        acc: Self::Accumulator,
1249    ) -> Self::Accumulator {
1250        algorithms::squared_euclidean_accum_i8x16(x, y, acc)
1251    }
1252
1253    #[inline(always)]
1254    unsafe fn epilogue(
1255        &self,
1256        arch: Neon,
1257        x: *const i8,
1258        y: *const i8,
1259        len: usize,
1260        acc: Self::Accumulator,
1261    ) -> Self::Accumulator {
1262        let scalar = scalar_epilogue(
1263            x,
1264            y,
1265            len.min(Self::SIMDWidth::value() - 1),
1266            0i32,
1267            |acc, x: i8, y: i8| -> i32 {
1268                let c = (x as i32) - (y as i32);
1269                acc + c * c
1270            },
1271        );
1272        acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0, 0, 0, 0, 0])
1273    }
1274
1275    // Perform a final reduction.
1276    #[inline(always)]
1277    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1278        x.sum_tree().as_f32_lossy()
1279    }
1280}
1281
1282impl SIMDSchema<i8, i8, Scalar> for L2 {
1283    type SIMDWidth = Const<4>;
1284    type Accumulator = Emulated<i32, 4>;
1285    type Left = Emulated<i8, 4>;
1286    type Right = Emulated<i8, 4>;
1287    type Return = f32;
1288    type Main = Strategy1x1;
1289
1290    #[inline(always)]
1291    fn init(&self, arch: Scalar) -> Self::Accumulator {
1292        Self::Accumulator::default(arch)
1293    }
1294
1295    #[inline(always)]
1296    fn accumulate(
1297        &self,
1298        x: Self::Left,
1299        y: Self::Right,
1300        acc: Self::Accumulator,
1301    ) -> Self::Accumulator {
1302        let x: Self::Accumulator = x.into();
1303        let y: Self::Accumulator = y.into();
1304        let c = x - y;
1305        acc + c * c
1306    }
1307
1308    // Perform a final reduction.
1309    #[inline(always)]
1310    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1311        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1312    }
1313
1314    #[inline(always)]
1315    unsafe fn epilogue(
1316        &self,
1317        arch: Scalar,
1318        x: *const i8,
1319        y: *const i8,
1320        len: usize,
1321        acc: Self::Accumulator,
1322    ) -> Self::Accumulator {
1323        let mut s: i32 = 0;
1324        for i in 0..len {
1325            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1326            let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
1327            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1328            let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
1329            let d = vx - vy;
1330            s += d * d;
1331        }
1332        acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1333    }
1334}
1335
1336#[cfg(target_arch = "x86_64")]
1337impl SIMDSchema<u8, u8, V4> for L2 {
1338    type SIMDWidth = Const<32>;
1339    type Accumulator = <V4 as Architecture>::i32x16;
1340    type Left = <V4 as Architecture>::u8x32;
1341    type Right = <V4 as Architecture>::u8x32;
1342    type Return = f32;
1343    type Main = Strategy4x1;
1344
1345    #[inline(always)]
1346    fn init(&self, arch: V4) -> Self::Accumulator {
1347        Self::Accumulator::default(arch)
1348    }
1349
1350    #[inline(always)]
1351    fn accumulate(
1352        &self,
1353        x: Self::Left,
1354        y: Self::Right,
1355        acc: Self::Accumulator,
1356    ) -> Self::Accumulator {
1357        diskann_wide::alias!(i16s = <V4>::i16x32);
1358
1359        let x: i16s = x.into();
1360        let y: i16s = y.into();
1361        let c = x - y;
1362        acc.dot_simd(c, c)
1363    }
1364
1365    #[inline(always)]
1366    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1367        x.sum_tree().as_f32_lossy()
1368    }
1369}
1370
1371#[cfg(target_arch = "x86_64")]
1372impl SIMDSchema<u8, u8, V3> for L2 {
1373    type SIMDWidth = Const<16>;
1374    type Accumulator = <V3 as Architecture>::i32x8;
1375    type Left = <V3 as Architecture>::u8x16;
1376    type Right = <V3 as Architecture>::u8x16;
1377    type Return = f32;
1378    type Main = Strategy4x1;
1379
1380    #[inline(always)]
1381    fn init(&self, arch: V3) -> Self::Accumulator {
1382        Self::Accumulator::default(arch)
1383    }
1384
1385    #[inline(always)]
1386    fn accumulate(
1387        &self,
1388        x: Self::Left,
1389        y: Self::Right,
1390        acc: Self::Accumulator,
1391    ) -> Self::Accumulator {
1392        diskann_wide::alias!(i16s = <V3>::i16x16);
1393
1394        let x: i16s = x.into();
1395        let y: i16s = y.into();
1396        let c = x - y;
1397        acc.dot_simd(c, c)
1398    }
1399
1400    // Perform a final reduction.
1401    #[inline(always)]
1402    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1403        x.sum_tree().as_f32_lossy()
1404    }
1405}
1406
1407#[cfg(target_arch = "aarch64")]
1408impl SIMDSchema<u8, u8, Neon> for L2 {
1409    type SIMDWidth = Const<16>;
1410    type Accumulator = <Neon as Architecture>::u32x8;
1411    type Left = diskann_wide::arch::aarch64::u8x16;
1412    type Right = diskann_wide::arch::aarch64::u8x16;
1413    type Return = f32;
1414    type Main = Strategy2x1;
1415
1416    #[inline(always)]
1417    fn init(&self, arch: Neon) -> Self::Accumulator {
1418        Self::Accumulator::default(arch)
1419    }
1420
1421    #[inline(always)]
1422    fn accumulate(
1423        &self,
1424        x: Self::Left,
1425        y: Self::Right,
1426        acc: Self::Accumulator,
1427    ) -> Self::Accumulator {
1428        algorithms::squared_euclidean_accum_u8x16(x, y, acc)
1429    }
1430
1431    #[inline(always)]
1432    unsafe fn epilogue(
1433        &self,
1434        arch: Neon,
1435        x: *const u8,
1436        y: *const u8,
1437        len: usize,
1438        acc: Self::Accumulator,
1439    ) -> Self::Accumulator {
1440        let scalar = scalar_epilogue(
1441            x,
1442            y,
1443            len.min(Self::SIMDWidth::value() - 1),
1444            0u32,
1445            |acc, x: u8, y: u8| -> u32 {
1446                let c = (x as i32) - (y as i32);
1447                acc + ((c * c) as u32)
1448            },
1449        );
1450        acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0, 0, 0, 0, 0])
1451    }
1452
1453    // Perform a final reduction.
1454    #[inline(always)]
1455    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1456        x.sum_tree().as_f32_lossy()
1457    }
1458}
1459
1460impl SIMDSchema<u8, u8, Scalar> for L2 {
1461    type SIMDWidth = Const<4>;
1462    type Accumulator = Emulated<i32, 4>;
1463    type Left = Emulated<u8, 4>;
1464    type Right = Emulated<u8, 4>;
1465    type Return = f32;
1466    type Main = Strategy1x1;
1467
1468    #[inline(always)]
1469    fn init(&self, arch: Scalar) -> Self::Accumulator {
1470        Self::Accumulator::default(arch)
1471    }
1472
1473    #[inline(always)]
1474    fn accumulate(
1475        &self,
1476        x: Self::Left,
1477        y: Self::Right,
1478        acc: Self::Accumulator,
1479    ) -> Self::Accumulator {
1480        let x: Self::Accumulator = x.into();
1481        let y: Self::Accumulator = y.into();
1482        let c = x - y;
1483        acc + c * c
1484    }
1485
1486    // Perform a final reduction.
1487    #[inline(always)]
1488    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1489        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
1490    }
1491
1492    #[inline(always)]
1493    unsafe fn epilogue(
1494        &self,
1495        arch: Scalar,
1496        x: *const u8,
1497        y: *const u8,
1498        len: usize,
1499        acc: Self::Accumulator,
1500    ) -> Self::Accumulator {
1501        let mut s: i32 = 0;
1502        for i in 0..len {
1503            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1504            let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
1505            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1506            let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
1507            let d = vx - vy;
1508            s += d * d;
1509        }
1510        acc + Self::Accumulator::from_array(arch, [s, 0, 0, 0])
1511    }
1512}
1513
1514// A L2 distance function that defers a final reduction, allowing for a distance
1515// computation to take place across multiple slice pairs.
1516#[derive(Clone, Copy, Debug)]
1517pub struct ResumableL2<A = diskann_wide::arch::Current>
1518where
1519    A: Architecture,
1520    L2: SIMDSchema<f32, f32, A>,
1521{
1522    acc: <L2 as SIMDSchema<f32, f32, A>>::Accumulator,
1523}
1524
1525impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableL2<A>
1526where
1527    A: Architecture,
1528    L2: SIMDSchema<f32, f32, A, Return = f32>,
1529{
1530    type NonResumable = L2;
1531    type FinalReturn = f32;
1532
1533    #[inline(always)]
1534    fn init(arch: A) -> Self {
1535        Self { acc: L2.init(arch) }
1536    }
1537
1538    #[inline(always)]
1539    fn combine_with(&self, other: <L2 as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
1540        Self {
1541            acc: self.acc + other,
1542        }
1543    }
1544
1545    #[inline(always)]
1546    fn sum(&self) -> f32 {
1547        L2.reduce(self.acc)
1548    }
1549}
1550
1551/////
1552///// IP Implementations
1553/////
1554
1555// A pure IP distance function that provides a final reduction.
1556#[derive(Clone, Copy, Debug, Default)]
1557pub struct IP;
1558
1559#[cfg(target_arch = "x86_64")]
1560impl SIMDSchema<f32, f32, V4> for IP {
1561    type SIMDWidth = Const<8>;
1562    type Accumulator = <V4 as Architecture>::f32x8;
1563    type Left = <V4 as Architecture>::f32x8;
1564    type Right = <V4 as Architecture>::f32x8;
1565    type Return = f32;
1566    type Main = Strategy4x1;
1567
1568    #[inline(always)]
1569    fn init(&self, arch: V4) -> Self::Accumulator {
1570        Self::Accumulator::default(arch)
1571    }
1572
1573    #[inline(always)]
1574    fn accumulate(
1575        &self,
1576        x: Self::Left,
1577        y: Self::Right,
1578        acc: Self::Accumulator,
1579    ) -> Self::Accumulator {
1580        x.mul_add_simd(y, acc)
1581    }
1582
1583    #[inline(always)]
1584    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1585        x.sum_tree()
1586    }
1587}
1588
1589#[cfg(target_arch = "x86_64")]
1590impl SIMDSchema<f32, f32, V3> for IP {
1591    type SIMDWidth = Const<8>;
1592    type Accumulator = <V3 as Architecture>::f32x8;
1593    type Left = <V3 as Architecture>::f32x8;
1594    type Right = <V3 as Architecture>::f32x8;
1595    type Return = f32;
1596    type Main = Strategy4x1;
1597
1598    #[inline(always)]
1599    fn init(&self, arch: V3) -> Self::Accumulator {
1600        Self::Accumulator::default(arch)
1601    }
1602
1603    #[inline(always)]
1604    fn accumulate(
1605        &self,
1606        x: Self::Left,
1607        y: Self::Right,
1608        acc: Self::Accumulator,
1609    ) -> Self::Accumulator {
1610        x.mul_add_simd(y, acc)
1611    }
1612
1613    // Perform a final reduction.
1614    #[inline(always)]
1615    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1616        x.sum_tree()
1617    }
1618}
1619
1620#[cfg(target_arch = "aarch64")]
1621impl SIMDSchema<f32, f32, Neon> for IP {
1622    type SIMDWidth = Const<4>;
1623    type Accumulator = <Neon as Architecture>::f32x4;
1624    type Left = <Neon as Architecture>::f32x4;
1625    type Right = <Neon as Architecture>::f32x4;
1626    type Return = f32;
1627    type Main = Strategy4x1;
1628
1629    #[inline(always)]
1630    fn init(&self, arch: Neon) -> Self::Accumulator {
1631        Self::Accumulator::default(arch)
1632    }
1633
1634    #[inline(always)]
1635    fn accumulate(
1636        &self,
1637        x: Self::Left,
1638        y: Self::Right,
1639        acc: Self::Accumulator,
1640    ) -> Self::Accumulator {
1641        x.mul_add_simd(y, acc)
1642    }
1643
1644    #[inline(always)]
1645    unsafe fn epilogue(
1646        &self,
1647        arch: Neon,
1648        x: *const f32,
1649        y: *const f32,
1650        len: usize,
1651        acc: Self::Accumulator,
1652    ) -> Self::Accumulator {
1653        let scalar = scalar_epilogue(
1654            x,
1655            y,
1656            len.min(Self::SIMDWidth::value() - 1),
1657            0.0f32,
1658            |acc, x: f32, y: f32| -> f32 { x.mul_add(y, acc) },
1659        );
1660        acc + Self::Accumulator::from_array(arch, [scalar, 0.0, 0.0, 0.0])
1661    }
1662
1663    #[inline(always)]
1664    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1665        x.sum_tree()
1666    }
1667}
1668
1669impl SIMDSchema<f32, f32, Scalar> for IP {
1670    type SIMDWidth = Const<4>;
1671    type Accumulator = Emulated<f32, 4>;
1672    type Left = Emulated<f32, 4>;
1673    type Right = Emulated<f32, 4>;
1674    type Return = f32;
1675    type Main = Strategy2x1;
1676
1677    #[inline(always)]
1678    fn init(&self, arch: Scalar) -> Self::Accumulator {
1679        Self::Accumulator::default(arch)
1680    }
1681
1682    #[inline(always)]
1683    fn accumulate(
1684        &self,
1685        x: Self::Left,
1686        y: Self::Right,
1687        acc: Self::Accumulator,
1688    ) -> Self::Accumulator {
1689        x * y + acc
1690    }
1691
1692    // Perform a final reduction.
1693    #[inline(always)]
1694    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1695        x.sum_tree()
1696    }
1697
1698    #[inline(always)]
1699    unsafe fn epilogue(
1700        &self,
1701        arch: Scalar,
1702        x: *const f32,
1703        y: *const f32,
1704        len: usize,
1705        acc: Self::Accumulator,
1706    ) -> Self::Accumulator {
1707        let mut s: f32 = 0.0;
1708        for i in 0..len {
1709            // SAFETY: The range `[x, x.add(len))` is valid for reads.
1710            let vx = unsafe { x.add(i).read_unaligned() };
1711            // SAFETY: The range `[y, y.add(len))` is valid for reads.
1712            let vy = unsafe { y.add(i).read_unaligned() };
1713            s += vx * vy;
1714        }
1715        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
1716    }
1717}
1718
1719#[cfg(target_arch = "x86_64")]
1720impl SIMDSchema<Half, Half, V4> for IP {
1721    type SIMDWidth = Const<8>;
1722    type Accumulator = <V4 as Architecture>::f32x8;
1723    type Left = <V4 as Architecture>::f16x8;
1724    type Right = <V4 as Architecture>::f16x8;
1725    type Return = f32;
1726    type Main = Strategy4x1;
1727
1728    #[inline(always)]
1729    fn init(&self, arch: V4) -> Self::Accumulator {
1730        Self::Accumulator::default(arch)
1731    }
1732
1733    #[inline(always)]
1734    fn accumulate(
1735        &self,
1736        x: Self::Left,
1737        y: Self::Right,
1738        acc: Self::Accumulator,
1739    ) -> Self::Accumulator {
1740        diskann_wide::alias!(f32s = <V4>::f32x8);
1741
1742        let x: f32s = x.into();
1743        let y: f32s = y.into();
1744        x.mul_add_simd(y, acc)
1745    }
1746
1747    #[inline(always)]
1748    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1749        x.sum_tree()
1750    }
1751}
1752
1753#[cfg(target_arch = "x86_64")]
1754impl SIMDSchema<Half, Half, V3> for IP {
1755    type SIMDWidth = Const<8>;
1756    type Accumulator = <V3 as Architecture>::f32x8;
1757    type Left = <V3 as Architecture>::f16x8;
1758    type Right = <V3 as Architecture>::f16x8;
1759    type Return = f32;
1760    type Main = Strategy2x4;
1761
1762    #[inline(always)]
1763    fn init(&self, arch: V3) -> Self::Accumulator {
1764        Self::Accumulator::default(arch)
1765    }
1766
1767    #[inline(always)]
1768    fn accumulate(
1769        &self,
1770        x: Self::Left,
1771        y: Self::Right,
1772        acc: Self::Accumulator,
1773    ) -> Self::Accumulator {
1774        diskann_wide::alias!(f32s = <V3>::f32x8);
1775
1776        let x: f32s = x.into();
1777        let y: f32s = y.into();
1778        x.mul_add_simd(y, acc)
1779    }
1780
1781    // Perform a final reduction.
1782    #[inline(always)]
1783    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1784        x.sum_tree()
1785    }
1786}
1787
1788#[cfg(target_arch = "aarch64")]
1789impl SIMDSchema<Half, Half, Neon> for IP {
1790    type SIMDWidth = Const<4>;
1791    type Accumulator = <Neon as Architecture>::f32x4;
1792    type Left = diskann_wide::arch::aarch64::f16x4;
1793    type Right = diskann_wide::arch::aarch64::f16x4;
1794    type Return = f32;
1795    type Main = Strategy4x1;
1796
1797    #[inline(always)]
1798    fn init(&self, arch: Neon) -> Self::Accumulator {
1799        Self::Accumulator::default(arch)
1800    }
1801
1802    #[inline(always)]
1803    fn accumulate(
1804        &self,
1805        x: Self::Left,
1806        y: Self::Right,
1807        acc: Self::Accumulator,
1808    ) -> Self::Accumulator {
1809        diskann_wide::alias!(f32s = <Neon>::f32x4);
1810
1811        let x: f32s = x.into();
1812        let y: f32s = y.into();
1813
1814        x.mul_add_simd(y, acc)
1815    }
1816
1817    #[inline(always)]
1818    unsafe fn epilogue(
1819        &self,
1820        arch: Neon,
1821        x: *const Half,
1822        y: *const Half,
1823        len: usize,
1824        acc: Self::Accumulator,
1825    ) -> Self::Accumulator {
1826        diskann_wide::alias!(f32s = <Neon>::f32x4);
1827
1828        let rest = scalar_epilogue(
1829            x,
1830            y,
1831            len.min(Self::SIMDWidth::value() - 1),
1832            f32s::default(arch),
1833            |acc, x: Half, y: Half| -> f32s {
1834                let zero = Half::default();
1835                let x: f32s = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
1836                let y: f32s = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
1837                x.mul_add_simd(y, acc)
1838            },
1839        );
1840        acc + rest
1841    }
1842
1843    #[inline(always)]
1844    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1845        x.sum_tree()
1846    }
1847}
1848
1849impl SIMDSchema<Half, Half, Scalar> for IP {
1850    type SIMDWidth = Const<1>;
1851    type Accumulator = Emulated<f32, 1>;
1852    type Left = Emulated<Half, 1>;
1853    type Right = Emulated<Half, 1>;
1854    type Return = f32;
1855    type Main = Strategy1x1;
1856
1857    #[inline(always)]
1858    fn init(&self, arch: Scalar) -> Self::Accumulator {
1859        Self::Accumulator::default(arch)
1860    }
1861
1862    #[inline(always)]
1863    fn accumulate(
1864        &self,
1865        x: Self::Left,
1866        y: Self::Right,
1867        acc: Self::Accumulator,
1868    ) -> Self::Accumulator {
1869        let x: Self::Accumulator = x.into();
1870        let y: Self::Accumulator = y.into();
1871        x * y + acc
1872    }
1873
1874    #[inline(always)]
1875    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1876        x.to_array()[0]
1877    }
1878}
1879
1880impl<A> SIMDSchema<f32, Half, A> for IP
1881where
1882    A: Architecture,
1883{
1884    type SIMDWidth = Const<8>;
1885    type Accumulator = A::f32x8;
1886    type Left = A::f32x8;
1887    type Right = A::f16x8;
1888    type Return = f32;
1889    type Main = Strategy4x2;
1890
1891    #[inline(always)]
1892    fn init(&self, arch: A) -> Self::Accumulator {
1893        Self::Accumulator::default(arch)
1894    }
1895
1896    #[inline(always)]
1897    fn accumulate(
1898        &self,
1899        x: Self::Left,
1900        y: Self::Right,
1901        acc: Self::Accumulator,
1902    ) -> Self::Accumulator {
1903        let y: A::f32x8 = y.into();
1904        x.mul_add_simd(y, acc)
1905    }
1906
1907    // Perform a final reduction.
1908    #[inline(always)]
1909    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1910        x.sum_tree()
1911    }
1912}
1913
1914#[cfg(target_arch = "x86_64")]
1915impl SIMDSchema<i8, i8, V4> for IP {
1916    type SIMDWidth = Const<32>;
1917    type Accumulator = <V4 as Architecture>::i32x16;
1918    type Left = <V4 as Architecture>::i8x32;
1919    type Right = <V4 as Architecture>::i8x32;
1920    type Return = f32;
1921    type Main = Strategy4x1;
1922
1923    #[inline(always)]
1924    fn init(&self, arch: V4) -> Self::Accumulator {
1925        Self::Accumulator::default(arch)
1926    }
1927
1928    #[inline(always)]
1929    fn accumulate(
1930        &self,
1931        x: Self::Left,
1932        y: Self::Right,
1933        acc: Self::Accumulator,
1934    ) -> Self::Accumulator {
1935        diskann_wide::alias!(i16s = <V4>::i16x32);
1936
1937        let x: i16s = x.into();
1938        let y: i16s = y.into();
1939        acc.dot_simd(x, y)
1940    }
1941
1942    #[inline(always)]
1943    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1944        x.sum_tree().as_f32_lossy()
1945    }
1946}
1947
1948#[cfg(target_arch = "x86_64")]
1949impl SIMDSchema<i8, i8, V3> for IP {
1950    type SIMDWidth = Const<16>;
1951    type Accumulator = <V3 as Architecture>::i32x8;
1952    type Left = <V3 as Architecture>::i8x16;
1953    type Right = <V3 as Architecture>::i8x16;
1954    type Return = f32;
1955    type Main = Strategy4x1;
1956
1957    #[inline(always)]
1958    fn init(&self, arch: V3) -> Self::Accumulator {
1959        Self::Accumulator::default(arch)
1960    }
1961
1962    #[inline(always)]
1963    fn accumulate(
1964        &self,
1965        x: Self::Left,
1966        y: Self::Right,
1967        acc: Self::Accumulator,
1968    ) -> Self::Accumulator {
1969        diskann_wide::alias!(i16s = <V3>::i16x16);
1970
1971        let x: i16s = x.into();
1972        let y: i16s = y.into();
1973        acc.dot_simd(x, y)
1974    }
1975
1976    // Perform a final reduction.
1977    #[inline(always)]
1978    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
1979        x.sum_tree().as_f32_lossy()
1980    }
1981}
1982
1983#[cfg(target_arch = "aarch64")]
1984impl SIMDSchema<i8, i8, Neon> for IP {
1985    type SIMDWidth = Const<16>;
1986    type Accumulator = <Neon as Architecture>::i32x4;
1987    type Left = <Neon as Architecture>::i8x16;
1988    type Right = <Neon as Architecture>::i8x16;
1989    type Return = f32;
1990    type Main = Strategy2x1;
1991
1992    #[inline(always)]
1993    fn init(&self, arch: Neon) -> Self::Accumulator {
1994        Self::Accumulator::default(arch)
1995    }
1996
1997    #[inline(always)]
1998    fn accumulate(
1999        &self,
2000        x: Self::Left,
2001        y: Self::Right,
2002        acc: Self::Accumulator,
2003    ) -> Self::Accumulator {
2004        acc.dot_simd(x, y)
2005    }
2006
2007    #[inline(always)]
2008    unsafe fn epilogue(
2009        &self,
2010        arch: Neon,
2011        x: *const i8,
2012        y: *const i8,
2013        len: usize,
2014        acc: Self::Accumulator,
2015    ) -> Self::Accumulator {
2016        let scalar = scalar_epilogue(
2017            x,
2018            y,
2019            len.min(Self::SIMDWidth::value() - 1),
2020            0i32,
2021            |acc, x: i8, y: i8| -> i32 { acc + (x as i32) * (y as i32) },
2022        );
2023        acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0])
2024    }
2025
2026    #[inline(always)]
2027    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2028        x.sum_tree().as_f32_lossy()
2029    }
2030}
2031
2032impl SIMDSchema<i8, i8, Scalar> for IP {
2033    type SIMDWidth = Const<1>;
2034    type Accumulator = Emulated<i32, 1>;
2035    type Left = Emulated<i8, 1>;
2036    type Right = Emulated<i8, 1>;
2037    type Return = f32;
2038    type Main = Strategy1x1;
2039
2040    #[inline(always)]
2041    fn init(&self, arch: Scalar) -> Self::Accumulator {
2042        Self::Accumulator::default(arch)
2043    }
2044
2045    #[inline(always)]
2046    fn accumulate(
2047        &self,
2048        x: Self::Left,
2049        y: Self::Right,
2050        acc: Self::Accumulator,
2051    ) -> Self::Accumulator {
2052        let x: Self::Accumulator = x.into();
2053        let y: Self::Accumulator = y.into();
2054        x * y + acc
2055    }
2056
2057    // Perform a final reduction.
2058    #[inline(always)]
2059    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2060        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
2061    }
2062
2063    #[inline(always)]
2064    unsafe fn epilogue(
2065        &self,
2066        _arch: Scalar,
2067        _x: *const i8,
2068        _y: *const i8,
2069        _len: usize,
2070        _acc: Self::Accumulator,
2071    ) -> Self::Accumulator {
2072        unreachable!("The SIMD width is 1, so there should be no epilogue")
2073    }
2074}
2075
2076#[cfg(target_arch = "x86_64")]
2077impl SIMDSchema<u8, u8, V4> for IP {
2078    type SIMDWidth = Const<32>;
2079    type Accumulator = <V4 as Architecture>::i32x16;
2080    type Left = <V4 as Architecture>::u8x32;
2081    type Right = <V4 as Architecture>::u8x32;
2082    type Return = f32;
2083    type Main = Strategy4x1;
2084
2085    #[inline(always)]
2086    fn init(&self, arch: V4) -> Self::Accumulator {
2087        Self::Accumulator::default(arch)
2088    }
2089
2090    #[inline(always)]
2091    fn accumulate(
2092        &self,
2093        x: Self::Left,
2094        y: Self::Right,
2095        acc: Self::Accumulator,
2096    ) -> Self::Accumulator {
2097        diskann_wide::alias!(i16s = <V4>::i16x32);
2098
2099        let x: i16s = x.into();
2100        let y: i16s = y.into();
2101        acc.dot_simd(x, y)
2102    }
2103
2104    #[inline(always)]
2105    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2106        x.sum_tree().as_f32_lossy()
2107    }
2108}
2109
2110#[cfg(target_arch = "x86_64")]
2111impl SIMDSchema<u8, u8, V3> for IP {
2112    type SIMDWidth = Const<16>;
2113    type Accumulator = <V3 as Architecture>::i32x8;
2114    type Left = <V3 as Architecture>::u8x16;
2115    type Right = <V3 as Architecture>::u8x16;
2116    type Return = f32;
2117    type Main = Strategy4x1;
2118
2119    #[inline(always)]
2120    fn init(&self, arch: V3) -> Self::Accumulator {
2121        Self::Accumulator::default(arch)
2122    }
2123
2124    #[inline(always)]
2125    fn accumulate(
2126        &self,
2127        x: Self::Left,
2128        y: Self::Right,
2129        acc: Self::Accumulator,
2130    ) -> Self::Accumulator {
2131        diskann_wide::alias!(i16s = <V3>::i16x16);
2132
2133        // NOTE: Promotiving to `i16` rather than `u16` to hit specialized AVX2
2134        // instructions on x86 hardware.
2135        let x: i16s = x.into();
2136        let y: i16s = y.into();
2137        acc.dot_simd(x, y)
2138    }
2139
2140    // Perform a final reduction.
2141    #[inline(always)]
2142    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2143        x.sum_tree().as_f32_lossy()
2144    }
2145}
2146
2147#[cfg(target_arch = "aarch64")]
2148impl SIMDSchema<u8, u8, Neon> for IP {
2149    type SIMDWidth = Const<16>;
2150    type Accumulator = <Neon as Architecture>::u32x4;
2151    type Left = <Neon as Architecture>::u8x16;
2152    type Right = <Neon as Architecture>::u8x16;
2153    type Return = f32;
2154    type Main = Strategy2x1;
2155
2156    #[inline(always)]
2157    fn init(&self, arch: Neon) -> Self::Accumulator {
2158        Self::Accumulator::default(arch)
2159    }
2160
2161    #[inline(always)]
2162    fn accumulate(
2163        &self,
2164        x: Self::Left,
2165        y: Self::Right,
2166        acc: Self::Accumulator,
2167    ) -> Self::Accumulator {
2168        acc.dot_simd(x, y)
2169    }
2170
2171    #[inline(always)]
2172    unsafe fn epilogue(
2173        &self,
2174        arch: Neon,
2175        x: *const u8,
2176        y: *const u8,
2177        len: usize,
2178        acc: Self::Accumulator,
2179    ) -> Self::Accumulator {
2180        let scalar = scalar_epilogue(
2181            x,
2182            y,
2183            len.min(Self::SIMDWidth::value() - 1),
2184            0u32,
2185            |acc, x: u8, y: u8| -> u32 { acc + (x as u32) * (y as u32) },
2186        );
2187        acc + Self::Accumulator::from_array(arch, [scalar, 0, 0, 0])
2188    }
2189
2190    #[inline(always)]
2191    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2192        x.sum_tree().as_f32_lossy()
2193    }
2194}
2195
2196impl SIMDSchema<u8, u8, Scalar> for IP {
2197    type SIMDWidth = Const<1>;
2198    type Accumulator = Emulated<i32, 1>;
2199    type Left = Emulated<u8, 1>;
2200    type Right = Emulated<u8, 1>;
2201    type Return = f32;
2202    type Main = Strategy1x1;
2203
2204    #[inline(always)]
2205    fn init(&self, arch: Scalar) -> Self::Accumulator {
2206        Self::Accumulator::default(arch)
2207    }
2208
2209    #[inline(always)]
2210    fn accumulate(
2211        &self,
2212        x: Self::Left,
2213        y: Self::Right,
2214        acc: Self::Accumulator,
2215    ) -> Self::Accumulator {
2216        let x: Self::Accumulator = x.into();
2217        let y: Self::Accumulator = y.into();
2218        x * y + acc
2219    }
2220
2221    // Perform a final reduction.
2222    #[inline(always)]
2223    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2224        x.to_array().into_iter().sum::<i32>().as_f32_lossy()
2225    }
2226
2227    #[inline(always)]
2228    unsafe fn epilogue(
2229        &self,
2230        _arch: Scalar,
2231        _x: *const u8,
2232        _y: *const u8,
2233        _len: usize,
2234        _acc: Self::Accumulator,
2235    ) -> Self::Accumulator {
2236        unreachable!("The SIMD width is 1, so there should be no epilogue")
2237    }
2238}
2239
2240// An IP distance function that defers a final reduction.
2241#[derive(Clone, Copy, Debug)]
2242pub struct ResumableIP<A = diskann_wide::arch::Current>
2243where
2244    A: Architecture,
2245    IP: SIMDSchema<f32, f32, A>,
2246{
2247    acc: <IP as SIMDSchema<f32, f32, A>>::Accumulator,
2248}
2249
2250impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableIP<A>
2251where
2252    A: Architecture,
2253    IP: SIMDSchema<f32, f32, A, Return = f32>,
2254{
2255    type NonResumable = IP;
2256    type FinalReturn = f32;
2257
2258    #[inline(always)]
2259    fn init(arch: A) -> Self {
2260        Self { acc: IP.init(arch) }
2261    }
2262
2263    #[inline(always)]
2264    fn combine_with(&self, other: <IP as SIMDSchema<f32, f32, A>>::Accumulator) -> Self {
2265        Self {
2266            acc: self.acc + other,
2267        }
2268    }
2269
2270    #[inline(always)]
2271    fn sum(&self) -> f32 {
2272        IP.reduce(self.acc)
2273    }
2274}
2275
2276/////////////////////////////////
2277// Stateless Cosine Similarity //
2278/////////////////////////////////
2279
2280/// Accumulator of partial products for a full cosine distance computation (where
2281/// the norms of both the query and the dataset vector are computed on the fly).
2282#[derive(Debug, Clone, Copy)]
2283pub struct FullCosineAccumulator<T> {
2284    normx: T,
2285    normy: T,
2286    xy: T,
2287}
2288
2289impl<T> FullCosineAccumulator<T>
2290where
2291    T: SIMDVector
2292        + SIMDSumTree
2293        + SIMDMulAdd
2294        + std::ops::Mul<Output = T>
2295        + std::ops::Add<Output = T>,
2296    T::Scalar: LossyF32Conversion,
2297{
2298    #[inline(always)]
2299    pub fn new(arch: T::Arch) -> Self {
2300        // SAFETY: Zero initializing a SIMD vector is safe.
2301        let zero = T::default(arch);
2302        Self {
2303            normx: zero,
2304            normy: zero,
2305            xy: zero,
2306        }
2307    }
2308
2309    #[inline(always)]
2310    pub fn add_with(&self, x: T, y: T) -> Self {
2311        // SAFETY: Arithmetic on valid arguments is valid.
2312        FullCosineAccumulator {
2313            normx: x.mul_add_simd(x, self.normx),
2314            normy: y.mul_add_simd(y, self.normy),
2315            xy: x.mul_add_simd(y, self.xy),
2316        }
2317    }
2318
2319    #[inline(always)]
2320    pub fn add_with_unfused(&self, x: T, y: T) -> Self {
2321        // SAFETY: Arithmetic on valid arguments is valid.
2322        FullCosineAccumulator {
2323            normx: x * x + self.normx,
2324            normy: y * y + self.normy,
2325            xy: x * y + self.xy,
2326        }
2327    }
2328
2329    #[inline(always)]
2330    pub fn sum(&self) -> f32 {
2331        let normx = self.normx.sum_tree().as_f32_lossy();
2332        let normy = self.normy.sum_tree().as_f32_lossy();
2333
2334        // Evaluate the denominator early and use `force_eval`.
2335        // This will allow the long `sqrt` to be overlapped with some other instructions
2336        // rather than waiting at the end of the function.
2337        //
2338        // There is some worry of subnormal numbers, but we're optimizing for the common
2339        // case where norms are reasonable values.
2340        let denominator = normx.sqrt() * normy.sqrt();
2341        let prod = self.xy.sum_tree().as_f32_lossy();
2342
2343        // Force the final products to be completely computed before the range check.
2344        //
2345        // This prevents LLVM from trying to compute `normy` or `prod` *after* the check
2346        // to `normx`, which causes it to spill heavily to the stack.
2347        //
2348        // Unfortunately, this results in a reduction pattern that appears to be slightly
2349        // slower on AMD or Windows.
2350        force_eval(denominator);
2351        force_eval(prod);
2352
2353        // This basically checks if either norm is subnormal and if so, we treat the vector
2354        // as having norm zero.
2355        //
2356        // The reason to do this rather than checking `denominator` directly is to have
2357        // consistent behavior when one vector has a small norm (i.e., always treat it as
2358        // zero) rather than potentially changing behavior when the other vector has a very
2359        // large norm to compensate.
2360        if normx < f32::MIN_POSITIVE || normy < f32::MIN_POSITIVE {
2361            return 0.0;
2362        }
2363
2364        let v = prod / denominator;
2365        (-1.0f32).max(1.0f32.min(v))
2366    }
2367
2368    /// Compute the L2 distance from the partial products rather than the cosine similarity.
2369    #[inline(always)]
2370    pub fn sum_as_l2(&self) -> f32 {
2371        let normx = self.normx.sum_tree().as_f32_lossy();
2372        let normy = self.normy.sum_tree().as_f32_lossy();
2373        let xy = self.xy.sum_tree().as_f32_lossy();
2374        normx + normy - (xy + xy)
2375    }
2376}
2377
2378impl<T> std::ops::Add for FullCosineAccumulator<T>
2379where
2380    T: std::ops::Add<Output = T>,
2381{
2382    type Output = Self;
2383    #[inline(always)]
2384    fn add(self, other: Self) -> Self {
2385        FullCosineAccumulator {
2386            normx: self.normx + other.normx,
2387            normy: self.normy + other.normy,
2388            xy: self.xy + other.xy,
2389        }
2390    }
2391}
2392
2393/// A pure Cosine Similarity function that provides a final reduction.
2394#[derive(Default, Clone, Copy)]
2395pub struct CosineStateless;
2396
2397#[cfg(target_arch = "x86_64")]
2398impl SIMDSchema<f32, f32, V4> for CosineStateless {
2399    type SIMDWidth = Const<16>;
2400    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2401    type Left = <V4 as Architecture>::f32x16;
2402    type Right = <V4 as Architecture>::f32x16;
2403    type Return = f32;
2404
2405    // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a
2406    // hefty unroll factor.
2407    type Main = Strategy2x4;
2408
2409    #[inline(always)]
2410    fn init(&self, arch: V4) -> Self::Accumulator {
2411        Self::Accumulator::new(arch)
2412    }
2413
2414    #[inline(always)]
2415    fn accumulate(
2416        &self,
2417        x: Self::Left,
2418        y: Self::Right,
2419        acc: Self::Accumulator,
2420    ) -> Self::Accumulator {
2421        acc.add_with(x, y)
2422    }
2423
2424    // Perform a final reduction.
2425    #[inline(always)]
2426    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2427        acc.sum()
2428    }
2429}
2430
2431#[cfg(target_arch = "x86_64")]
2432impl SIMDSchema<f32, f32, V3> for CosineStateless {
2433    type SIMDWidth = Const<8>;
2434    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2435    type Left = <V3 as Architecture>::f32x8;
2436    type Right = <V3 as Architecture>::f32x8;
2437    type Return = f32;
2438
2439    // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a
2440    // hefty unroll factor.
2441    type Main = Strategy2x4;
2442
2443    #[inline(always)]
2444    fn init(&self, arch: V3) -> Self::Accumulator {
2445        Self::Accumulator::new(arch)
2446    }
2447
2448    #[inline(always)]
2449    fn accumulate(
2450        &self,
2451        x: Self::Left,
2452        y: Self::Right,
2453        acc: Self::Accumulator,
2454    ) -> Self::Accumulator {
2455        acc.add_with(x, y)
2456    }
2457
2458    // Perform a final reduction.
2459    #[inline(always)]
2460    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2461        acc.sum()
2462    }
2463}
2464
2465#[cfg(target_arch = "aarch64")]
2466impl SIMDSchema<f32, f32, Neon> for CosineStateless {
2467    type SIMDWidth = Const<4>;
2468    type Accumulator = FullCosineAccumulator<<Neon as Architecture>::f32x4>;
2469    type Left = <Neon as Architecture>::f32x4;
2470    type Right = <Neon as Architecture>::f32x4;
2471    type Return = f32;
2472
2473    // Cosine accumulators are pretty large, so only use 2 parallel accumulator with a
2474    // hefty unroll factor.
2475    type Main = Strategy2x4;
2476
2477    #[inline(always)]
2478    fn init(&self, arch: Neon) -> Self::Accumulator {
2479        Self::Accumulator::new(arch)
2480    }
2481
2482    #[inline(always)]
2483    fn accumulate(
2484        &self,
2485        x: Self::Left,
2486        y: Self::Right,
2487        acc: Self::Accumulator,
2488    ) -> Self::Accumulator {
2489        acc.add_with(x, y)
2490    }
2491
2492    #[inline(always)]
2493    unsafe fn epilogue(
2494        &self,
2495        arch: Neon,
2496        x: *const f32,
2497        y: *const f32,
2498        len: usize,
2499        acc: Self::Accumulator,
2500    ) -> Self::Accumulator {
2501        let mut xx: f32 = 0.0;
2502        let mut yy: f32 = 0.0;
2503        let mut xy: f32 = 0.0;
2504        for i in 0..len.min(Self::SIMDWidth::value() - 1) {
2505            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2506            let vx = unsafe { x.add(i).read_unaligned() };
2507            // SAFETY: The range `[y, y.add(len))` is valid for reads.
2508            let vy = unsafe { y.add(i).read_unaligned() };
2509            xx = vx.mul_add(vx, xx);
2510            yy = vy.mul_add(vy, yy);
2511            xy = vx.mul_add(vy, xy);
2512        }
2513        type V = <Neon as Architecture>::f32x4;
2514        acc + FullCosineAccumulator {
2515            normx: V::from_array(arch, [xx, 0.0, 0.0, 0.0]),
2516            normy: V::from_array(arch, [yy, 0.0, 0.0, 0.0]),
2517            xy: V::from_array(arch, [xy, 0.0, 0.0, 0.0]),
2518        }
2519    }
2520
2521    // Perform a final reduction.
2522    #[inline(always)]
2523    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2524        acc.sum()
2525    }
2526}
2527
2528impl SIMDSchema<f32, f32, Scalar> for CosineStateless {
2529    type SIMDWidth = Const<4>;
2530    type Accumulator = FullCosineAccumulator<Emulated<f32, 4>>;
2531    type Left = Emulated<f32, 4>;
2532    type Right = Emulated<f32, 4>;
2533    type Return = f32;
2534
2535    type Main = Strategy2x1;
2536
2537    #[inline(always)]
2538    fn init(&self, arch: Scalar) -> Self::Accumulator {
2539        Self::Accumulator::new(arch)
2540    }
2541
2542    #[inline(always)]
2543    fn accumulate(
2544        &self,
2545        x: Self::Left,
2546        y: Self::Right,
2547        acc: Self::Accumulator,
2548    ) -> Self::Accumulator {
2549        acc.add_with_unfused(x, y)
2550    }
2551
2552    #[inline(always)]
2553    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2554        acc.sum()
2555    }
2556}
2557
2558#[cfg(target_arch = "x86_64")]
2559impl SIMDSchema<Half, Half, V4> for CosineStateless {
2560    type SIMDWidth = Const<16>;
2561    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::f32x16>;
2562    type Left = <V4 as Architecture>::f16x16;
2563    type Right = <V4 as Architecture>::f16x16;
2564    type Return = f32;
2565    type Main = Strategy2x4;
2566
2567    #[inline(always)]
2568    fn init(&self, arch: V4) -> Self::Accumulator {
2569        Self::Accumulator::new(arch)
2570    }
2571
2572    #[inline(always)]
2573    fn accumulate(
2574        &self,
2575        x: Self::Left,
2576        y: Self::Right,
2577        acc: Self::Accumulator,
2578    ) -> Self::Accumulator {
2579        diskann_wide::alias!(f32s = <V4>::f32x16);
2580
2581        let x: f32s = x.into();
2582        let y: f32s = y.into();
2583        acc.add_with(x, y)
2584    }
2585
2586    #[inline(always)]
2587    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2588        acc.sum()
2589    }
2590}
2591
2592#[cfg(target_arch = "x86_64")]
2593impl SIMDSchema<Half, Half, V3> for CosineStateless {
2594    type SIMDWidth = Const<8>;
2595    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::f32x8>;
2596    type Left = <V3 as Architecture>::f16x8;
2597    type Right = <V3 as Architecture>::f16x8;
2598    type Return = f32;
2599    type Main = Strategy2x4;
2600
2601    #[inline(always)]
2602    fn init(&self, arch: V3) -> Self::Accumulator {
2603        Self::Accumulator::new(arch)
2604    }
2605
2606    #[inline(always)]
2607    fn accumulate(
2608        &self,
2609        x: Self::Left,
2610        y: Self::Right,
2611        acc: Self::Accumulator,
2612    ) -> Self::Accumulator {
2613        diskann_wide::alias!(f32s = <V3>::f32x8);
2614
2615        let x: f32s = x.into();
2616        let y: f32s = y.into();
2617        acc.add_with(x, y)
2618    }
2619
2620    // Perform a final reduction.
2621    #[inline(always)]
2622    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2623        acc.sum()
2624    }
2625}
2626
2627#[cfg(target_arch = "aarch64")]
2628impl SIMDSchema<Half, Half, Neon> for CosineStateless {
2629    type SIMDWidth = Const<4>;
2630    type Accumulator = FullCosineAccumulator<<Neon as Architecture>::f32x4>;
2631    type Left = diskann_wide::arch::aarch64::f16x4;
2632    type Right = diskann_wide::arch::aarch64::f16x4;
2633    type Return = f32;
2634
2635    type Main = Strategy2x4;
2636
2637    #[inline(always)]
2638    fn init(&self, arch: Neon) -> Self::Accumulator {
2639        Self::Accumulator::new(arch)
2640    }
2641
2642    #[inline(always)]
2643    fn accumulate(
2644        &self,
2645        x: Self::Left,
2646        y: Self::Right,
2647        acc: Self::Accumulator,
2648    ) -> Self::Accumulator {
2649        diskann_wide::alias!(f32s = <Neon>::f32x4);
2650
2651        let x: f32s = x.into();
2652        let y: f32s = y.into();
2653        acc.add_with(x, y)
2654    }
2655
2656    #[inline(always)]
2657    unsafe fn epilogue(
2658        &self,
2659        arch: Neon,
2660        x: *const Half,
2661        y: *const Half,
2662        len: usize,
2663        acc: Self::Accumulator,
2664    ) -> Self::Accumulator {
2665        type V = <Neon as Architecture>::f32x4;
2666
2667        let rest = scalar_epilogue(
2668            x,
2669            y,
2670            len.min(Self::SIMDWidth::value() - 1),
2671            FullCosineAccumulator::<V>::new(arch),
2672            |acc, x: Half, y: Half| -> FullCosineAccumulator<V> {
2673                let zero = Half::default();
2674                let x: V = Self::Left::from_array(arch, [x, zero, zero, zero]).into();
2675                let y: V = Self::Right::from_array(arch, [y, zero, zero, zero]).into();
2676                acc.add_with(x, y)
2677            },
2678        );
2679        acc + rest
2680    }
2681
2682    #[inline(always)]
2683    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2684        acc.sum()
2685    }
2686}
2687
2688impl SIMDSchema<Half, Half, Scalar> for CosineStateless {
2689    type SIMDWidth = Const<1>;
2690    type Accumulator = FullCosineAccumulator<Emulated<f32, 1>>;
2691    type Left = Emulated<Half, 1>;
2692    type Right = Emulated<Half, 1>;
2693    type Return = f32;
2694    type Main = Strategy1x1;
2695
2696    #[inline(always)]
2697    fn init(&self, arch: Scalar) -> Self::Accumulator {
2698        Self::Accumulator::new(arch)
2699    }
2700
2701    #[inline(always)]
2702    fn accumulate(
2703        &self,
2704        x: Self::Left,
2705        y: Self::Right,
2706        acc: Self::Accumulator,
2707    ) -> Self::Accumulator {
2708        let x: Emulated<f32, 1> = x.into();
2709        let y: Emulated<f32, 1> = y.into();
2710        acc.add_with_unfused(x, y)
2711    }
2712
2713    #[inline(always)]
2714    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2715        acc.sum()
2716    }
2717}
2718impl<A> SIMDSchema<f32, Half, A> for CosineStateless
2719where
2720    A: Architecture,
2721{
2722    type SIMDWidth = Const<8>;
2723    type Accumulator = FullCosineAccumulator<A::f32x8>;
2724    type Left = A::f32x8;
2725    type Right = A::f16x8;
2726    type Return = f32;
2727    type Main = Strategy2x4;
2728
2729    #[inline(always)]
2730    fn init(&self, arch: A) -> Self::Accumulator {
2731        Self::Accumulator::new(arch)
2732    }
2733
2734    #[inline(always)]
2735    fn accumulate(
2736        &self,
2737        x: Self::Left,
2738        y: Self::Right,
2739        acc: Self::Accumulator,
2740    ) -> Self::Accumulator {
2741        let y: A::f32x8 = y.into();
2742        acc.add_with(x, y)
2743    }
2744
2745    #[inline(always)]
2746    fn reduce(&self, acc: Self::Accumulator) -> Self::Return {
2747        acc.sum()
2748    }
2749}
2750
2751#[cfg(target_arch = "x86_64")]
2752impl SIMDSchema<i8, i8, V4> for CosineStateless {
2753    type SIMDWidth = Const<32>;
2754    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2755    type Left = <V4 as Architecture>::i8x32;
2756    type Right = <V4 as Architecture>::i8x32;
2757    type Return = f32;
2758    type Main = Strategy4x1;
2759
2760    #[inline(always)]
2761    fn init(&self, arch: V4) -> Self::Accumulator {
2762        Self::Accumulator::new(arch)
2763    }
2764
2765    #[inline(always)]
2766    fn accumulate(
2767        &self,
2768        x: Self::Left,
2769        y: Self::Right,
2770        acc: Self::Accumulator,
2771    ) -> Self::Accumulator {
2772        diskann_wide::alias!(i16s = <V4>::i16x32);
2773
2774        let x: i16s = x.into();
2775        let y: i16s = y.into();
2776
2777        FullCosineAccumulator {
2778            normx: acc.normx.dot_simd(x, x),
2779            normy: acc.normy.dot_simd(y, y),
2780            xy: acc.xy.dot_simd(x, y),
2781        }
2782    }
2783
2784    // Perform a final reduction.
2785    #[inline(always)]
2786    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2787        x.sum()
2788    }
2789}
2790
2791#[cfg(target_arch = "x86_64")]
2792impl SIMDSchema<i8, i8, V3> for CosineStateless {
2793    type SIMDWidth = Const<16>;
2794    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
2795    type Left = <V3 as Architecture>::i8x16;
2796    type Right = <V3 as Architecture>::i8x16;
2797    type Return = f32;
2798    type Main = Strategy4x1;
2799
2800    #[inline(always)]
2801    fn init(&self, arch: V3) -> Self::Accumulator {
2802        Self::Accumulator::new(arch)
2803    }
2804
2805    #[inline(always)]
2806    fn accumulate(
2807        &self,
2808        x: Self::Left,
2809        y: Self::Right,
2810        acc: Self::Accumulator,
2811    ) -> Self::Accumulator {
2812        diskann_wide::alias!(i16s = <V3>::i16x16);
2813
2814        let x: i16s = x.into();
2815        let y: i16s = y.into();
2816
2817        FullCosineAccumulator {
2818            normx: acc.normx.dot_simd(x, x),
2819            normy: acc.normy.dot_simd(y, y),
2820            xy: acc.xy.dot_simd(x, y),
2821        }
2822    }
2823
2824    // Perform a final reduction.
2825    #[inline(always)]
2826    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2827        x.sum()
2828    }
2829}
2830
2831#[cfg(target_arch = "aarch64")]
2832impl SIMDSchema<i8, i8, Neon> for CosineStateless {
2833    type SIMDWidth = Const<16>;
2834    type Accumulator = FullCosineAccumulator<<Neon as Architecture>::i32x4>;
2835    type Left = <Neon as Architecture>::i8x16;
2836    type Right = <Neon as Architecture>::i8x16;
2837    type Return = f32;
2838    type Main = Strategy2x1;
2839
2840    #[inline(always)]
2841    fn init(&self, arch: Neon) -> Self::Accumulator {
2842        Self::Accumulator::new(arch)
2843    }
2844
2845    #[inline(always)]
2846    fn accumulate(
2847        &self,
2848        x: Self::Left,
2849        y: Self::Right,
2850        acc: Self::Accumulator,
2851    ) -> Self::Accumulator {
2852        FullCosineAccumulator {
2853            normx: acc.normx.dot_simd(x, x),
2854            normy: acc.normy.dot_simd(y, y),
2855            xy: acc.xy.dot_simd(x, y),
2856        }
2857    }
2858
2859    #[inline(always)]
2860    unsafe fn epilogue(
2861        &self,
2862        arch: Neon,
2863        x: *const i8,
2864        y: *const i8,
2865        len: usize,
2866        acc: Self::Accumulator,
2867    ) -> Self::Accumulator {
2868        let mut xx: i32 = 0;
2869        let mut yy: i32 = 0;
2870        let mut xy: i32 = 0;
2871        for i in 0..len.min(Self::SIMDWidth::value() - 1) {
2872            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2873            let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
2874            // SAFETY: The range `[y, y.add(len))` is valid for reads.
2875            let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
2876            xx += vx * vx;
2877            xy += vx * vy;
2878            yy += vy * vy;
2879        }
2880        type V = <Neon as Architecture>::i32x4;
2881        acc + FullCosineAccumulator {
2882            normx: V::from_array(arch, [xx, 0, 0, 0]),
2883            normy: V::from_array(arch, [yy, 0, 0, 0]),
2884            xy: V::from_array(arch, [xy, 0, 0, 0]),
2885        }
2886    }
2887
2888    #[inline(always)]
2889    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2890        x.sum()
2891    }
2892}
2893
2894impl SIMDSchema<i8, i8, Scalar> for CosineStateless {
2895    type SIMDWidth = Const<4>;
2896    type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
2897    type Left = Emulated<i8, 4>;
2898    type Right = Emulated<i8, 4>;
2899    type Return = f32;
2900    type Main = Strategy1x1;
2901
2902    #[inline(always)]
2903    fn init(&self, arch: Scalar) -> Self::Accumulator {
2904        Self::Accumulator::new(arch)
2905    }
2906
2907    #[inline(always)]
2908    fn accumulate(
2909        &self,
2910        x: Self::Left,
2911        y: Self::Right,
2912        acc: Self::Accumulator,
2913    ) -> Self::Accumulator {
2914        let x: Emulated<i32, 4> = x.into();
2915        let y: Emulated<i32, 4> = y.into();
2916        acc.add_with(x, y)
2917    }
2918
2919    // Perform a final reduction.
2920    #[inline(always)]
2921    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2922        x.sum()
2923    }
2924
2925    #[inline(always)]
2926    unsafe fn epilogue(
2927        &self,
2928        arch: Scalar,
2929        x: *const i8,
2930        y: *const i8,
2931        len: usize,
2932        acc: Self::Accumulator,
2933    ) -> Self::Accumulator {
2934        let mut xy: i32 = 0;
2935        let mut xx: i32 = 0;
2936        let mut yy: i32 = 0;
2937
2938        for i in 0..len {
2939            // SAFETY: The range `[x, x.add(len))` is valid for reads.
2940            let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
2941            // SAFETY: The range `[y, y.add(len))` is valid for reads.
2942            let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
2943
2944            xx += vx * vx;
2945            xy += vx * vy;
2946            yy += vy * vy;
2947        }
2948
2949        acc + FullCosineAccumulator {
2950            normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
2951            normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
2952            xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
2953        }
2954    }
2955}
2956
2957#[cfg(target_arch = "x86_64")]
2958impl SIMDSchema<u8, u8, V4> for CosineStateless {
2959    type SIMDWidth = Const<32>;
2960    type Accumulator = FullCosineAccumulator<<V4 as Architecture>::i32x16>;
2961    type Left = <V4 as Architecture>::u8x32;
2962    type Right = <V4 as Architecture>::u8x32;
2963    type Return = f32;
2964    type Main = Strategy4x1;
2965
2966    #[inline(always)]
2967    fn init(&self, arch: V4) -> Self::Accumulator {
2968        Self::Accumulator::new(arch)
2969    }
2970
2971    #[inline(always)]
2972    fn accumulate(
2973        &self,
2974        x: Self::Left,
2975        y: Self::Right,
2976        acc: Self::Accumulator,
2977    ) -> Self::Accumulator {
2978        diskann_wide::alias!(i16s = <V4>::i16x32);
2979
2980        let x: i16s = x.into();
2981        let y: i16s = y.into();
2982
2983        FullCosineAccumulator {
2984            normx: acc.normx.dot_simd(x, x),
2985            normy: acc.normy.dot_simd(y, y),
2986            xy: acc.xy.dot_simd(x, y),
2987        }
2988    }
2989
2990    // Perform a final reduction.
2991    #[inline(always)]
2992    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
2993        x.sum()
2994    }
2995}
2996
2997#[cfg(target_arch = "x86_64")]
2998impl SIMDSchema<u8, u8, V3> for CosineStateless {
2999    type SIMDWidth = Const<16>;
3000    type Accumulator = FullCosineAccumulator<<V3 as Architecture>::i32x8>;
3001    type Left = <V3 as Architecture>::u8x16;
3002    type Right = <V3 as Architecture>::u8x16;
3003    type Return = f32;
3004    type Main = Strategy4x1;
3005
3006    #[inline(always)]
3007    fn init(&self, arch: V3) -> Self::Accumulator {
3008        Self::Accumulator::new(arch)
3009    }
3010
3011    #[inline(always)]
3012    fn accumulate(
3013        &self,
3014        x: Self::Left,
3015        y: Self::Right,
3016        acc: Self::Accumulator,
3017    ) -> Self::Accumulator {
3018        diskann_wide::alias!(i16s = <V3>::i16x16);
3019
3020        let x: i16s = x.into();
3021        let y: i16s = y.into();
3022
3023        FullCosineAccumulator {
3024            normx: acc.normx.dot_simd(x, x),
3025            normy: acc.normy.dot_simd(y, y),
3026            xy: acc.xy.dot_simd(x, y),
3027        }
3028    }
3029
3030    // Perform a final reduction.
3031    #[inline(always)]
3032    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3033        x.sum()
3034    }
3035}
3036
3037#[cfg(target_arch = "aarch64")]
3038impl SIMDSchema<u8, u8, Neon> for CosineStateless {
3039    type SIMDWidth = Const<16>;
3040    type Accumulator = FullCosineAccumulator<<Neon as Architecture>::u32x4>;
3041    type Left = <Neon as Architecture>::u8x16;
3042    type Right = <Neon as Architecture>::u8x16;
3043    type Return = f32;
3044    type Main = Strategy2x1;
3045
3046    #[inline(always)]
3047    fn init(&self, arch: Neon) -> Self::Accumulator {
3048        Self::Accumulator::new(arch)
3049    }
3050
3051    #[inline(always)]
3052    fn accumulate(
3053        &self,
3054        x: Self::Left,
3055        y: Self::Right,
3056        acc: Self::Accumulator,
3057    ) -> Self::Accumulator {
3058        FullCosineAccumulator {
3059            normx: acc.normx.dot_simd(x, x),
3060            normy: acc.normy.dot_simd(y, y),
3061            xy: acc.xy.dot_simd(x, y),
3062        }
3063    }
3064
3065    #[inline(always)]
3066    unsafe fn epilogue(
3067        &self,
3068        arch: Neon,
3069        x: *const u8,
3070        y: *const u8,
3071        len: usize,
3072        acc: Self::Accumulator,
3073    ) -> Self::Accumulator {
3074        let mut xx: u32 = 0;
3075        let mut yy: u32 = 0;
3076        let mut xy: u32 = 0;
3077        for i in 0..len.min(Self::SIMDWidth::value() - 1) {
3078            // SAFETY: The range `[x, x.add(len))` is valid for reads.
3079            let vx: u32 = unsafe { x.add(i).read_unaligned() }.into();
3080            // SAFETY: The range `[y, y.add(len))` is valid for reads.
3081            let vy: u32 = unsafe { y.add(i).read_unaligned() }.into();
3082            xx += vx * vx;
3083            xy += vx * vy;
3084            yy += vy * vy;
3085        }
3086        type V = <Neon as Architecture>::u32x4;
3087        acc + FullCosineAccumulator {
3088            normx: V::from_array(arch, [xx, 0, 0, 0]),
3089            normy: V::from_array(arch, [yy, 0, 0, 0]),
3090            xy: V::from_array(arch, [xy, 0, 0, 0]),
3091        }
3092    }
3093
3094    #[inline(always)]
3095    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3096        x.sum()
3097    }
3098}
3099
3100impl SIMDSchema<u8, u8, Scalar> for CosineStateless {
3101    type SIMDWidth = Const<4>;
3102    type Accumulator = FullCosineAccumulator<Emulated<i32, 4>>;
3103    type Left = Emulated<u8, 4>;
3104    type Right = Emulated<u8, 4>;
3105    type Return = f32;
3106    type Main = Strategy1x1;
3107
3108    #[inline(always)]
3109    fn init(&self, arch: Scalar) -> Self::Accumulator {
3110        Self::Accumulator::new(arch)
3111    }
3112
3113    #[inline(always)]
3114    fn accumulate(
3115        &self,
3116        x: Self::Left,
3117        y: Self::Right,
3118        acc: Self::Accumulator,
3119    ) -> Self::Accumulator {
3120        let x: Emulated<i32, 4> = x.into();
3121        let y: Emulated<i32, 4> = y.into();
3122        acc.add_with(x, y)
3123    }
3124
3125    // Perform a final reduction.
3126    #[inline(always)]
3127    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3128        x.sum()
3129    }
3130
3131    #[inline(always)]
3132    unsafe fn epilogue(
3133        &self,
3134        arch: Scalar,
3135        x: *const u8,
3136        y: *const u8,
3137        len: usize,
3138        acc: Self::Accumulator,
3139    ) -> Self::Accumulator {
3140        let mut xy: i32 = 0;
3141        let mut xx: i32 = 0;
3142        let mut yy: i32 = 0;
3143
3144        for i in 0..len {
3145            // SAFETY: The range `[x, x.add(len))` is valid for reads.
3146            let vx: i32 = unsafe { x.add(i).read_unaligned() }.into();
3147            // SAFETY: The range `[y, y.add(len))` is valid for reads.
3148            let vy: i32 = unsafe { y.add(i).read_unaligned() }.into();
3149
3150            xx += vx * vx;
3151            xy += vx * vy;
3152            yy += vy * vy;
3153        }
3154
3155        acc + FullCosineAccumulator {
3156            normx: Emulated::from_array(arch, [xx, 0, 0, 0]),
3157            normy: Emulated::from_array(arch, [yy, 0, 0, 0]),
3158            xy: Emulated::from_array(arch, [xy, 0, 0, 0]),
3159        }
3160    }
3161}
3162
3163/// A resumable cosine similarity computation.
3164#[derive(Debug, Clone, Copy)]
3165pub struct ResumableCosine<A = diskann_wide::arch::Current>(
3166    <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
3167)
3168where
3169    A: Architecture,
3170    CosineStateless: SIMDSchema<f32, f32, A>;
3171
3172impl<A> ResumableSIMDSchema<f32, f32, A> for ResumableCosine<A>
3173where
3174    A: Architecture,
3175    CosineStateless: SIMDSchema<f32, f32, A, Return = f32>,
3176{
3177    type NonResumable = CosineStateless;
3178    type FinalReturn = f32;
3179
3180    #[inline(always)]
3181    fn init(arch: A) -> Self {
3182        Self(CosineStateless.init(arch))
3183    }
3184
3185    #[inline(always)]
3186    fn combine_with(
3187        &self,
3188        other: <CosineStateless as SIMDSchema<f32, f32, A>>::Accumulator,
3189    ) -> Self {
3190        Self(self.0 + other)
3191    }
3192
3193    #[inline(always)]
3194    fn sum(&self) -> f32 {
3195        CosineStateless.reduce(self.0)
3196    }
3197}
3198
3199/////
3200///// L1 Norm Implementations
3201/////
3202
3203// ==================================================================================================
3204// NOTE: L1Norm IS A LOGICAL UNARY OPERATION
3205// --------------------------------------------------------------------------------------------------
3206// Although wired through the generic binary 'SIMDSchema'/'simd_op' infrastructure (which expects
3207// two input slices of equal length), 'L1Norm' conceptually computes: sum_i |x_i|
3208// The right-hand operand is completely ignored and exists ONLY to satisfy the shared execution
3209// machinery (loop tiling, epilogue handling, etc.).
3210// ==================================================================================================
3211
3212// A pure L1 norm function that provides a final reduction.
3213#[derive(Clone, Copy, Debug, Default)]
3214pub struct L1Norm;
3215
3216#[cfg(target_arch = "x86_64")]
3217impl SIMDSchema<f32, f32, V4> for L1Norm {
3218    type SIMDWidth = Const<16>;
3219    type Accumulator = <V4 as Architecture>::f32x16;
3220    type Left = <V4 as Architecture>::f32x16;
3221    type Right = <V4 as Architecture>::f32x16;
3222    type Return = f32;
3223    type Main = Strategy4x1;
3224
3225    #[inline(always)]
3226    fn init(&self, arch: V4) -> Self::Accumulator {
3227        Self::Accumulator::default(arch)
3228    }
3229
3230    #[inline(always)]
3231    fn accumulate(
3232        &self,
3233        x: Self::Left,
3234        _y: Self::Right,
3235        acc: Self::Accumulator,
3236    ) -> Self::Accumulator {
3237        x.abs_simd() + acc
3238    }
3239
3240    // Perform a final reduction.
3241    #[inline(always)]
3242    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3243        x.sum_tree()
3244    }
3245}
3246
3247#[cfg(target_arch = "x86_64")]
3248impl SIMDSchema<f32, f32, V3> for L1Norm {
3249    type SIMDWidth = Const<8>;
3250    type Accumulator = <V3 as Architecture>::f32x8;
3251    type Left = <V3 as Architecture>::f32x8;
3252    type Right = <V3 as Architecture>::f32x8;
3253    type Return = f32;
3254    type Main = Strategy4x1;
3255
3256    #[inline(always)]
3257    fn init(&self, arch: V3) -> Self::Accumulator {
3258        Self::Accumulator::default(arch)
3259    }
3260
3261    #[inline(always)]
3262    fn accumulate(
3263        &self,
3264        x: Self::Left,
3265        _y: Self::Right,
3266        acc: Self::Accumulator,
3267    ) -> Self::Accumulator {
3268        x.abs_simd() + acc
3269    }
3270
3271    // Perform a final reduction.
3272    #[inline(always)]
3273    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3274        x.sum_tree()
3275    }
3276}
3277
3278#[cfg(target_arch = "aarch64")]
3279impl SIMDSchema<f32, f32, Neon> for L1Norm {
3280    type SIMDWidth = Const<4>;
3281    type Accumulator = <Neon as Architecture>::f32x4;
3282    type Left = <Neon as Architecture>::f32x4;
3283    type Right = <Neon as Architecture>::f32x4;
3284    type Return = f32;
3285    type Main = Strategy4x1;
3286
3287    #[inline(always)]
3288    fn init(&self, arch: Neon) -> Self::Accumulator {
3289        Self::Accumulator::default(arch)
3290    }
3291
3292    #[inline(always)]
3293    fn accumulate(
3294        &self,
3295        x: Self::Left,
3296        _y: Self::Right,
3297        acc: Self::Accumulator,
3298    ) -> Self::Accumulator {
3299        x.abs_simd() + acc
3300    }
3301
3302    #[inline(always)]
3303    unsafe fn epilogue(
3304        &self,
3305        arch: Neon,
3306        x: *const f32,
3307        _y: *const f32,
3308        len: usize,
3309        acc: Self::Accumulator,
3310    ) -> Self::Accumulator {
3311        let mut s: f32 = 0.0;
3312        for i in 0..len.min(Self::SIMDWidth::value() - 1) {
3313            // SAFETY: The range `[x, x.add(len))` is valid for reads.
3314            let vx = unsafe { x.add(i).read_unaligned() };
3315            s += vx.abs();
3316        }
3317        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
3318    }
3319
3320    #[inline(always)]
3321    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3322        x.sum_tree()
3323    }
3324}
3325
3326impl SIMDSchema<f32, f32, Scalar> for L1Norm {
3327    type SIMDWidth = Const<4>;
3328    type Accumulator = Emulated<f32, 4>;
3329    type Left = Emulated<f32, 4>;
3330    type Right = Emulated<f32, 4>;
3331    type Return = f32;
3332    type Main = Strategy2x1;
3333
3334    #[inline(always)]
3335    fn init(&self, arch: Scalar) -> Self::Accumulator {
3336        Self::Accumulator::default(arch)
3337    }
3338
3339    #[inline(always)]
3340    fn accumulate(
3341        &self,
3342        x: Self::Left,
3343        _y: Self::Right,
3344        acc: Self::Accumulator,
3345    ) -> Self::Accumulator {
3346        x.abs_simd() + acc
3347    }
3348
3349    // Perform a final reduction.
3350    #[inline(always)]
3351    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3352        x.sum_tree()
3353    }
3354
3355    #[inline(always)]
3356    unsafe fn epilogue(
3357        &self,
3358        arch: Scalar,
3359        x: *const f32,
3360        _y: *const f32,
3361        len: usize,
3362        acc: Self::Accumulator,
3363    ) -> Self::Accumulator {
3364        let mut s: f32 = 0.0;
3365        for i in 0..len {
3366            // SAFETY: The range `[x, x.add(len))` is valid for reads.
3367            let vx = unsafe { x.add(i).read_unaligned() };
3368            s += vx.abs();
3369        }
3370        acc + Self::Accumulator::from_array(arch, [s, 0.0, 0.0, 0.0])
3371    }
3372}
3373
3374#[cfg(target_arch = "x86_64")]
3375impl SIMDSchema<Half, Half, V4> for L1Norm {
3376    type SIMDWidth = Const<8>;
3377    type Accumulator = <V4 as Architecture>::f32x8;
3378    type Left = <V4 as Architecture>::f16x8;
3379    type Right = <V4 as Architecture>::f16x8;
3380    type Return = f32;
3381    type Main = Strategy2x4;
3382
3383    #[inline(always)]
3384    fn init(&self, arch: V4) -> Self::Accumulator {
3385        Self::Accumulator::default(arch)
3386    }
3387
3388    #[inline(always)]
3389    fn accumulate(
3390        &self,
3391        x: Self::Left,
3392        _y: Self::Right,
3393        acc: Self::Accumulator,
3394    ) -> Self::Accumulator {
3395        let x: <V4 as Architecture>::f32x8 = x.into();
3396        x.abs_simd() + acc
3397    }
3398
3399    // Perform a final reduction.
3400    #[inline(always)]
3401    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3402        x.sum_tree()
3403    }
3404}
3405
3406#[cfg(target_arch = "x86_64")]
3407impl SIMDSchema<Half, Half, V3> for L1Norm {
3408    type SIMDWidth = Const<8>;
3409    type Accumulator = <V3 as Architecture>::f32x8;
3410    type Left = <V3 as Architecture>::f16x8;
3411    type Right = <V3 as Architecture>::f16x8;
3412    type Return = f32;
3413    type Main = Strategy2x4;
3414
3415    #[inline(always)]
3416    fn init(&self, arch: V3) -> Self::Accumulator {
3417        Self::Accumulator::default(arch)
3418    }
3419
3420    #[inline(always)]
3421    fn accumulate(
3422        &self,
3423        x: Self::Left,
3424        _y: Self::Right,
3425        acc: Self::Accumulator,
3426    ) -> Self::Accumulator {
3427        let x: <V3 as Architecture>::f32x8 = x.into();
3428        x.abs_simd() + acc
3429    }
3430
3431    // Perform a final reduction.
3432    #[inline(always)]
3433    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3434        x.sum_tree()
3435    }
3436}
3437
3438#[cfg(target_arch = "aarch64")]
3439impl SIMDSchema<Half, Half, Neon> for L1Norm {
3440    type SIMDWidth = Const<4>;
3441    type Accumulator = <Neon as Architecture>::f32x4;
3442    type Left = diskann_wide::arch::aarch64::f16x4;
3443    type Right = diskann_wide::arch::aarch64::f16x4;
3444    type Return = f32;
3445    type Main = Strategy2x4;
3446
3447    #[inline(always)]
3448    fn init(&self, arch: Neon) -> Self::Accumulator {
3449        Self::Accumulator::default(arch)
3450    }
3451
3452    #[inline(always)]
3453    fn accumulate(
3454        &self,
3455        x: Self::Left,
3456        _y: Self::Right,
3457        acc: Self::Accumulator,
3458    ) -> Self::Accumulator {
3459        let x: <Neon as Architecture>::f32x4 = x.into();
3460        x.abs_simd() + acc
3461    }
3462
3463    #[inline(always)]
3464    unsafe fn epilogue(
3465        &self,
3466        arch: Neon,
3467        x: *const Half,
3468        _y: *const Half,
3469        len: usize,
3470        acc: Self::Accumulator,
3471    ) -> Self::Accumulator {
3472        let rest = scalar_epilogue(
3473            x,
3474            x, // unused, but scalar_epilogue requires a right pointer
3475            len.min(Self::SIMDWidth::value() - 1),
3476            Self::Accumulator::default(arch),
3477            |acc, x: Half, _: Half| -> Self::Accumulator {
3478                let zero = Half::default();
3479                let x: Self::Accumulator =
3480                    Self::Left::from_array(arch, [x, zero, zero, zero]).into();
3481                x.abs_simd() + acc
3482            },
3483        );
3484        acc + rest
3485    }
3486
3487    // Perform a final reduction.
3488    #[inline(always)]
3489    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3490        x.sum_tree()
3491    }
3492}
3493
3494impl SIMDSchema<Half, Half, Scalar> for L1Norm {
3495    type SIMDWidth = Const<1>;
3496    type Accumulator = Emulated<f32, 1>;
3497    type Left = Emulated<Half, 1>;
3498    type Right = Emulated<Half, 1>;
3499    type Return = f32;
3500    type Main = Strategy1x1;
3501
3502    #[inline(always)]
3503    fn init(&self, arch: Scalar) -> Self::Accumulator {
3504        Self::Accumulator::default(arch)
3505    }
3506
3507    #[inline(always)]
3508    fn accumulate(
3509        &self,
3510        x: Self::Left,
3511        _y: Self::Right,
3512        acc: Self::Accumulator,
3513    ) -> Self::Accumulator {
3514        let x: Self::Accumulator = x.into();
3515        x.abs_simd() + acc
3516    }
3517
3518    // Perform a final reduction.
3519    #[inline(always)]
3520    fn reduce(&self, x: Self::Accumulator) -> Self::Return {
3521        x.to_array()[0]
3522    }
3523
3524    #[inline(always)]
3525    unsafe fn epilogue(
3526        &self,
3527        _arch: Scalar,
3528        _x: *const Half,
3529        _y: *const Half,
3530        _len: usize,
3531        _acc: Self::Accumulator,
3532    ) -> Self::Accumulator {
3533        unreachable!("The SIMD width is 1, so there should be no epilogue")
3534    }
3535}
3536
3537///////////
3538// Tests //
3539///////////
3540
3541#[cfg(test)]
3542mod tests {
3543    use std::{collections::HashMap, sync::LazyLock};
3544
3545    use approx::assert_relative_eq;
3546    use diskann_wide::{arch::Target1, ARCH};
3547    use half::f16;
3548    use rand::{distr::StandardUniform, rngs::StdRng, Rng, SeedableRng};
3549    use rand_distr;
3550
3551    use super::*;
3552    use crate::{distance::reference, norm::LInfNorm, test_util};
3553
3554    ///////////////////////
3555    // Cosine Norm Check //
3556    ///////////////////////
3557
3558    fn cosine_norm_check_impl<A>(arch: A)
3559    where
3560        A: diskann_wide::Architecture,
3561        CosineStateless:
3562            SIMDSchema<f32, f32, A, Return = f32> + SIMDSchema<Half, Half, A, Return = f32>,
3563    {
3564        // Zero - f32
3565        {
3566            let x: [f32; 2] = [0.0, 0.0];
3567            let y: [f32; 2] = [0.0, 1.0];
3568            assert_eq!(
3569                simd_op(&CosineStateless {}, arch, x, x),
3570                0.0,
3571                "when both vectors are zero, similarity should be zero",
3572            );
3573            assert_eq!(
3574                simd_op(&CosineStateless {}, arch, x, y),
3575                0.0,
3576                "when one vector is zero, similarity should be zero",
3577            );
3578            assert_eq!(
3579                simd_op(&CosineStateless {}, arch, y, x),
3580                0.0,
3581                "when one vector is zero, similarity should be zero",
3582            );
3583        }
3584
3585        // Subnormal - f32
3586        {
3587            let x: [f32; 4] = [0.0, 0.0, 2.938736e-39f32, 0.0];
3588            let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
3589            assert_eq!(
3590                simd_op(&CosineStateless {}, arch, x, x),
3591                0.0,
3592                "when both vectors are almost zero, similarity should be zero",
3593            );
3594            assert_eq!(
3595                simd_op(&CosineStateless {}, arch, x, y),
3596                0.0,
3597                "when one vector is almost zero, similarity should be zero",
3598            );
3599            assert_eq!(
3600                simd_op(&CosineStateless {}, arch, y, x),
3601                0.0,
3602                "when one vector is almost zero, similarity should be zero",
3603            );
3604        }
3605
3606        // Small - f32
3607        {
3608            let x: [f32; 4] = [0.0, 0.0, 1.0842022e-19f32, 0.0];
3609            let y: [f32; 4] = [0.0, 0.0, 1.0, 0.0];
3610            assert_eq!(
3611                simd_op(&CosineStateless {}, arch, x, x),
3612                1.0,
3613                "cosine-stateless should handle vectors this small",
3614            );
3615            assert_eq!(
3616                simd_op(&CosineStateless {}, arch, x, y),
3617                1.0,
3618                "cosine-stateless should handle vectors this small",
3619            );
3620            assert_eq!(
3621                simd_op(&CosineStateless {}, arch, y, x),
3622                1.0,
3623                "cosine-stateless should handle vectors this small",
3624            );
3625        }
3626
3627        let cvt = diskann_wide::cast_f32_to_f16;
3628
3629        // Zero - f16
3630        {
3631            let x: [Half; 2] = [Half::default(), Half::default()];
3632            let y: [Half; 2] = [Half::default(), cvt(1.0)];
3633            assert_eq!(
3634                simd_op(&CosineStateless {}, arch, x, x),
3635                0.0,
3636                "when both vectors are zero, similarity should be zero",
3637            );
3638            assert_eq!(
3639                simd_op(&CosineStateless {}, arch, x, y),
3640                0.0,
3641                "when one vector is zero, similarity should be zero",
3642            );
3643            assert_eq!(
3644                simd_op(&CosineStateless {}, arch, y, x),
3645                0.0,
3646                "when one vector is zero, similarity should be zero",
3647            );
3648        }
3649
3650        // Subnormal - f16
3651        {
3652            let x: [Half; 4] = [
3653                Half::default(),
3654                Half::default(),
3655                Half::MIN_POSITIVE_SUBNORMAL,
3656                Half::default(),
3657            ];
3658            let y: [Half; 4] = [Half::default(), Half::default(), cvt(1.0), Half::default()];
3659            assert_eq!(
3660                simd_op(&CosineStateless {}, arch, x, x),
3661                1.0,
3662                "when both vectors are almost zero, similarity should be zero",
3663            );
3664            assert_eq!(
3665                simd_op(&CosineStateless {}, arch, x, y),
3666                1.0,
3667                "when one vector is almost zero, similarity should be zero",
3668            );
3669            assert_eq!(
3670                simd_op(&CosineStateless {}, arch, y, x),
3671                1.0,
3672                "when one vector is almost zero, similarity should be zero",
3673            );
3674
3675            // Grab a range of floating point numbers whose squares cover the range of
3676            // our target threshold.
3677            //
3678            // Ensure that all combinations of values within this critical range to not
3679            // result in a misrounding.
3680            let threshold = f32::MIN_POSITIVE;
3681            let bound = 50;
3682            let values = {
3683                let mut down = threshold;
3684                let mut up = threshold;
3685                for _ in 0..bound {
3686                    down = down.next_down();
3687                    up = up.next_up();
3688                }
3689                assert!(down > 0.0);
3690                let min = down.sqrt();
3691                let max = up.sqrt();
3692                let mut v = min;
3693                let mut values = Vec::new();
3694                while v <= max {
3695                    values.push(v);
3696                    v = v.next_up();
3697                }
3698                values
3699            };
3700
3701            let mut lo = 0;
3702            let mut hi = 0;
3703            for i in values.iter() {
3704                for j in values.iter() {
3705                    let s: f32 = simd_op(&CosineStateless {}, arch, [*i], [*j]);
3706                    if i * i < threshold || j * j < threshold {
3707                        lo += 1;
3708                        assert_eq!(s, 0.0, "failed for i = {}, j = {}", i, j);
3709                    } else {
3710                        hi += 1;
3711                        assert_eq!(s, 1.0, "failed for i = {}, j = {}", i, j);
3712                    }
3713                }
3714            }
3715            assert_ne!(lo, 0);
3716            assert_ne!(hi, 0);
3717        }
3718    }
3719
3720    #[test]
3721    fn cosine_norm_check() {
3722        cosine_norm_check_impl::<diskann_wide::arch::Current>(diskann_wide::arch::current());
3723        cosine_norm_check_impl::<diskann_wide::arch::Scalar>(diskann_wide::arch::Scalar::new());
3724    }
3725
3726    #[test]
3727    #[cfg(target_arch = "x86_64")]
3728    fn cosine_norm_check_x86_64() {
3729        if let Some(arch) = V3::new_checked() {
3730            cosine_norm_check_impl::<V3>(arch);
3731        }
3732
3733        if let Some(arch) = V4::new_checked_miri() {
3734            cosine_norm_check_impl::<V4>(arch);
3735        }
3736    }
3737
3738    ////////////
3739    // Schema //
3740    ////////////
3741
3742    // Chunk the left and right hand slices and compute the result using a resumable function.
3743    fn test_resumable<T, L, R, A>(arch: A, x: &[L], y: &[R], chunk_size: usize) -> f32
3744    where
3745        A: Architecture,
3746        T: ResumableSIMDSchema<L, R, A, FinalReturn = f32>,
3747    {
3748        let mut acc = Resumable(<T as ResumableSIMDSchema<L, R, A>>::init(arch));
3749        let iter = std::iter::zip(x.chunks(chunk_size), y.chunks(chunk_size));
3750        for (a, b) in iter {
3751            acc = simd_op(&acc, arch, a, b);
3752        }
3753        acc.0.sum()
3754    }
3755
3756    fn stress_test_with_resumable<
3757        A: Architecture,
3758        O: Default + SIMDSchema<f32, f32, A, Return = f32>,
3759        T: ResumableSIMDSchema<f32, f32, A, NonResumable = O, FinalReturn = f32>,
3760        Rand: Rng,
3761    >(
3762        arch: A,
3763        reference: fn(&[f32], &[f32]) -> f32,
3764        dim: usize,
3765        epsilon: f32,
3766        max_relative: f32,
3767        rng: &mut Rand,
3768    ) {
3769        // Pick chunk sizes that exercise combinations of the unrolled loops.
3770        let chunk_divisors: Vec<usize> = vec![1, 2, 3, 4, 16, 54, 64, 65, 70, 77];
3771        let checker = test_util::AdHocChecker::<f32, f32>::new(|a: &[f32], b: &[f32]| {
3772            let expected = reference(a, b);
3773            let got = simd_op(&O::default(), arch, a, b);
3774            println!("dim = {}", dim);
3775            assert_relative_eq!(
3776                expected,
3777                got,
3778                epsilon = epsilon,
3779                max_relative = max_relative,
3780            );
3781
3782            if dim == 0 {
3783                return;
3784            }
3785
3786            for d in &chunk_divisors {
3787                let chunk_size = dim / d + (!dim.is_multiple_of(*d) as usize);
3788                let chunked = test_resumable::<T, f32, f32, _>(arch, a, b, chunk_size);
3789                assert_relative_eq!(chunked, got, epsilon = epsilon, max_relative = max_relative);
3790            }
3791        });
3792
3793        test_util::test_distance_function(
3794            checker,
3795            rand_distr::Normal::new(0.0, 10.0).unwrap(),
3796            rand_distr::Normal::new(0.0, 10.0).unwrap(),
3797            dim,
3798            10,
3799            rng,
3800        )
3801    }
3802
3803    #[allow(clippy::too_many_arguments)]
3804    fn stress_test<L, R, DistLeft, DistRight, O, Rand, A>(
3805        arch: A,
3806        reference: fn(&[L], &[R]) -> f32,
3807        left_dist: DistLeft,
3808        right_dist: DistRight,
3809        dim: usize,
3810        epsilon: f32,
3811        max_relative: f32,
3812        rng: &mut Rand,
3813    ) where
3814        L: test_util::CornerCases,
3815        R: test_util::CornerCases,
3816        DistLeft: test_util::GenerateRandomArguments<L>,
3817        DistRight: test_util::GenerateRandomArguments<R>,
3818        O: Default + SIMDSchema<L, R, A, Return = f32>,
3819        Rand: Rng,
3820        A: Architecture,
3821    {
3822        let checker = test_util::Checker::<L, R, f32>::new(
3823            |x: &[L], y: &[R]| simd_op(&O::default(), arch, x, y),
3824            reference,
3825            |got, expected| {
3826                assert_relative_eq!(
3827                    expected,
3828                    got,
3829                    epsilon = epsilon,
3830                    max_relative = max_relative
3831                );
3832            },
3833        );
3834
3835        let trials = if cfg!(miri) { 0 } else { 10 };
3836
3837        test_util::test_distance_function(checker, left_dist, right_dist, dim, trials, rng);
3838    }
3839
3840    fn stress_test_linf<L, Dist, Rand, A>(
3841        arch: A,
3842        reference: fn(&[L]) -> f32,
3843        dist: Dist,
3844        dim: usize,
3845        epsilon: f32,
3846        max_relative: f32,
3847        rng: &mut Rand,
3848    ) where
3849        L: test_util::CornerCases + Copy,
3850        Dist: Clone + test_util::GenerateRandomArguments<L>,
3851        Rand: Rng,
3852        A: Architecture,
3853        LInfNorm: for<'a> Target1<A, f32, &'a [L]>,
3854    {
3855        let checker = test_util::Checker::<L, L, f32>::new(
3856            |x: &[L], _y: &[L]| (LInfNorm).run(arch, x),
3857            |x: &[L], _y: &[L]| reference(x),
3858            |got, expected| {
3859                assert_relative_eq!(
3860                    expected,
3861                    got,
3862                    epsilon = epsilon,
3863                    max_relative = max_relative
3864                );
3865            },
3866        );
3867
3868        println!("checking {dim}");
3869        test_util::test_distance_function(checker, dist.clone(), dist, dim, 10, rng);
3870    }
3871
3872    /////////
3873    // f32 //
3874    /////////
3875
3876    macro_rules! float_test {
3877        ($name:ident,
3878         $impl:ty,
3879         $resumable:ident,
3880         $reference:path,
3881         $eps:literal,
3882         $relative:literal,
3883         $seed:literal,
3884         $upper:literal,
3885         $($arch:tt)*
3886        ) => {
3887            #[test]
3888            fn $name() {
3889                if let Some(arch) = $($arch)* {
3890                    let mut rng = StdRng::seed_from_u64($seed);
3891                    for dim in 0..$upper {
3892                        stress_test_with_resumable::<_, $impl, $resumable<_>, StdRng>(
3893                            arch,
3894                            |l, r| $reference(l, r).into_inner(),
3895                            dim,
3896                            $eps,
3897                            $relative,
3898                            &mut rng,
3899                        );
3900                    }
3901                }
3902            }
3903        }
3904    }
3905
3906    //----//
3907    // L2 //
3908    //----//
3909
3910    float_test!(
3911        test_l2_f32_current,
3912        L2,
3913        ResumableL2,
3914        reference::reference_squared_l2_f32_mathematical,
3915        1e-5,
3916        1e-5,
3917        0xf149c2bcde660128,
3918        64,
3919        Some(diskann_wide::ARCH)
3920    );
3921
3922    float_test!(
3923        test_l2_f32_scalar,
3924        L2,
3925        ResumableL2,
3926        reference::reference_squared_l2_f32_mathematical,
3927        1e-5,
3928        1e-5,
3929        0xf149c2bcde660128,
3930        64,
3931        Some(diskann_wide::arch::Scalar)
3932    );
3933
3934    #[cfg(target_arch = "x86_64")]
3935    float_test!(
3936        test_l2_f32_x86_64_v3,
3937        L2,
3938        ResumableL2,
3939        reference::reference_squared_l2_f32_mathematical,
3940        1e-5,
3941        1e-5,
3942        0xf149c2bcde660128,
3943        256,
3944        V3::new_checked()
3945    );
3946
3947    #[cfg(target_arch = "x86_64")]
3948    float_test!(
3949        test_l2_f32_x86_64_v4,
3950        L2,
3951        ResumableL2,
3952        reference::reference_squared_l2_f32_mathematical,
3953        1e-5,
3954        1e-5,
3955        0xf149c2bcde660128,
3956        256,
3957        V4::new_checked_miri()
3958    );
3959
3960    #[cfg(target_arch = "aarch64")]
3961    float_test!(
3962        test_l2_f32_aarch64_neon,
3963        L2,
3964        ResumableL2,
3965        reference::reference_squared_l2_f32_mathematical,
3966        1e-5,
3967        1e-5,
3968        0xf149c2bcde660128,
3969        256,
3970        Neon::new_checked()
3971    );
3972
3973    //----//
3974    // IP //
3975    //----//
3976
3977    float_test!(
3978        test_ip_f32_current,
3979        IP,
3980        ResumableIP,
3981        reference::reference_innerproduct_f32_mathematical,
3982        2e-4,
3983        1e-3,
3984        0xb4687c17a9ea9866,
3985        64,
3986        Some(diskann_wide::ARCH)
3987    );
3988
3989    float_test!(
3990        test_ip_f32_scalar,
3991        IP,
3992        ResumableIP,
3993        reference::reference_innerproduct_f32_mathematical,
3994        2e-4,
3995        1e-3,
3996        0xb4687c17a9ea9866,
3997        64,
3998        Some(diskann_wide::arch::Scalar)
3999    );
4000
4001    #[cfg(target_arch = "x86_64")]
4002    float_test!(
4003        test_ip_f32_x86_64_v3,
4004        IP,
4005        ResumableIP,
4006        reference::reference_innerproduct_f32_mathematical,
4007        2e-4,
4008        1e-3,
4009        0xb4687c17a9ea9866,
4010        256,
4011        V3::new_checked()
4012    );
4013
4014    #[cfg(target_arch = "x86_64")]
4015    float_test!(
4016        test_ip_f32_x86_64_v4,
4017        IP,
4018        ResumableIP,
4019        reference::reference_innerproduct_f32_mathematical,
4020        2e-4,
4021        1e-3,
4022        0xb4687c17a9ea9866,
4023        256,
4024        V4::new_checked_miri()
4025    );
4026
4027    #[cfg(target_arch = "aarch64")]
4028    float_test!(
4029        test_ip_f32_aarch64_neon,
4030        IP,
4031        ResumableIP,
4032        reference::reference_innerproduct_f32_mathematical,
4033        2e-4,
4034        1e-3,
4035        0xb4687c17a9ea9866,
4036        256,
4037        Neon::new_checked()
4038    );
4039
4040    //--------//
4041    // Cosine //
4042    //--------//
4043
4044    float_test!(
4045        test_cosine_f32_current,
4046        CosineStateless,
4047        ResumableCosine,
4048        reference::reference_cosine_f32_mathematical,
4049        1e-5,
4050        1e-5,
4051        0xe860e9dc65f38bb8,
4052        64,
4053        Some(diskann_wide::ARCH)
4054    );
4055
4056    float_test!(
4057        test_cosine_f32_scalar,
4058        CosineStateless,
4059        ResumableCosine,
4060        reference::reference_cosine_f32_mathematical,
4061        1e-5,
4062        1e-5,
4063        0xe860e9dc65f38bb8,
4064        64,
4065        Some(diskann_wide::arch::Scalar)
4066    );
4067
4068    #[cfg(target_arch = "x86_64")]
4069    float_test!(
4070        test_cosine_f32_x86_64_v3,
4071        CosineStateless,
4072        ResumableCosine,
4073        reference::reference_cosine_f32_mathematical,
4074        1e-5,
4075        1e-5,
4076        0xe860e9dc65f38bb8,
4077        256,
4078        V3::new_checked()
4079    );
4080
4081    #[cfg(target_arch = "x86_64")]
4082    float_test!(
4083        test_cosine_f32_x86_64_v4,
4084        CosineStateless,
4085        ResumableCosine,
4086        reference::reference_cosine_f32_mathematical,
4087        1e-5,
4088        1e-5,
4089        0xe860e9dc65f38bb8,
4090        256,
4091        V4::new_checked_miri()
4092    );
4093
4094    #[cfg(target_arch = "aarch64")]
4095    float_test!(
4096        test_cosine_f32_aarch64_neon,
4097        CosineStateless,
4098        ResumableCosine,
4099        reference::reference_cosine_f32_mathematical,
4100        1e-5,
4101        1e-5,
4102        0xe860e9dc65f38bb8,
4103        256,
4104        Neon::new_checked()
4105    );
4106
4107    /////////
4108    // f16 //
4109    /////////
4110
4111    macro_rules! half_test {
4112        ($name:ident,
4113         $impl:ty,
4114         $reference:path,
4115         $eps:literal,
4116         $relative:literal,
4117         $seed:literal,
4118         $upper:literal,
4119         $($arch:tt)*
4120        ) => {
4121            #[test]
4122            fn $name() {
4123                if let Some(arch) = $($arch)* {
4124                    let mut rng = StdRng::seed_from_u64($seed);
4125                    for dim in 0..$upper {
4126                        stress_test::<
4127                            Half,
4128                            Half,
4129                            rand_distr::Normal<f32>,
4130                            rand_distr::Normal<f32>,
4131                            $impl,
4132                            StdRng,
4133                            _
4134                        >(
4135                            arch,
4136                            |l, r| $reference(l, r).into_inner(),
4137                            rand_distr::Normal::new(0.0, 10.0).unwrap(),
4138                            rand_distr::Normal::new(0.0, 10.0).unwrap(),
4139                            dim,
4140                            $eps,
4141                            $relative,
4142                            &mut rng
4143                        );
4144                    }
4145                }
4146            }
4147        }
4148    }
4149
4150    //----//
4151    // L2 //
4152    //----//
4153
4154    half_test!(
4155        test_l2_f16_current,
4156        L2,
4157        reference::reference_squared_l2_f16_mathematical,
4158        1e-5,
4159        1e-5,
4160        0x87ca6f1051667500,
4161        64,
4162        Some(diskann_wide::ARCH)
4163    );
4164
4165    half_test!(
4166        test_l2_f16_scalar,
4167        L2,
4168        reference::reference_squared_l2_f16_mathematical,
4169        1e-5,
4170        1e-5,
4171        0x87ca6f1051667500,
4172        64,
4173        Some(diskann_wide::arch::Scalar)
4174    );
4175
4176    #[cfg(target_arch = "x86_64")]
4177    half_test!(
4178        test_l2_f16_x86_64_v3,
4179        L2,
4180        reference::reference_squared_l2_f16_mathematical,
4181        1e-5,
4182        1e-5,
4183        0x87ca6f1051667500,
4184        256,
4185        V3::new_checked()
4186    );
4187
4188    #[cfg(target_arch = "x86_64")]
4189    half_test!(
4190        test_l2_f16_x86_64_v4,
4191        L2,
4192        reference::reference_squared_l2_f16_mathematical,
4193        1e-5,
4194        1e-5,
4195        0x87ca6f1051667500,
4196        256,
4197        V4::new_checked_miri()
4198    );
4199
4200    #[cfg(target_arch = "aarch64")]
4201    half_test!(
4202        test_l2_f16_aarch64_neon,
4203        L2,
4204        reference::reference_squared_l2_f16_mathematical,
4205        1e-5,
4206        1e-5,
4207        0x87ca6f1051667500,
4208        256,
4209        Neon::new_checked()
4210    );
4211
4212    //----//
4213    // IP //
4214    //----//
4215
4216    half_test!(
4217        test_ip_f16_current,
4218        IP,
4219        reference::reference_innerproduct_f16_mathematical,
4220        2e-4,
4221        2e-4,
4222        0x5909f5f20307ccbe,
4223        64,
4224        Some(diskann_wide::ARCH)
4225    );
4226
4227    half_test!(
4228        test_ip_f16_scalar,
4229        IP,
4230        reference::reference_innerproduct_f16_mathematical,
4231        2e-4,
4232        2e-4,
4233        0x5909f5f20307ccbe,
4234        64,
4235        Some(diskann_wide::arch::Scalar)
4236    );
4237
4238    #[cfg(target_arch = "x86_64")]
4239    half_test!(
4240        test_ip_f16_x86_64_v3,
4241        IP,
4242        reference::reference_innerproduct_f16_mathematical,
4243        2e-4,
4244        2e-4,
4245        0x5909f5f20307ccbe,
4246        256,
4247        V3::new_checked()
4248    );
4249
4250    #[cfg(target_arch = "x86_64")]
4251    half_test!(
4252        test_ip_f16_x86_64_v4,
4253        IP,
4254        reference::reference_innerproduct_f16_mathematical,
4255        2e-4,
4256        2e-4,
4257        0x5909f5f20307ccbe,
4258        256,
4259        V4::new_checked_miri()
4260    );
4261
4262    #[cfg(target_arch = "aarch64")]
4263    half_test!(
4264        test_ip_f16_aarch64_neon,
4265        IP,
4266        reference::reference_innerproduct_f16_mathematical,
4267        2e-4,
4268        2e-4,
4269        0x5909f5f20307ccbe,
4270        256,
4271        Neon::new_checked()
4272    );
4273
4274    //--------//
4275    // Cosine //
4276    //--------//
4277
4278    half_test!(
4279        test_cosine_f16_current,
4280        CosineStateless,
4281        reference::reference_cosine_f16_mathematical,
4282        1e-5,
4283        1e-5,
4284        0x41dda34655f05ef6,
4285        64,
4286        Some(diskann_wide::ARCH)
4287    );
4288
4289    half_test!(
4290        test_cosine_f16_scalar,
4291        CosineStateless,
4292        reference::reference_cosine_f16_mathematical,
4293        1e-5,
4294        1e-5,
4295        0x41dda34655f05ef6,
4296        64,
4297        Some(diskann_wide::arch::Scalar)
4298    );
4299
4300    #[cfg(target_arch = "x86_64")]
4301    half_test!(
4302        test_cosine_f16_x86_64_v3,
4303        CosineStateless,
4304        reference::reference_cosine_f16_mathematical,
4305        1e-5,
4306        1e-5,
4307        0x41dda34655f05ef6,
4308        256,
4309        V3::new_checked()
4310    );
4311
4312    #[cfg(target_arch = "x86_64")]
4313    half_test!(
4314        test_cosine_f16_x86_64_v4,
4315        CosineStateless,
4316        reference::reference_cosine_f16_mathematical,
4317        1e-5,
4318        1e-5,
4319        0x41dda34655f05ef6,
4320        256,
4321        V4::new_checked_miri()
4322    );
4323
4324    #[cfg(target_arch = "aarch64")]
4325    half_test!(
4326        test_cosine_f16_aarch64_neon,
4327        CosineStateless,
4328        reference::reference_cosine_f16_mathematical,
4329        1e-5,
4330        1e-5,
4331        0x41dda34655f05ef6,
4332        256,
4333        Neon::new_checked()
4334    );
4335
4336    /////////////
4337    // Integer //
4338    /////////////
4339
4340    macro_rules! int_test {
4341        (
4342            $name:ident,
4343            $T:ty,
4344            $impl:ty,
4345            $reference:path,
4346            $seed:literal,
4347            $upper:literal,
4348            { $($arch:tt)* }
4349        ) => {
4350            #[test]
4351            fn $name() {
4352                if let Some(arch) = $($arch)* {
4353                    let mut rng = StdRng::seed_from_u64($seed);
4354                    for dim in 0..$upper {
4355                        stress_test::<$T, $T, _, _, $impl, _, _>(
4356                            arch,
4357                            |l, r| $reference(l, r).into_inner(),
4358                            StandardUniform,
4359                            StandardUniform,
4360                            dim,
4361                            0.0,
4362                            0.0,
4363                            &mut rng,
4364                        )
4365                    }
4366                }
4367            }
4368        }
4369    }
4370
4371    //----//
4372    // U8 //
4373    //----//
4374
4375    int_test!(
4376        test_l2_u8_current,
4377        u8,
4378        L2,
4379        reference::reference_squared_l2_u8_mathematical,
4380        0x945bdc37d8279d4b,
4381        128,
4382        { Some(ARCH) }
4383    );
4384
4385    int_test!(
4386        test_l2_u8_scalar,
4387        u8,
4388        L2,
4389        reference::reference_squared_l2_u8_mathematical,
4390        0x74c86334ab7a51f9,
4391        128,
4392        { Some(diskann_wide::arch::Scalar) }
4393    );
4394
4395    #[cfg(target_arch = "x86_64")]
4396    int_test!(
4397        test_l2_u8_x86_64_v3,
4398        u8,
4399        L2,
4400        reference::reference_squared_l2_u8_mathematical,
4401        0x74c86334ab7a51f9,
4402        256,
4403        { V3::new_checked() }
4404    );
4405
4406    #[cfg(target_arch = "x86_64")]
4407    int_test!(
4408        test_l2_u8_x86_64_v4,
4409        u8,
4410        L2,
4411        reference::reference_squared_l2_u8_mathematical,
4412        0x74c86334ab7a51f9,
4413        320,
4414        { V4::new_checked_miri() }
4415    );
4416
4417    #[cfg(target_arch = "aarch64")]
4418    int_test!(
4419        test_l2_u8_aarch64_neon,
4420        u8,
4421        L2,
4422        reference::reference_squared_l2_u8_mathematical,
4423        0x74c86334ab7a51f9,
4424        320,
4425        { Neon::new_checked() }
4426    );
4427
4428    int_test!(
4429        test_ip_u8_current,
4430        u8,
4431        IP,
4432        reference::reference_innerproduct_u8_mathematical,
4433        0xcbe0342c75085fd5,
4434        64,
4435        { Some(ARCH) }
4436    );
4437
4438    int_test!(
4439        test_ip_u8_scalar,
4440        u8,
4441        IP,
4442        reference::reference_innerproduct_u8_mathematical,
4443        0x888e07fc489e773f,
4444        64,
4445        { Some(diskann_wide::arch::Scalar) }
4446    );
4447
4448    #[cfg(target_arch = "x86_64")]
4449    int_test!(
4450        test_ip_u8_x86_64_v3,
4451        u8,
4452        IP,
4453        reference::reference_innerproduct_u8_mathematical,
4454        0x888e07fc489e773f,
4455        256,
4456        { V3::new_checked() }
4457    );
4458
4459    #[cfg(target_arch = "x86_64")]
4460    int_test!(
4461        test_ip_u8_x86_64_v4,
4462        u8,
4463        IP,
4464        reference::reference_innerproduct_u8_mathematical,
4465        0x888e07fc489e773f,
4466        320,
4467        { V4::new_checked_miri() }
4468    );
4469
4470    #[cfg(target_arch = "aarch64")]
4471    int_test!(
4472        test_ip_u8_aarch64_neon,
4473        u8,
4474        IP,
4475        reference::reference_innerproduct_u8_mathematical,
4476        0x888e07fc489e773f,
4477        320,
4478        { Neon::new_checked() }
4479    );
4480
4481    int_test!(
4482        test_cosine_u8_current,
4483        u8,
4484        CosineStateless,
4485        reference::reference_cosine_u8_mathematical,
4486        0x96867b6aff616b28,
4487        64,
4488        { Some(ARCH) }
4489    );
4490
4491    int_test!(
4492        test_cosine_u8_scalar,
4493        u8,
4494        CosineStateless,
4495        reference::reference_cosine_u8_mathematical,
4496        0xcc258c9391733211,
4497        64,
4498        { Some(diskann_wide::arch::Scalar) }
4499    );
4500
4501    #[cfg(target_arch = "x86_64")]
4502    int_test!(
4503        test_cosine_u8_x86_64_v3,
4504        u8,
4505        CosineStateless,
4506        reference::reference_cosine_u8_mathematical,
4507        0xcc258c9391733211,
4508        256,
4509        { V3::new_checked() }
4510    );
4511
4512    #[cfg(target_arch = "x86_64")]
4513    int_test!(
4514        test_cosine_u8_x86_64_v4,
4515        u8,
4516        CosineStateless,
4517        reference::reference_cosine_u8_mathematical,
4518        0xcc258c9391733211,
4519        320,
4520        { V4::new_checked_miri() }
4521    );
4522
4523    #[cfg(target_arch = "aarch64")]
4524    int_test!(
4525        test_cosine_u8_aarch64_neon,
4526        u8,
4527        CosineStateless,
4528        reference::reference_cosine_u8_mathematical,
4529        0xcc258c9391733211,
4530        320,
4531        { Neon::new_checked() }
4532    );
4533
4534    //----//
4535    // I8 //
4536    //----//
4537
4538    int_test!(
4539        test_l2_i8_current,
4540        i8,
4541        L2,
4542        reference::reference_squared_l2_i8_mathematical,
4543        0xa60136248cd3c2f0,
4544        64,
4545        { Some(ARCH) }
4546    );
4547
4548    int_test!(
4549        test_l2_i8_scalar,
4550        i8,
4551        L2,
4552        reference::reference_squared_l2_i8_mathematical,
4553        0x3e8bada709e176be,
4554        64,
4555        { Some(diskann_wide::arch::Scalar) }
4556    );
4557
4558    #[cfg(target_arch = "x86_64")]
4559    int_test!(
4560        test_l2_i8_x86_64_v3,
4561        i8,
4562        L2,
4563        reference::reference_squared_l2_i8_mathematical,
4564        0x3e8bada709e176be,
4565        256,
4566        { V3::new_checked() }
4567    );
4568
4569    #[cfg(target_arch = "x86_64")]
4570    int_test!(
4571        test_l2_i8_x86_64_v4,
4572        i8,
4573        L2,
4574        reference::reference_squared_l2_i8_mathematical,
4575        0x3e8bada709e176be,
4576        320,
4577        { V4::new_checked_miri() }
4578    );
4579
4580    #[cfg(target_arch = "aarch64")]
4581    int_test!(
4582        test_l2_i8_aarch64_neon,
4583        i8,
4584        L2,
4585        reference::reference_squared_l2_i8_mathematical,
4586        0x3e8bada709e176be,
4587        320,
4588        { Neon::new_checked() }
4589    );
4590
4591    int_test!(
4592        test_ip_i8_current,
4593        i8,
4594        IP,
4595        reference::reference_innerproduct_i8_mathematical,
4596        0xe8306104740509e1,
4597        64,
4598        { Some(ARCH) }
4599    );
4600
4601    int_test!(
4602        test_ip_i8_scalar,
4603        i8,
4604        IP,
4605        reference::reference_innerproduct_i8_mathematical,
4606        0x8a263408c7b31d85,
4607        64,
4608        { Some(diskann_wide::arch::Scalar) }
4609    );
4610
4611    #[cfg(target_arch = "x86_64")]
4612    int_test!(
4613        test_ip_i8_x86_64_v3,
4614        i8,
4615        IP,
4616        reference::reference_innerproduct_i8_mathematical,
4617        0x8a263408c7b31d85,
4618        256,
4619        { V3::new_checked() }
4620    );
4621
4622    #[cfg(target_arch = "x86_64")]
4623    int_test!(
4624        test_ip_i8_x86_64_v4,
4625        i8,
4626        IP,
4627        reference::reference_innerproduct_i8_mathematical,
4628        0x8a263408c7b31d85,
4629        320,
4630        { V4::new_checked_miri() }
4631    );
4632
4633    #[cfg(target_arch = "aarch64")]
4634    int_test!(
4635        test_ip_i8_aarch64_neon,
4636        i8,
4637        IP,
4638        reference::reference_innerproduct_i8_mathematical,
4639        0x8a263408c7b31d85,
4640        320,
4641        { Neon::new_checked() }
4642    );
4643
4644    int_test!(
4645        test_cosine_i8_current,
4646        i8,
4647        CosineStateless,
4648        reference::reference_cosine_i8_mathematical,
4649        0x818c210190701e4b,
4650        64,
4651        { Some(ARCH) }
4652    );
4653
4654    int_test!(
4655        test_cosine_i8_scalar,
4656        i8,
4657        CosineStateless,
4658        reference::reference_cosine_i8_mathematical,
4659        0x2d077bed2629b18e,
4660        64,
4661        { Some(diskann_wide::arch::Scalar) }
4662    );
4663
4664    #[cfg(target_arch = "x86_64")]
4665    int_test!(
4666        test_cosine_i8_x86_64_v3,
4667        i8,
4668        CosineStateless,
4669        reference::reference_cosine_i8_mathematical,
4670        0x2d077bed2629b18e,
4671        256,
4672        { V3::new_checked() }
4673    );
4674
4675    #[cfg(target_arch = "x86_64")]
4676    int_test!(
4677        test_cosine_i8_x86_64_v4,
4678        i8,
4679        CosineStateless,
4680        reference::reference_cosine_i8_mathematical,
4681        0x2d077bed2629b18e,
4682        320,
4683        { V4::new_checked_miri() }
4684    );
4685
4686    #[cfg(target_arch = "aarch64")]
4687    int_test!(
4688        test_cosine_i8_aarch64_neon,
4689        i8,
4690        CosineStateless,
4691        reference::reference_cosine_i8_mathematical,
4692        0x2d077bed2629b18e,
4693        320,
4694        { Neon::new_checked() }
4695    );
4696
4697    //////////
4698    // LInf //
4699    //////////
4700
4701    macro_rules! linf_test {
4702        ($name:ident,
4703         $T:ty,
4704         $reference:path,
4705         $eps:literal,
4706         $relative:literal,
4707         $seed:literal,
4708         $upper:literal,
4709         $($arch:tt)*
4710        ) => {
4711            #[test]
4712            fn $name() {
4713                if let Some(arch) = $($arch)* {
4714                    let mut rng = StdRng::seed_from_u64($seed);
4715                    for dim in 0..$upper {
4716                        stress_test_linf::<$T, _, StdRng, _>(
4717                            arch,
4718                            |l| $reference(l).into_inner(),
4719                            rand_distr::Normal::new(-10.0, 10.0).unwrap(),
4720                            dim,
4721                            $eps,
4722                            $relative,
4723                            &mut rng,
4724                        );
4725                    }
4726                }
4727            }
4728        }
4729    }
4730
4731    linf_test!(
4732        test_linf_f32_scalar,
4733        f32,
4734        reference::reference_linf_f32_mathematical,
4735        1e-6,
4736        1e-6,
4737        0xf149c2bcde660128,
4738        256,
4739        Some(Scalar::new())
4740    );
4741
4742    #[cfg(target_arch = "x86_64")]
4743    linf_test!(
4744        test_linf_f32_v3,
4745        f32,
4746        reference::reference_linf_f32_mathematical,
4747        1e-6,
4748        1e-6,
4749        0xf149c2bcde660128,
4750        256,
4751        V3::new_checked()
4752    );
4753
4754    #[cfg(target_arch = "x86_64")]
4755    linf_test!(
4756        test_linf_f32_v4,
4757        f32,
4758        reference::reference_linf_f32_mathematical,
4759        1e-6,
4760        1e-6,
4761        0xf149c2bcde660128,
4762        256,
4763        V4::new_checked_miri()
4764    );
4765
4766    #[cfg(target_arch = "aarch64")]
4767    linf_test!(
4768        test_linf_f32_neon,
4769        f32,
4770        reference::reference_linf_f32_mathematical,
4771        1e-6,
4772        1e-6,
4773        0xf149c2bcde660128,
4774        256,
4775        Neon::new_checked()
4776    );
4777
4778    linf_test!(
4779        test_linf_f16_scalar,
4780        f16,
4781        reference::reference_linf_f16_mathematical,
4782        1e-6,
4783        1e-6,
4784        0xf149c2bcde660128,
4785        256,
4786        Some(Scalar::new())
4787    );
4788
4789    #[cfg(target_arch = "x86_64")]
4790    linf_test!(
4791        test_linf_f16_v3,
4792        f16,
4793        reference::reference_linf_f16_mathematical,
4794        1e-6,
4795        1e-6,
4796        0xf149c2bcde660128,
4797        256,
4798        V3::new_checked()
4799    );
4800
4801    #[cfg(target_arch = "x86_64")]
4802    linf_test!(
4803        test_linf_f16_v4,
4804        f16,
4805        reference::reference_linf_f16_mathematical,
4806        1e-6,
4807        1e-6,
4808        0xf149c2bcde660128,
4809        256,
4810        V4::new_checked_miri()
4811    );
4812
4813    #[cfg(target_arch = "aarch64")]
4814    linf_test!(
4815        test_linf_f16_neon,
4816        f16,
4817        reference::reference_linf_f16_mathematical,
4818        1e-6,
4819        1e-6,
4820        0xf149c2bcde660128,
4821        256,
4822        Neon::new_checked()
4823    );
4824
4825    ////////////////
4826    // Miri Tests //
4827    ////////////////
4828
4829    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4830    enum DataType {
4831        Float32,
4832        Float16,
4833        UInt8,
4834        Int8,
4835    }
4836
4837    trait AsDataType {
4838        fn as_data_type() -> DataType;
4839    }
4840
4841    impl AsDataType for f32 {
4842        fn as_data_type() -> DataType {
4843            DataType::Float32
4844        }
4845    }
4846
4847    impl AsDataType for f16 {
4848        fn as_data_type() -> DataType {
4849            DataType::Float16
4850        }
4851    }
4852
4853    impl AsDataType for u8 {
4854        fn as_data_type() -> DataType {
4855            DataType::UInt8
4856        }
4857    }
4858
4859    impl AsDataType for i8 {
4860        fn as_data_type() -> DataType {
4861            DataType::Int8
4862        }
4863    }
4864
4865    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4866    enum Arch {
4867        Scalar,
4868        #[expect(non_camel_case_types)]
4869        X86_64_V3,
4870        #[expect(non_camel_case_types)]
4871        X86_64_V4,
4872        Aarch64Neon,
4873    }
4874
4875    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
4876    struct Key {
4877        arch: Arch,
4878        left: DataType,
4879        right: DataType,
4880    }
4881
4882    impl Key {
4883        fn new(arch: Arch, left: DataType, right: DataType) -> Self {
4884            Self { arch, left, right }
4885        }
4886    }
4887
4888    static MIRI_BOUNDS: LazyLock<HashMap<Key, usize>> = LazyLock::new(|| {
4889        use Arch::{Aarch64Neon, Scalar, X86_64_V3, X86_64_V4};
4890        use DataType::{Float16, Float32, Int8, UInt8};
4891
4892        [
4893            (Key::new(Scalar, Float32, Float32), 64),
4894            (Key::new(X86_64_V3, Float32, Float32), 256),
4895            (Key::new(X86_64_V4, Float32, Float32), 256),
4896            (Key::new(Aarch64Neon, Float32, Float32), 128),
4897            (Key::new(Scalar, Float16, Float16), 64),
4898            (Key::new(X86_64_V3, Float16, Float16), 256),
4899            (Key::new(X86_64_V4, Float16, Float16), 256),
4900            (Key::new(Aarch64Neon, Float16, Float16), 128),
4901            (Key::new(Scalar, Float32, Float16), 64),
4902            (Key::new(X86_64_V3, Float32, Float16), 256),
4903            (Key::new(X86_64_V4, Float32, Float16), 256),
4904            (Key::new(Aarch64Neon, Float32, Float16), 128),
4905            (Key::new(Scalar, UInt8, UInt8), 64),
4906            (Key::new(X86_64_V3, UInt8, UInt8), 256),
4907            (Key::new(X86_64_V4, UInt8, UInt8), 320),
4908            (Key::new(Aarch64Neon, UInt8, UInt8), 128),
4909            (Key::new(Scalar, Int8, Int8), 64),
4910            (Key::new(X86_64_V3, Int8, Int8), 256),
4911            (Key::new(X86_64_V4, Int8, Int8), 320),
4912            (Key::new(Aarch64Neon, Int8, Int8), 128),
4913        ]
4914        .into_iter()
4915        .collect()
4916    });
4917
4918    macro_rules! test_bounds {
4919        (
4920            $function:ident,
4921            $left:ty,
4922            $left_ex:expr,
4923            $right:ty,
4924            $right_ex:expr
4925        ) => {
4926            #[test]
4927            fn $function() {
4928                let left: $left = $left_ex;
4929                let right: $right = $right_ex;
4930
4931                let left_type = <$left>::as_data_type();
4932                let right_type = <$right>::as_data_type();
4933
4934                // Scalar
4935                {
4936                    let max = MIRI_BOUNDS[&Key::new(Arch::Scalar, left_type, right_type)];
4937                    for dim in 0..max {
4938                        let left: Vec<$left> = vec![left; dim];
4939                        let right: Vec<$right> = vec![right; dim];
4940
4941                        let arch = diskann_wide::arch::Scalar;
4942                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
4943                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
4944                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4945                    }
4946                }
4947
4948                #[cfg(target_arch = "x86_64")]
4949                if let Some(arch) = V3::new_checked() {
4950                    let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V3, left_type, right_type)];
4951                    for dim in 0..max {
4952                        let left: Vec<$left> = vec![left; dim];
4953                        let right: Vec<$right> = vec![right; dim];
4954
4955                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
4956                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
4957                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4958                    }
4959                }
4960
4961                #[cfg(target_arch = "x86_64")]
4962                if let Some(arch) = V4::new_checked_miri() {
4963                    let max = MIRI_BOUNDS[&Key::new(Arch::X86_64_V4, left_type, right_type)];
4964                    for dim in 0..max {
4965                        let left: Vec<$left> = vec![left; dim];
4966                        let right: Vec<$right> = vec![right; dim];
4967
4968                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
4969                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
4970                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4971                    }
4972                }
4973
4974                #[cfg(target_arch = "aarch64")]
4975                if let Some(arch) = Neon::new_checked() {
4976                    let max = MIRI_BOUNDS[&Key::new(Arch::Aarch64Neon, left_type, right_type)];
4977                    for dim in 0..max {
4978                        let left: Vec<$left> = vec![left; dim];
4979                        let right: Vec<$right> = vec![right; dim];
4980
4981                        simd_op(&L2, arch, left.as_slice(), right.as_slice());
4982                        simd_op(&IP, arch, left.as_slice(), right.as_slice());
4983                        simd_op(&CosineStateless, arch, left.as_slice(), right.as_slice());
4984                    }
4985                }
4986            }
4987        };
4988    }
4989
4990    test_bounds!(miri_test_bounds_f32xf32, f32, 1.0f32, f32, 2.0f32);
4991    test_bounds!(
4992        miri_test_bounds_f16xf16,
4993        f16,
4994        diskann_wide::cast_f32_to_f16(1.0f32),
4995        f16,
4996        diskann_wide::cast_f32_to_f16(2.0f32)
4997    );
4998    test_bounds!(
4999        miri_test_bounds_f32xf16,
5000        f32,
5001        1.0f32,
5002        f16,
5003        diskann_wide::cast_f32_to_f16(2.0f32)
5004    );
5005    test_bounds!(miri_test_bounds_u8xu8, u8, 1u8, u8, 1u8);
5006    test_bounds!(miri_test_bounds_i8xi8, i8, 1i8, i8, 1i8);
5007}