Skip to main content

diskann_vector/distance/
implementations.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::{arch::Target2, Architecture, ARCH};
7
8/// Experimental traits for distance functions.
9use super::simd;
10use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};
11
12trait ToSlice {
13    type Target;
14    fn to_slice(&self) -> &[Self::Target];
15}
16
17impl<T> ToSlice for &[T] {
18    type Target = T;
19    fn to_slice(&self) -> &[T] {
20        self
21    }
22}
23impl<T, const N: usize> ToSlice for &[T; N] {
24    type Target = T;
25    fn to_slice(&self) -> &[T] {
26        &self[..]
27    }
28}
29impl<T, const N: usize> ToSlice for [T; N] {
30    type Target = T;
31    fn to_slice(&self) -> &[T] {
32        &self[..]
33    }
34}
35
36macro_rules! architecture_hook {
37    ($functor:ty, $impl:path) => {
38        impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
39        where
40            A: Architecture,
41            L: ToSlice,
42            R: ToSlice,
43            $impl: simd::SIMDSchema<L::Target, R::Target, A>,
44            Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
45        {
46            #[inline(always)]
47            fn run(self, arch: A, left: L, right: R) -> T {
48                Self::post_op(simd::simd_op(
49                    &$impl,
50                    arch,
51                    left.to_slice(),
52                    right.to_slice(),
53                ))
54            }
55        }
56
57        impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
58        where
59            A: Architecture,
60            L: ToSlice,
61            R: ToSlice,
62            Self: diskann_wide::arch::Target2<A, T, L, R>,
63        {
64            #[inline(always)]
65            fn run(arch: A, left: L, right: R) -> T {
66                arch.run2(Self::default(), left, right)
67            }
68        }
69    };
70}
71
72/// A utility for specializing distance computations for fixed-length slices.
73#[derive(Debug, Clone, Copy)]
74pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);
75
76impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
77    for Specialize<N, F>
78where
79    A: Architecture,
80    F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
81{
82    #[inline(always)]
83    fn run(arch: A, x: &[L], y: &[R]) -> T {
84        if (x.len() != N) | (y.len() != N) {
85            fail_length_check(x, y, N);
86        }
87
88        // SAFETY: We have checked that both arguments have the correct length.
89        //
90        // The alignment requirements of arrays are the alignment requirements of
91        // `Left` and `Right` respectively, which is provided by the corresponding slices.
92        arch.run2(
93            F::default(),
94            unsafe { &*(x.as_ptr() as *const [L; N]) },
95            unsafe { &*(y.as_ptr() as *const [R; N]) },
96        )
97    }
98}
99
100// Outline the panic formatting and keep the calling convention the same as
101// the top function. This keeps code generation extremely lightweight.
102#[inline(never)]
103#[allow(clippy::panic)]
104fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
105    let message = if x.len() != len {
106        ("first", x.len())
107    } else {
108        ("second", y.len())
109    };
110    panic!(
111        "expected {} argument to have length {}, instead it has length {}",
112        message.0, len, message.1
113    );
114}
115
116/// An internal trait to transform the result of the low-level SIMD ops into a value
117/// expected by the rest of DiskANN.
118///
119/// Keep this trait private as it is likely to either change or be removed completely in the
120/// near future once better integer implementations come online.
121pub(super) trait PostOp<From, To> {
122    fn post_op(x: From) -> To;
123}
124
125/// Provide explicit dynamic and sized implementations for a distance functor.
126macro_rules! use_simd_implementation {
127    ($functor:ty, $T:ty, $U:ty) => {
128        //////////////////////
129        // Similarity Score //
130        //////////////////////
131
132        // Dynamically Sized.
133        impl PureDistanceFunction<&[$T], &[$U], SimilarityScore<f32>> for $functor {
134            #[inline]
135            fn evaluate(x: &[$T], y: &[$U]) -> SimilarityScore<f32> {
136                <$functor>::default().run(ARCH, x, y)
137            }
138        }
139        // Statically Sized
140        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
141            for $functor
142        {
143            #[inline]
144            fn evaluate(x: &[$T; N], y: &[$U; N]) -> SimilarityScore<f32> {
145                <$functor>::default().run(ARCH, x, y)
146            }
147        }
148
149        ////////////////////////
150        // Mathematical Value //
151        ////////////////////////
152
153        // Dynamically Sized.
154        impl PureDistanceFunction<&[$T], &[$U], MathematicalValue<f32>> for $functor {
155            #[inline]
156            fn evaluate(x: &[$T], y: &[$U]) -> MathematicalValue<f32> {
157                <$functor>::default().run(ARCH, x, y)
158            }
159        }
160        // Statically Sized
161        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], MathematicalValue<f32>>
162            for $functor
163        {
164            #[inline]
165            fn evaluate(x: &[$T; N], y: &[$U; N]) -> MathematicalValue<f32> {
166                <$functor>::default().run(ARCH, x, y)
167            }
168        }
169
170        /////////
171        // f32 //
172        /////////
173
174        // Dynamically Sized
175        impl PureDistanceFunction<&[$T], &[$U], f32> for $functor {
176            #[inline(always)]
177            fn evaluate(x: &[$T], y: &[$U]) -> f32 {
178                <$functor>::default().run(ARCH, x, y)
179            }
180        }
181
182        // Statically Sized
183        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], f32> for $functor {
184            #[inline]
185            fn evaluate(x: &[$T; N], y: &[$U; N]) -> f32 {
186                <$functor>::default().run(ARCH, x, y)
187            }
188        }
189    };
190}
191
192///////////////
193// SquaredL2 //
194///////////////
195
196/// Compute the squared L2 distance between two vectors.
197#[derive(Debug, Clone, Copy, Default)]
198pub struct SquaredL2 {}
199
200impl PostOp<f32, SimilarityScore<f32>> for SquaredL2 {
201    #[inline(always)]
202    fn post_op(x: f32) -> SimilarityScore<f32> {
203        SimilarityScore::new(x)
204    }
205}
206
207impl PostOp<f32, f32> for SquaredL2 {
208    #[inline(always)]
209    fn post_op(x: f32) -> f32 {
210        x
211    }
212}
213
214impl PostOp<f32, MathematicalValue<f32>> for SquaredL2 {
215    #[inline(always)]
216    fn post_op(x: f32) -> MathematicalValue<f32> {
217        MathematicalValue::new(x)
218    }
219}
220
221architecture_hook!(SquaredL2, simd::L2);
222use_simd_implementation!(SquaredL2, f32, f32);
223use_simd_implementation!(SquaredL2, f32, Half);
224use_simd_implementation!(SquaredL2, Half, Half);
225use_simd_implementation!(SquaredL2, i8, i8);
226use_simd_implementation!(SquaredL2, u8, u8);
227
228////////////
229// FullL2 //
230////////////
231
232/// Computes the full L2 distance between two vectors.
233///
234/// Unlike `SquaredL2`, this function-like object will perform compute the full L2 distance
235/// including the trailing square root.
236#[derive(Debug, Clone, Copy, Default)]
237pub struct FullL2 {}
238
239impl PostOp<f32, SimilarityScore<f32>> for FullL2 {
240    #[inline(always)]
241    fn post_op(x: f32) -> SimilarityScore<f32> {
242        SimilarityScore::new(x.sqrt())
243    }
244}
245
246impl PostOp<f32, f32> for FullL2 {
247    #[inline(always)]
248    fn post_op(x: f32) -> f32 {
249        x.sqrt()
250    }
251}
252
253impl PostOp<f32, MathematicalValue<f32>> for FullL2 {
254    #[inline(always)]
255    fn post_op(x: f32) -> MathematicalValue<f32> {
256        MathematicalValue::new(x.sqrt())
257    }
258}
259
260architecture_hook!(FullL2, simd::L2);
261use_simd_implementation!(FullL2, f32, f32);
262use_simd_implementation!(FullL2, f32, Half);
263use_simd_implementation!(FullL2, Half, Half);
264use_simd_implementation!(FullL2, i8, i8);
265use_simd_implementation!(FullL2, u8, u8);
266
267//////////////////
268// InnerProduct //
269//////////////////
270
271/// Compute the inner product between two vectors.
272#[derive(Debug, Clone, Copy, Default)]
273pub struct InnerProduct {}
274
275impl PostOp<f32, SimilarityScore<f32>> for InnerProduct {
276    // The low-level operations compute the mathematical dot product.
277    // Similarity scores used in DiskANN expect the InnerProduct to be negated.
278    // This PostOp does that negation.
279    #[inline(always)]
280    fn post_op(x: f32) -> SimilarityScore<f32> {
281        SimilarityScore::new(-x)
282    }
283}
284
285impl PostOp<f32, MathematicalValue<f32>> for InnerProduct {
286    #[inline(always)]
287    fn post_op(x: f32) -> MathematicalValue<f32> {
288        MathematicalValue::new(x)
289    }
290}
291
292impl PostOp<f32, f32> for InnerProduct {
293    #[inline(always)]
294    fn post_op(x: f32) -> f32 {
295        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
296    }
297}
298
299architecture_hook!(InnerProduct, simd::IP);
300use_simd_implementation!(InnerProduct, f32, f32);
301use_simd_implementation!(InnerProduct, f32, Half);
302use_simd_implementation!(InnerProduct, Half, Half);
303use_simd_implementation!(InnerProduct, i8, i8);
304use_simd_implementation!(InnerProduct, u8, u8);
305
306////////////
307// Cosine //
308////////////
309
310/// Perform the conversion `x -> 1 - x`.
311///
312/// Don't clamp the output - assume the output is clamped from the inner computation.
313fn cosine_transformation(x: f32) -> f32 {
314    1.0 - x
315}
316
317/// Compute the cosine similarity between two vectors.
318#[derive(Debug, Clone, Copy, Default)]
319pub struct Cosine {}
320
321impl PostOp<f32, SimilarityScore<f32>> for Cosine {
322    fn post_op(x: f32) -> SimilarityScore<f32> {
323        debug_assert!(x >= -1.0);
324        debug_assert!(x <= 1.0);
325        SimilarityScore::new(cosine_transformation(x))
326    }
327}
328
329impl PostOp<f32, MathematicalValue<f32>> for Cosine {
330    fn post_op(x: f32) -> MathematicalValue<f32> {
331        debug_assert!(x >= -1.0);
332        debug_assert!(x <= 1.0);
333        MathematicalValue::new(x)
334    }
335}
336
337impl PostOp<f32, f32> for Cosine {
338    fn post_op(x: f32) -> f32 {
339        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
340    }
341}
342
343architecture_hook!(Cosine, simd::CosineStateless);
344use_simd_implementation!(Cosine, f32, f32);
345use_simd_implementation!(Cosine, f32, Half);
346use_simd_implementation!(Cosine, Half, Half);
347use_simd_implementation!(Cosine, i8, i8);
348use_simd_implementation!(Cosine, u8, u8);
349
350//////////////////////
351// CosineNormalized //
352//////////////////////
353
354/// Compute the cosine similarity between two normalized vectors.
355#[derive(Debug, Clone, Copy, Default)]
356pub struct CosineNormalized {}
357
358impl PostOp<f32, SimilarityScore<f32>> for CosineNormalized {
359    #[inline(always)]
360    fn post_op(x: f32) -> SimilarityScore<f32> {
361        // If the vectors are assumed to be normalized, then the implementation of
362        // normalized cosine can be expressed in terms of an inner product inner loop.
363        //
364        // Don't use `clamp` at the end since the simple non-vector implementations do not
365        // clamp their outputs.
366        SimilarityScore::new(cosine_transformation(x))
367    }
368}
369
370impl PostOp<f32, MathematicalValue<f32>> for CosineNormalized {
371    #[inline(always)]
372    fn post_op(x: f32) -> MathematicalValue<f32> {
373        MathematicalValue::new(x)
374    }
375}
376
377impl PostOp<f32, f32> for CosineNormalized {
378    #[inline(always)]
379    fn post_op(x: f32) -> f32 {
380        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
381    }
382}
383
384architecture_hook!(CosineNormalized, simd::IP);
385use_simd_implementation!(CosineNormalized, f32, f32);
386use_simd_implementation!(CosineNormalized, f32, Half);
387use_simd_implementation!(CosineNormalized, Half, Half);
388
389////////////
390// L1Norm //
391////////////
392
393/// Compute the L1 norm of a vector.
394#[derive(Debug, Clone, Copy, Default)]
395pub struct L1NormFunctor {}
396
397impl PostOp<f32, f32> for L1NormFunctor {
398    #[inline(always)]
399    fn post_op(x: f32) -> f32 {
400        x
401    }
402}
403
404architecture_hook!(L1NormFunctor, simd::L1Norm);
405
406impl PureDistanceFunction<&[f32], &[f32], f32> for L1NormFunctor {
407    #[inline]
408    fn evaluate(x: &[f32], y: &[f32]) -> f32 {
409        L1NormFunctor::default().run(ARCH, x, y)
410    }
411}
412
413////////////
414// Tests //
415////////////
416
417#[cfg(test)]
418mod tests {
419
420    use std::hash::{Hash, Hasher};
421
422    use approx::assert_relative_eq;
423    use rand::{Rng, SeedableRng};
424
425    use super::*;
426    use crate::{
427        distance::{
428            reference::{self, ReferenceProvider},
429            Metric,
430        },
431        test_util::{self, Normalize},
432    };
433
434    pub fn as_function_pointer<T, Left, Right, Return>(x: &[Left], y: &[Right]) -> Return
435    where
436        T: for<'a, 'b> PureDistanceFunction<&'a [Left], &'b [Right], Return>,
437    {
438        T::evaluate(x, y)
439    }
440
441    fn simd_provider(metric: Metric) -> fn(&[f32], &[f32]) -> f32 {
442        match metric {
443            Metric::L2 => as_function_pointer::<SquaredL2, _, _, _>,
444            Metric::InnerProduct => as_function_pointer::<InnerProduct, _, _, _>,
445            Metric::Cosine => as_function_pointer::<Cosine, _, _, _>,
446            Metric::CosineNormalized => as_function_pointer::<CosineNormalized, _, _, _>,
447        }
448    }
449
450    fn random_normal_arguments(dim: usize, lo: f32, hi: f32, seed: u64) -> (Vec<f32>, Vec<f32>) {
451        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
452        let x: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
453        let y: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
454        (x, y)
455    }
456
457    struct LeftRightPair {
458        pub x: Vec<f32>,
459        pub y: Vec<f32>,
460    }
461
462    fn generate_corner_cases(dim: usize) -> Vec<LeftRightPair> {
463        let mut output = Vec::<LeftRightPair>::new();
464        let fixed_values = [0.0, -5.0, 5.0, 10.0];
465
466        for va in fixed_values.iter() {
467            for vb in fixed_values.iter() {
468                let x: Vec<f32> = vec![*va; dim];
469                let y: Vec<f32> = vec![*vb; dim];
470                output.push(LeftRightPair { x, y });
471            }
472        }
473        output
474    }
475
476    fn collect_random_arguments(
477        dim: usize,
478        num_trials: usize,
479        lo: f32,
480        hi: f32,
481        mut seed: u64,
482    ) -> Vec<LeftRightPair> {
483        (0..num_trials)
484            .map(|_| {
485                let (x, y) = random_normal_arguments(dim, lo, hi, seed);
486
487                // update the seed.
488                let mut hasher = std::hash::DefaultHasher::new();
489                seed.hash(&mut hasher);
490                seed = hasher.finish();
491
492                LeftRightPair { x, y }
493            })
494            .collect()
495    }
496
497    fn test_pure_functions_impl<T>(metric: Metric, _func: T, normalize: bool)
498    where
499        T: for<'a, 'b> PureDistanceFunction<&'a [f32], &'b [f32], f32> + Clone,
500    {
501        let epsilon: f32 = 1e-4;
502        let max_relative: f32 = 1e-4;
503
504        let max_dim = 256;
505        let num_trials = 10;
506
507        let f_reference = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
508        let f_simd = simd_provider(metric);
509
510        // Inner test that loops over a vector of arguments.
511        let run_tests = |argument_pairs: Vec<LeftRightPair>| {
512            for LeftRightPair { mut x, mut y } in argument_pairs {
513                if normalize {
514                    x.normalize();
515                    y.normalize();
516                }
517
518                let reference: f32 = f_reference(&x, &y).into_inner();
519                let simd = f_simd(&x, &y);
520
521                assert_relative_eq!(
522                    reference,
523                    simd,
524                    epsilon = epsilon,
525                    max_relative = max_relative
526                );
527
528                // Compute via direct call.
529                let simd_direct = T::evaluate(&x, &y);
530                assert_eq!(simd_direct, simd);
531            }
532        };
533
534        // Corner Cases
535        for dim in 0..max_dim {
536            run_tests(generate_corner_cases(dim));
537        }
538
539        // Generated tests
540        for dim in 0..max_dim {
541            run_tests(collect_random_arguments(
542                dim, num_trials, -10.0, 10.0, 0x5643,
543            ));
544        }
545    }
546
547    #[test]
548    fn test_pure_functions() {
549        println!("L2");
550        test_pure_functions_impl(Metric::L2, SquaredL2 {}, false);
551        println!("InnerProduct");
552        test_pure_functions_impl(Metric::InnerProduct, InnerProduct {}, false);
553        println!("Cosine");
554        test_pure_functions_impl(Metric::Cosine, Cosine {}, false);
555        println!("CosineNormalized");
556        test_pure_functions_impl(Metric::CosineNormalized, CosineNormalized {}, true);
557    }
558
559    /// Test that the constant function pointer implementation returns the same result as
560    /// non-sized counterpart..
561    #[test]
562    fn test_specialize() {
563        use diskann_wide::arch::FTarget2;
564
565        const DIM: usize = 123;
566        let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);
567
568        let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
569        let evaluated: f32 =
570            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
571
572        // Equality should be exact.
573        assert_eq!(reference, evaluated);
574    }
575
576    #[test]
577    #[should_panic]
578    fn test_function_pointer_const_panics_left() {
579        use diskann_wide::arch::FTarget2;
580
581        const DIM: usize = 34;
582        let x = vec![0.0f32; DIM + 1];
583        let y = vec![0.0f32; DIM];
584        // Since `x` does not have the correct dimensions, this should panic.
585        let _: f32 =
586            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
587    }
588
589    #[test]
590    #[should_panic]
591    fn test_function_pointer_const_panics_right() {
592        use diskann_wide::arch::FTarget2;
593
594        const DIM: usize = 34;
595        let x = vec![0.0f32; DIM];
596        let y = vec![0.0f32; DIM + 1];
597        // Since `y` does not have the correct dimensions, this should panic.
598        let _: f32 =
599            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
600    }
601
602    ////////////////////
603    // Test Version 2 //
604    ////////////////////
605
606    trait GetInner {
607        fn get_inner(self) -> f32;
608    }
609
610    impl GetInner for f32 {
611        fn get_inner(self) -> f32 {
612            self
613        }
614    }
615
616    impl GetInner for SimilarityScore<f32> {
617        fn get_inner(self) -> f32 {
618            self.into_inner()
619        }
620    }
621
622    impl GetInner for MathematicalValue<f32> {
623        fn get_inner(self) -> f32 {
624            self.into_inner()
625        }
626    }
627
628    // Comparison Bounds
629    #[derive(Clone, Copy)]
630    struct EpsilonAndRelative {
631        epsilon: f32,
632        max_relative: f32,
633    }
634
635    #[allow(clippy::too_many_arguments)]
636    fn run_test<L, R, To, Distribution, Callback>(
637        under_test: fn(&[L], &[R]) -> To,
638        reference: fn(&[L], &[R]) -> To,
639        bounds: EpsilonAndRelative,
640        dim: usize,
641        num_trials: usize,
642        distribution: Distribution,
643        rng: &mut impl Rng,
644        mut cb: Callback,
645    ) where
646        L: test_util::CornerCases,
647        R: test_util::CornerCases,
648        Distribution:
649            test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
650        To: GetInner + Copy,
651        Callback: FnMut(To, To),
652    {
653        let checker =
654            test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
655                // Invoke the callback with the received numbers.
656                cb(got, expected);
657                assert_relative_eq!(
658                    got.get_inner(),
659                    expected.get_inner(),
660                    epsilon = bounds.epsilon,
661                    max_relative = bounds.max_relative
662                );
663            });
664
665        test_util::test_distance_function(
666            checker,
667            distribution.clone(),
668            distribution.clone(),
669            dim,
670            num_trials,
671            rng,
672        );
673    }
674
675    /// The maximum dimension tested for these tests.
676    #[cfg(not(debug_assertions))]
677    const MAX_DIM: usize = 256;
678
679    #[cfg(debug_assertions)]
680    const MAX_DIM: usize = 160;
681
682    // Decrease the number of trials in debug mode to keep test run-time down.
683    #[cfg(not(debug_assertions))]
684    const INTEGER_TRIALS: usize = 10000;
685
686    #[cfg(debug_assertions)]
687    const INTEGER_TRIALS: usize = 100;
688
689    ////////////////////
690    // Integer Tester //
691    ////////////////////
692
693    // For integer tests - we expect exact reproducibility with the reference
694    // implementations.
695    fn run_integer_test<T, R>(
696        under_test: fn(&[T], &[T]) -> R,
697        reference: fn(&[T], &[T]) -> R,
698        rng: &mut impl Rng,
699    ) where
700        T: test_util::CornerCases,
701        R: GetInner + Copy,
702        rand::distr::StandardUniform: test_util::GenerateRandomArguments<T> + Clone,
703    {
704        let distribution = rand::distr::StandardUniform {};
705        let num_corner_cases = <T as test_util::CornerCases>::corner_cases().len();
706
707        for dim in 0..MAX_DIM {
708            let mut callcount = 0;
709            let callback = |_, _| {
710                callcount += 1;
711            };
712
713            run_test(
714                under_test,
715                reference,
716                EpsilonAndRelative {
717                    epsilon: 0.0,
718                    max_relative: 0.0,
719                },
720                dim,
721                INTEGER_TRIALS,
722                distribution,
723                rng,
724                callback,
725            );
726
727            // Make sure the expected number of callbacks were made.
728            assert_eq!(
729                callcount,
730                INTEGER_TRIALS + num_corner_cases * num_corner_cases
731            );
732        }
733    }
734
735    //////////////////
736    // L2 - Integer //
737    //////////////////
738
739    #[test]
740    fn test_l2_i8_mathematical() {
741        let mut rng = rand::rngs::StdRng::seed_from_u64(0x2bb701074c2b81c9);
742        run_integer_test(
743            as_function_pointer::<FullL2, i8, i8, MathematicalValue<f32>>,
744            reference::reference_l2_i8_mathematical,
745            &mut rng,
746        );
747    }
748
749    #[test]
750    fn test_l2_u8_mathematical() {
751        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9284ced6d080808c);
752        run_integer_test(
753            as_function_pointer::<FullL2, u8, u8, MathematicalValue<f32>>,
754            reference::reference_l2_u8_mathematical,
755            &mut rng,
756        );
757    }
758
759    #[test]
760    fn test_l2_i8_similarity() {
761        let mut rng = rand::rngs::StdRng::seed_from_u64(0xb196fecc4def04fa);
762        run_integer_test(
763            as_function_pointer::<FullL2, i8, i8, SimilarityScore<f32>>,
764            reference::reference_l2_i8_similarity,
765            &mut rng,
766        );
767    }
768
769    #[test]
770    fn test_l2_u8_similarity() {
771        let mut rng = rand::rngs::StdRng::seed_from_u64(0x07f6463e4a654aea);
772        run_integer_test(
773            as_function_pointer::<FullL2, u8, u8, SimilarityScore<f32>>,
774            reference::reference_l2_u8_similarity,
775            &mut rng,
776        );
777    }
778
779    ////////////////////////////
780    // InnerProduct - Integer //
781    ////////////////////////////
782
783    #[test]
784    fn test_innerproduct_i8_mathematical() {
785        let mut rng = rand::rngs::StdRng::seed_from_u64(0x2c1b1bddda5774be);
786        run_integer_test(
787            as_function_pointer::<InnerProduct, i8, i8, MathematicalValue<f32>>,
788            reference::reference_innerproduct_i8_mathematical,
789            &mut rng,
790        );
791    }
792
793    #[test]
794    fn test_innerproduct_u8_mathematical() {
795        let mut rng = rand::rngs::StdRng::seed_from_u64(0x757e363832d7f215);
796        run_integer_test(
797            as_function_pointer::<InnerProduct, u8, u8, MathematicalValue<f32>>,
798            reference::reference_innerproduct_u8_mathematical,
799            &mut rng,
800        );
801    }
802
803    #[test]
804    fn test_innerproduct_i8_similarity() {
805        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4788ce0b991eb15a);
806        run_integer_test(
807            as_function_pointer::<InnerProduct, i8, i8, SimilarityScore<f32>>,
808            reference::reference_innerproduct_i8_similarity,
809            &mut rng,
810        );
811    }
812
813    #[test]
814    fn test_innerproduct_u8_similarity() {
815        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4994adb68f814d96);
816        run_integer_test(
817            as_function_pointer::<InnerProduct, u8, u8, SimilarityScore<f32>>,
818            reference::reference_innerproduct_u8_similarity,
819            &mut rng,
820        );
821    }
822
823    //////////////////////
824    // Cosine - Integer //
825    //////////////////////
826
827    #[test]
828    fn test_cosine_i8_mathematical() {
829        let mut rng = rand::rngs::StdRng::seed_from_u64(0xedef81c780491ada);
830        run_integer_test(
831            as_function_pointer::<Cosine, i8, i8, MathematicalValue<f32>>,
832            reference::reference_cosine_i8_mathematical,
833            &mut rng,
834        );
835    }
836
837    #[test]
838    fn test_cosine_u8_mathematical() {
839        let mut rng = rand::rngs::StdRng::seed_from_u64(0x107cee2adcc58b73);
840        run_integer_test(
841            as_function_pointer::<Cosine, u8, u8, MathematicalValue<f32>>,
842            reference::reference_cosine_u8_mathematical,
843            &mut rng,
844        );
845    }
846
847    #[test]
848    fn test_cosine_i8_similarity() {
849        let mut rng = rand::rngs::StdRng::seed_from_u64(0x02d95c1cc0843647);
850        run_integer_test(
851            as_function_pointer::<Cosine, i8, i8, SimilarityScore<f32>>,
852            reference::reference_cosine_i8_similarity,
853            &mut rng,
854        );
855    }
856
857    #[test]
858    fn test_cosine_u8_similarity() {
859        let mut rng = rand::rngs::StdRng::seed_from_u64(0xf5ea1974bf8d8b3b);
860        run_integer_test(
861            as_function_pointer::<Cosine, u8, u8, SimilarityScore<f32>>,
862            reference::reference_cosine_u8_similarity,
863            &mut rng,
864        );
865    }
866
867    //////////////////
868    // Float Tester //
869    //////////////////
870
871    // For integer tests - we expect exact reproducibility with the reference
872    // implementations.
873    fn run_float_test<L, R, To, Dist>(
874        under_test: fn(&[L], &[R]) -> To,
875        reference: fn(&[L], &[R]) -> To,
876        rng: &mut impl Rng,
877        distribution: Dist,
878        bounds: EpsilonAndRelative,
879    ) where
880        L: test_util::CornerCases,
881        R: test_util::CornerCases,
882        To: GetInner + Copy,
883        Dist: test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
884    {
885        let left_corner_cases = <L as test_util::CornerCases>::corner_cases().len();
886        let right_corner_cases = <R as test_util::CornerCases>::corner_cases().len();
887        for dim in 0..MAX_DIM {
888            let mut callcount = 0;
889            let callback = |_, _| {
890                callcount += 1;
891            };
892
893            run_test(
894                under_test,
895                reference,
896                bounds,
897                dim,
898                INTEGER_TRIALS,
899                distribution.clone(),
900                rng,
901                callback,
902            );
903
904            // Make sure the expected number of callbacks were made.
905            assert_eq!(
906                callcount,
907                INTEGER_TRIALS + left_corner_cases * right_corner_cases
908            );
909        }
910    }
911
912    ////////////////
913    // L2 - Float //
914    ////////////////
915
916    fn expected_l2_errors() -> EpsilonAndRelative {
917        EpsilonAndRelative {
918            epsilon: 0.0,
919            max_relative: 1.2e-6,
920        }
921    }
922
923    #[test]
924    fn test_l2_f32_mathematical() {
925        let mut rng = rand::rngs::StdRng::seed_from_u64(0x6d22d320bdf35aec);
926        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
927        run_float_test(
928            as_function_pointer::<FullL2, f32, f32, MathematicalValue<f32>>,
929            reference::reference_l2_f32_mathematical,
930            &mut rng,
931            distribution,
932            expected_l2_errors(),
933        );
934    }
935
936    #[test]
937    fn test_l2_f16_mathematical() {
938        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
939        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
940        run_float_test(
941            as_function_pointer::<FullL2, Half, Half, MathematicalValue<f32>>,
942            reference::reference_l2_f16_mathematical,
943            &mut rng,
944            distribution,
945            expected_l2_errors(),
946        );
947    }
948
949    #[test]
950    fn test_l2_f32xf16_mathematical() {
951        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
952        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
953
954        run_float_test(
955            as_function_pointer::<FullL2, f32, Half, MathematicalValue<f32>>,
956            reference::reference_l2_f32xf16_mathematical,
957            &mut rng,
958            distribution,
959            expected_l2_errors(),
960        );
961    }
962
963    #[test]
964    fn test_l2_f32_similarity() {
965        let mut rng = rand::rngs::StdRng::seed_from_u64(0xbfc5f4b42b5bc0c1);
966        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
967        run_float_test(
968            as_function_pointer::<FullL2, f32, f32, SimilarityScore<f32>>,
969            reference::reference_l2_f32_similarity,
970            &mut rng,
971            distribution,
972            expected_l2_errors(),
973        );
974    }
975
976    #[test]
977    fn test_l2_f16_similarity() {
978        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9d3809d84f54e4b6);
979        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
980        run_float_test(
981            as_function_pointer::<FullL2, Half, Half, SimilarityScore<f32>>,
982            reference::reference_l2_f16_similarity,
983            &mut rng,
984            distribution,
985            expected_l2_errors(),
986        );
987    }
988
989    #[test]
990    fn test_l2_f32xf16_similarity() {
991        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
992        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
993
994        run_float_test(
995            as_function_pointer::<FullL2, f32, Half, SimilarityScore<f32>>,
996            reference::reference_l2_f32xf16_similarity,
997            &mut rng,
998            distribution,
999            expected_l2_errors(),
1000        );
1001    }
1002
1003    ///////////////////////////
1004    // InnerProduct - Floats //
1005    ///////////////////////////
1006
1007    fn expected_innerproduct_errors() -> EpsilonAndRelative {
1008        EpsilonAndRelative {
1009            epsilon: 2.5e-5,
1010            max_relative: 1.6e-5,
1011        }
1012    }
1013
1014    #[test]
1015    fn test_innerproduct_f32_mathematical() {
1016        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1ef6ac3b65869792);
1017        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1018        run_float_test(
1019            as_function_pointer::<InnerProduct, f32, f32, MathematicalValue<f32>>,
1020            reference::reference_innerproduct_f32_mathematical,
1021            &mut rng,
1022            distribution,
1023            expected_innerproduct_errors(),
1024        );
1025    }
1026
1027    #[test]
1028    fn test_innerproduct_f16_mathematical() {
1029        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1030        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1031        run_float_test(
1032            as_function_pointer::<InnerProduct, Half, Half, MathematicalValue<f32>>,
1033            reference::reference_innerproduct_f16_mathematical,
1034            &mut rng,
1035            distribution,
1036            expected_innerproduct_errors(),
1037        );
1038    }
1039
1040    #[test]
1041    fn test_innerproduct_f32xf16_mathematical() {
1042        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1043        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1044        run_float_test(
1045            as_function_pointer::<InnerProduct, f32, Half, MathematicalValue<f32>>,
1046            reference::reference_innerproduct_f32xf16_mathematical,
1047            &mut rng,
1048            distribution,
1049            expected_innerproduct_errors(),
1050        );
1051    }
1052
1053    #[test]
1054    fn test_innerproduct_f32_similarity() {
1055        let mut rng = rand::rngs::StdRng::seed_from_u64(0x40326b22a57db0d7);
1056        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1057        run_float_test(
1058            as_function_pointer::<InnerProduct, f32, f32, SimilarityScore<f32>>,
1059            reference::reference_innerproduct_f32_similarity,
1060            &mut rng,
1061            distribution,
1062            expected_innerproduct_errors(),
1063        );
1064    }
1065
1066    #[test]
1067    fn test_innerproduct_f16_similarity() {
1068        let mut rng = rand::rngs::StdRng::seed_from_u64(0xfb8cff47bcbc9528);
1069        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1070        run_float_test(
1071            as_function_pointer::<InnerProduct, Half, Half, SimilarityScore<f32>>,
1072            reference::reference_innerproduct_f16_similarity,
1073            &mut rng,
1074            distribution,
1075            expected_innerproduct_errors(),
1076        );
1077    }
1078
1079    #[test]
1080    fn test_innerproduct_f32xf16_similarity() {
1081        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1082        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1083        run_float_test(
1084            as_function_pointer::<InnerProduct, f32, Half, SimilarityScore<f32>>,
1085            reference::reference_innerproduct_f32xf16_similarity,
1086            &mut rng,
1087            distribution,
1088            expected_innerproduct_errors(),
1089        );
1090    }
1091
1092    /////////////////////
1093    // Cosine - Floats //
1094    /////////////////////
1095
1096    fn expected_cosine_errors() -> EpsilonAndRelative {
1097        EpsilonAndRelative {
1098            epsilon: 3e-7,
1099            max_relative: 5e-6,
1100        }
1101    }
1102
1103    #[test]
1104    fn test_cosine_f32_mathematical() {
1105        let mut rng = rand::rngs::StdRng::seed_from_u64(0xca6eaac942999500);
1106        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1107        run_float_test(
1108            as_function_pointer::<Cosine, f32, f32, MathematicalValue<f32>>,
1109            reference::reference_cosine_f32_mathematical,
1110            &mut rng,
1111            distribution,
1112            expected_cosine_errors(),
1113        );
1114    }
1115
1116    #[test]
1117    fn test_cosine_f16_mathematical() {
1118        let mut rng = rand::rngs::StdRng::seed_from_u64(0xa736c789aa16ce86);
1119        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1120        run_float_test(
1121            as_function_pointer::<Cosine, Half, Half, MathematicalValue<f32>>,
1122            reference::reference_cosine_f16_mathematical,
1123            &mut rng,
1124            distribution,
1125            expected_cosine_errors(),
1126        );
1127    }
1128
1129    #[test]
1130    fn test_cosine_f32xf16_mathematical() {
1131        let mut rng = rand::rngs::StdRng::seed_from_u64(0xac550231088a0d5c);
1132        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1133        run_float_test(
1134            as_function_pointer::<Cosine, f32, Half, MathematicalValue<f32>>,
1135            reference::reference_cosine_f32xf16_mathematical,
1136            &mut rng,
1137            distribution,
1138            expected_cosine_errors(),
1139        );
1140    }
1141
1142    #[test]
1143    fn test_cosine_f32_similarity() {
1144        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4a09ad987a6204f3);
1145        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1146        run_float_test(
1147            as_function_pointer::<Cosine, f32, f32, SimilarityScore<f32>>,
1148            reference::reference_cosine_f32_similarity,
1149            &mut rng,
1150            distribution,
1151            expected_cosine_errors(),
1152        );
1153    }
1154
1155    #[test]
1156    fn test_cosine_f16_similarity() {
1157        let mut rng = rand::rngs::StdRng::seed_from_u64(0x77a48d1914f850f2);
1158        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1159        run_float_test(
1160            as_function_pointer::<Cosine, Half, Half, SimilarityScore<f32>>,
1161            reference::reference_cosine_f16_similarity,
1162            &mut rng,
1163            distribution,
1164            expected_cosine_errors(),
1165        );
1166    }
1167
1168    #[test]
1169    fn test_cosine_f32xf16_similarity() {
1170        let mut rng = rand::rngs::StdRng::seed_from_u64(0xbd7471b815655ca1);
1171        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1172        run_float_test(
1173            as_function_pointer::<Cosine, f32, Half, SimilarityScore<f32>>,
1174            reference::reference_cosine_f32xf16_similarity,
1175            &mut rng,
1176            distribution,
1177            expected_cosine_errors(),
1178        );
1179    }
1180
1181    ///////////////////////////////
1182    // CosineNormalized - Floats //
1183    ///////////////////////////////
1184
1185    fn expected_cosine_normalized_errors() -> EpsilonAndRelative {
1186        EpsilonAndRelative {
1187            epsilon: 3e-7,
1188            max_relative: 5e-6,
1189        }
1190    }
1191
1192    #[test]
1193    fn test_cosine_normalized_f32_mathematical() {
1194        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1fda98112747f8dd);
1195        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1196        run_float_test(
1197            as_function_pointer::<CosineNormalized, f32, f32, MathematicalValue<f32>>,
1198            reference::reference_cosine_normalized_f32_mathematical,
1199            &mut rng,
1200            test_util::Normalized(distribution),
1201            expected_cosine_normalized_errors(),
1202        );
1203    }
1204
1205    #[test]
1206    fn test_cosine_normalized_f16_mathematical() {
1207        let mut rng = rand::rngs::StdRng::seed_from_u64(0x5e8c5d5e19cdd840);
1208        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1209        run_float_test(
1210            as_function_pointer::<CosineNormalized, Half, Half, MathematicalValue<f32>>,
1211            reference::reference_cosine_normalized_f16_mathematical,
1212            &mut rng,
1213            test_util::Normalized(distribution),
1214            expected_cosine_normalized_errors(),
1215        );
1216    }
1217
1218    #[test]
1219    fn test_cosine_normalized_f32xf16_mathematical() {
1220        let mut rng = rand::rngs::StdRng::seed_from_u64(0x3fd01e1c11c9bc45);
1221        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1222        run_float_test(
1223            as_function_pointer::<CosineNormalized, f32, Half, MathematicalValue<f32>>,
1224            reference::reference_cosine_normalized_f32xf16_mathematical,
1225            &mut rng,
1226            test_util::Normalized(distribution),
1227            expected_cosine_normalized_errors(),
1228        );
1229    }
1230
1231    #[test]
1232    fn test_cosine_normalized_f32_similarity() {
1233        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9446d057870e5605);
1234        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1235        run_float_test(
1236            as_function_pointer::<CosineNormalized, f32, f32, SimilarityScore<f32>>,
1237            reference::reference_cosine_normalized_f32_similarity,
1238            &mut rng,
1239            test_util::Normalized(distribution),
1240            expected_cosine_normalized_errors(),
1241        );
1242    }
1243
1244    #[test]
1245    fn test_cosine_normalized_f16_similarity() {
1246        let mut rng = rand::rngs::StdRng::seed_from_u64(0x885c371801f18174);
1247        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1248        run_float_test(
1249            as_function_pointer::<CosineNormalized, Half, Half, SimilarityScore<f32>>,
1250            reference::reference_cosine_normalized_f16_similarity,
1251            &mut rng,
1252            test_util::Normalized(distribution),
1253            expected_cosine_normalized_errors(),
1254        );
1255    }
1256
1257    #[test]
1258    fn test_cosine_normalized_f32xf16_similarity() {
1259        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1c356c92d0522c0f);
1260        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1261        run_float_test(
1262            as_function_pointer::<CosineNormalized, f32, Half, SimilarityScore<f32>>,
1263            reference::reference_cosine_normalized_f32xf16_similarity,
1264            &mut rng,
1265            test_util::Normalized(distribution),
1266            expected_cosine_normalized_errors(),
1267        );
1268    }
1269}