Skip to main content

diskann_vector/distance/
implementations.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::{arch::Target2, Architecture, ARCH};
7
8/// Experimental traits for distance functions.
9use super::simd;
10use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};
11
12trait ToSlice {
13    type Target;
14    fn to_slice(&self) -> &[Self::Target];
15}
16
17impl<T> ToSlice for &[T] {
18    type Target = T;
19    fn to_slice(&self) -> &[T] {
20        self
21    }
22}
23impl<T, const N: usize> ToSlice for &[T; N] {
24    type Target = T;
25    fn to_slice(&self) -> &[T] {
26        &self[..]
27    }
28}
29impl<T, const N: usize> ToSlice for [T; N] {
30    type Target = T;
31    fn to_slice(&self) -> &[T] {
32        &self[..]
33    }
34}
35
36macro_rules! architecture_hook {
37    ($functor:ty, $impl:path) => {
38        impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
39        where
40            A: Architecture,
41            L: ToSlice,
42            R: ToSlice,
43            $impl: simd::SIMDSchema<L::Target, R::Target, A>,
44            Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
45        {
46            #[inline(always)]
47            fn run(self, arch: A, left: L, right: R) -> T {
48                Self::post_op(simd::simd_op(
49                    &$impl,
50                    arch,
51                    left.to_slice(),
52                    right.to_slice(),
53                ))
54            }
55        }
56
57        impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
58        where
59            A: Architecture,
60            L: ToSlice,
61            R: ToSlice,
62            Self: diskann_wide::arch::Target2<A, T, L, R>,
63        {
64            #[inline(always)]
65            fn run(arch: A, left: L, right: R) -> T {
66                arch.run2(Self::default(), left, right)
67            }
68        }
69    };
70}
71
72/// A utility for specializing distance computatiosn for fixed-length slices.
73#[cfg(any(test, target_arch = "x86_64"))]
74#[derive(Debug, Clone, Copy)]
75pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);
76
77#[cfg(any(test, target_arch = "x86_64"))]
78impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
79    for Specialize<N, F>
80where
81    A: Architecture,
82    F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
83{
84    #[inline(always)]
85    fn run(arch: A, x: &[L], y: &[R]) -> T {
86        if (x.len() != N) | (y.len() != N) {
87            fail_length_check(x, y, N);
88        }
89
90        // SAFETY: We have checked that both arguments have the correct length.
91        //
92        // The alignment requirements of arrays are the alignment requirements of
93        // `Left` and `Right` respectively, which is provided by the corresponding slices.
94        arch.run2(
95            F::default(),
96            unsafe { &*(x.as_ptr() as *const [L; N]) },
97            unsafe { &*(y.as_ptr() as *const [R; N]) },
98        )
99    }
100}
101
102// Outline the panic formatting and keep the calling convention the same as
103// the top function. This keeps code generation extremely lightweight.
104#[cfg(any(test, target_arch = "x86_64"))]
105#[inline(never)]
106#[allow(clippy::panic)]
107fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
108    let message = if x.len() != len {
109        ("first", x.len())
110    } else {
111        ("second", y.len())
112    };
113    panic!(
114        "expected {} argument to have length {}, instead it has length {}",
115        message.0, len, message.1
116    );
117}
118
119/// An internal trait to transform the result of the low-level SIMD ops into a value
120/// expected by the rest of DiskANN.
121///
122/// Keep this trait private as it is likely to either change or be removed completely in the
123/// near future once better integer implementations come online.
124pub(super) trait PostOp<From, To> {
125    fn post_op(x: From) -> To;
126}
127
128/// Provide explicit dynamic and sized implementations for a distance functor.
129macro_rules! use_simd_implementation {
130    ($functor:ty, $T:ty, $U:ty) => {
131        //////////////////////
132        // Similarity Score //
133        //////////////////////
134
135        // Dynamically Sized.
136        impl PureDistanceFunction<&[$T], &[$U], SimilarityScore<f32>> for $functor {
137            #[inline]
138            fn evaluate(x: &[$T], y: &[$U]) -> SimilarityScore<f32> {
139                <$functor>::default().run(ARCH, x, y)
140            }
141        }
142        // Statically Sized
143        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
144            for $functor
145        {
146            #[inline]
147            fn evaluate(x: &[$T; N], y: &[$U; N]) -> SimilarityScore<f32> {
148                <$functor>::default().run(ARCH, x, y)
149            }
150        }
151
152        ////////////////////////
153        // Mathematical Value //
154        ////////////////////////
155
156        // Dynamically Sized.
157        impl PureDistanceFunction<&[$T], &[$U], MathematicalValue<f32>> for $functor {
158            #[inline]
159            fn evaluate(x: &[$T], y: &[$U]) -> MathematicalValue<f32> {
160                <$functor>::default().run(ARCH, x, y)
161            }
162        }
163        // Statically Sized
164        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], MathematicalValue<f32>>
165            for $functor
166        {
167            #[inline]
168            fn evaluate(x: &[$T; N], y: &[$U; N]) -> MathematicalValue<f32> {
169                <$functor>::default().run(ARCH, x, y)
170            }
171        }
172
173        /////////
174        // f32 //
175        /////////
176
177        // Dynamically Sized
178        impl PureDistanceFunction<&[$T], &[$U], f32> for $functor {
179            #[inline(always)]
180            fn evaluate(x: &[$T], y: &[$U]) -> f32 {
181                <$functor>::default().run(ARCH, x, y)
182            }
183        }
184
185        // Statically Sized
186        impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], f32> for $functor {
187            #[inline]
188            fn evaluate(x: &[$T; N], y: &[$U; N]) -> f32 {
189                <$functor>::default().run(ARCH, x, y)
190            }
191        }
192    };
193}
194
195///////////////
196// SquaredL2 //
197///////////////
198
199/// Compute the squared L2 distance between two vectors.
200#[derive(Debug, Clone, Copy, Default)]
201pub struct SquaredL2 {}
202
203impl PostOp<f32, SimilarityScore<f32>> for SquaredL2 {
204    #[inline(always)]
205    fn post_op(x: f32) -> SimilarityScore<f32> {
206        SimilarityScore::new(x)
207    }
208}
209
210impl PostOp<f32, f32> for SquaredL2 {
211    #[inline(always)]
212    fn post_op(x: f32) -> f32 {
213        x
214    }
215}
216
217impl PostOp<f32, MathematicalValue<f32>> for SquaredL2 {
218    #[inline(always)]
219    fn post_op(x: f32) -> MathematicalValue<f32> {
220        MathematicalValue::new(x)
221    }
222}
223
224architecture_hook!(SquaredL2, simd::L2);
225use_simd_implementation!(SquaredL2, f32, f32);
226use_simd_implementation!(SquaredL2, f32, Half);
227use_simd_implementation!(SquaredL2, Half, Half);
228use_simd_implementation!(SquaredL2, i8, i8);
229use_simd_implementation!(SquaredL2, u8, u8);
230
231////////////
232// FullL2 //
233////////////
234
235/// Computes the full L2 distance between two vectors.
236///
237/// Unlike `SquaredL2`, this function-like object will perform compute the full L2 distance
238/// including the trailing square root.
239#[derive(Debug, Clone, Copy, Default)]
240pub struct FullL2 {}
241
242impl PostOp<f32, SimilarityScore<f32>> for FullL2 {
243    #[inline(always)]
244    fn post_op(x: f32) -> SimilarityScore<f32> {
245        SimilarityScore::new(x.sqrt())
246    }
247}
248
249impl PostOp<f32, f32> for FullL2 {
250    #[inline(always)]
251    fn post_op(x: f32) -> f32 {
252        x.sqrt()
253    }
254}
255
256impl PostOp<f32, MathematicalValue<f32>> for FullL2 {
257    #[inline(always)]
258    fn post_op(x: f32) -> MathematicalValue<f32> {
259        MathematicalValue::new(x.sqrt())
260    }
261}
262
263architecture_hook!(FullL2, simd::L2);
264use_simd_implementation!(FullL2, f32, f32);
265use_simd_implementation!(FullL2, f32, Half);
266use_simd_implementation!(FullL2, Half, Half);
267use_simd_implementation!(FullL2, i8, i8);
268use_simd_implementation!(FullL2, u8, u8);
269
270//////////////////
271// InnerProduct //
272//////////////////
273
274/// Compute the inner product between two vectors.
275#[derive(Debug, Clone, Copy, Default)]
276pub struct InnerProduct {}
277
278impl PostOp<f32, SimilarityScore<f32>> for InnerProduct {
279    // The low-level operations compute the mathematical dot product.
280    // Similarity scores used in DiskANN expect the InnerProduct to be negated.
281    // This PostOp does that negation.
282    #[inline(always)]
283    fn post_op(x: f32) -> SimilarityScore<f32> {
284        SimilarityScore::new(-x)
285    }
286}
287
288impl PostOp<f32, MathematicalValue<f32>> for InnerProduct {
289    #[inline(always)]
290    fn post_op(x: f32) -> MathematicalValue<f32> {
291        MathematicalValue::new(x)
292    }
293}
294
295impl PostOp<f32, f32> for InnerProduct {
296    #[inline(always)]
297    fn post_op(x: f32) -> f32 {
298        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
299    }
300}
301
302architecture_hook!(InnerProduct, simd::IP);
303use_simd_implementation!(InnerProduct, f32, f32);
304use_simd_implementation!(InnerProduct, f32, Half);
305use_simd_implementation!(InnerProduct, Half, Half);
306use_simd_implementation!(InnerProduct, i8, i8);
307use_simd_implementation!(InnerProduct, u8, u8);
308
309////////////
310// Cosine //
311////////////
312
313/// Perform the conversion `x -> 1 - x`.
314///
315/// Don't clamp the output - assume the output is clamped from the inner computation.
316fn cosine_transformation(x: f32) -> f32 {
317    1.0 - x
318}
319
320/// Compute the cosine similarity between two vectors.
321#[derive(Debug, Clone, Copy, Default)]
322pub struct Cosine {}
323
324impl PostOp<f32, SimilarityScore<f32>> for Cosine {
325    fn post_op(x: f32) -> SimilarityScore<f32> {
326        debug_assert!(x >= -1.0);
327        debug_assert!(x <= 1.0);
328        SimilarityScore::new(cosine_transformation(x))
329    }
330}
331
332impl PostOp<f32, MathematicalValue<f32>> for Cosine {
333    fn post_op(x: f32) -> MathematicalValue<f32> {
334        debug_assert!(x >= -1.0);
335        debug_assert!(x <= 1.0);
336        MathematicalValue::new(x)
337    }
338}
339
340impl PostOp<f32, f32> for Cosine {
341    fn post_op(x: f32) -> f32 {
342        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
343    }
344}
345
346architecture_hook!(Cosine, simd::CosineStateless);
347use_simd_implementation!(Cosine, f32, f32);
348use_simd_implementation!(Cosine, f32, Half);
349use_simd_implementation!(Cosine, Half, Half);
350use_simd_implementation!(Cosine, i8, i8);
351use_simd_implementation!(Cosine, u8, u8);
352
353//////////////////////
354// CosineNormalized //
355//////////////////////
356
357/// Compute the cosine similarity between two normalized vectors.
358#[derive(Debug, Clone, Copy, Default)]
359pub struct CosineNormalized {}
360
361impl PostOp<f32, SimilarityScore<f32>> for CosineNormalized {
362    #[inline(always)]
363    fn post_op(x: f32) -> SimilarityScore<f32> {
364        // If the vectors are assumed to be normalized, then the implementation of
365        // normalized cosine can be expressed in terms of an inner product inner loop.
366        //
367        // Don't use `clamp` at the end since the simple non-vector implementations do not
368        // clamp their outputs.
369        SimilarityScore::new(cosine_transformation(x))
370    }
371}
372
373impl PostOp<f32, MathematicalValue<f32>> for CosineNormalized {
374    #[inline(always)]
375    fn post_op(x: f32) -> MathematicalValue<f32> {
376        MathematicalValue::new(x)
377    }
378}
379
380impl PostOp<f32, f32> for CosineNormalized {
381    #[inline(always)]
382    fn post_op(x: f32) -> f32 {
383        <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
384    }
385}
386
387architecture_hook!(CosineNormalized, simd::IP);
388use_simd_implementation!(CosineNormalized, f32, f32);
389use_simd_implementation!(CosineNormalized, f32, Half);
390use_simd_implementation!(CosineNormalized, Half, Half);
391
392////////////
393// L1Norm //
394////////////
395
396/// Compute the L1 norm of a vector.
397#[derive(Debug, Clone, Copy, Default)]
398pub struct L1NormFunctor {}
399
400impl PostOp<f32, f32> for L1NormFunctor {
401    #[inline(always)]
402    fn post_op(x: f32) -> f32 {
403        x
404    }
405}
406
407architecture_hook!(L1NormFunctor, simd::L1Norm);
408
409impl PureDistanceFunction<&[f32], &[f32], f32> for L1NormFunctor {
410    #[inline]
411    fn evaluate(x: &[f32], y: &[f32]) -> f32 {
412        L1NormFunctor::default().run(ARCH, x, y)
413    }
414}
415
416////////////
417// Tests //
418////////////
419
420#[cfg(test)]
421mod tests {
422
423    use std::hash::{Hash, Hasher};
424
425    use approx::assert_relative_eq;
426    use rand::{Rng, SeedableRng};
427
428    use super::*;
429    use crate::{
430        distance::{
431            reference::{self, ReferenceProvider},
432            Metric,
433        },
434        test_util::{self, Normalize},
435    };
436
437    pub fn as_function_pointer<T, Left, Right, Return>(x: &[Left], y: &[Right]) -> Return
438    where
439        T: for<'a, 'b> PureDistanceFunction<&'a [Left], &'b [Right], Return>,
440    {
441        T::evaluate(x, y)
442    }
443
444    fn simd_provider(metric: Metric) -> fn(&[f32], &[f32]) -> f32 {
445        match metric {
446            Metric::L2 => as_function_pointer::<SquaredL2, _, _, _>,
447            Metric::InnerProduct => as_function_pointer::<InnerProduct, _, _, _>,
448            Metric::Cosine => as_function_pointer::<Cosine, _, _, _>,
449            Metric::CosineNormalized => as_function_pointer::<CosineNormalized, _, _, _>,
450        }
451    }
452
453    fn random_normal_arguments(dim: usize, lo: f32, hi: f32, seed: u64) -> (Vec<f32>, Vec<f32>) {
454        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
455        let x: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
456        let y: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
457        (x, y)
458    }
459
460    struct LeftRightPair {
461        pub x: Vec<f32>,
462        pub y: Vec<f32>,
463    }
464
465    fn generate_corner_cases(dim: usize) -> Vec<LeftRightPair> {
466        let mut output = Vec::<LeftRightPair>::new();
467        let fixed_values = [0.0, -5.0, 5.0, 10.0];
468
469        for va in fixed_values.iter() {
470            for vb in fixed_values.iter() {
471                let x: Vec<f32> = vec![*va; dim];
472                let y: Vec<f32> = vec![*vb; dim];
473                output.push(LeftRightPair { x, y });
474            }
475        }
476        output
477    }
478
479    fn collect_random_arguments(
480        dim: usize,
481        num_trials: usize,
482        lo: f32,
483        hi: f32,
484        mut seed: u64,
485    ) -> Vec<LeftRightPair> {
486        (0..num_trials)
487            .map(|_| {
488                let (x, y) = random_normal_arguments(dim, lo, hi, seed);
489
490                // update the seed.
491                let mut hasher = std::hash::DefaultHasher::new();
492                seed.hash(&mut hasher);
493                seed = hasher.finish();
494
495                LeftRightPair { x, y }
496            })
497            .collect()
498    }
499
500    fn test_pure_functions_impl<T>(metric: Metric, _func: T, normalize: bool)
501    where
502        T: for<'a, 'b> PureDistanceFunction<&'a [f32], &'b [f32], f32> + Clone,
503    {
504        let epsilon: f32 = 1e-4;
505        let max_relative: f32 = 1e-4;
506
507        let max_dim = 256;
508        let num_trials = 10;
509
510        let f_reference = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
511        let f_simd = simd_provider(metric);
512
513        // Inner test that loops over a vector of arguments.
514        let run_tests = |argument_pairs: Vec<LeftRightPair>| {
515            for LeftRightPair { mut x, mut y } in argument_pairs {
516                if normalize {
517                    x.normalize();
518                    y.normalize();
519                }
520
521                let reference: f32 = f_reference(&x, &y).into_inner();
522                let simd = f_simd(&x, &y);
523
524                assert_relative_eq!(
525                    reference,
526                    simd,
527                    epsilon = epsilon,
528                    max_relative = max_relative
529                );
530
531                // Compute via direct call.
532                let simd_direct = T::evaluate(&x, &y);
533                assert_eq!(simd_direct, simd);
534            }
535        };
536
537        // Corner Cases
538        for dim in 0..max_dim {
539            run_tests(generate_corner_cases(dim));
540        }
541
542        // Generated tests
543        for dim in 0..max_dim {
544            run_tests(collect_random_arguments(
545                dim, num_trials, -10.0, 10.0, 0x5643,
546            ));
547        }
548    }
549
550    #[test]
551    fn test_pure_functions() {
552        println!("L2");
553        test_pure_functions_impl(Metric::L2, SquaredL2 {}, false);
554        println!("InnerProduct");
555        test_pure_functions_impl(Metric::InnerProduct, InnerProduct {}, false);
556        println!("Cosine");
557        test_pure_functions_impl(Metric::Cosine, Cosine {}, false);
558        println!("CosineNormalized");
559        test_pure_functions_impl(Metric::CosineNormalized, CosineNormalized {}, true);
560    }
561
562    /// Test that the constant function pointer implementation returns the same result as
563    /// non-sized counterpart..
564    #[test]
565    fn test_specialize() {
566        use diskann_wide::arch::FTarget2;
567
568        const DIM: usize = 123;
569        let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);
570
571        let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
572        let evaluated: f32 =
573            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
574
575        // Equality should be exact.
576        assert_eq!(reference, evaluated);
577    }
578
579    #[test]
580    #[should_panic]
581    fn test_function_pointer_const_panics_left() {
582        use diskann_wide::arch::FTarget2;
583
584        const DIM: usize = 34;
585        let x = vec![0.0f32; DIM + 1];
586        let y = vec![0.0f32; DIM];
587        // Since `x` does not have the correct dimensions, this should panic.
588        let _: f32 =
589            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
590    }
591
592    #[test]
593    #[should_panic]
594    fn test_function_pointer_const_panics_right() {
595        use diskann_wide::arch::FTarget2;
596
597        const DIM: usize = 34;
598        let x = vec![0.0f32; DIM];
599        let y = vec![0.0f32; DIM + 1];
600        // Since `y` does not have the correct dimensions, this should panic.
601        let _: f32 =
602            Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
603    }
604
605    ////////////////////
606    // Test Version 2 //
607    ////////////////////
608
609    trait GetInner {
610        fn get_inner(self) -> f32;
611    }
612
613    impl GetInner for f32 {
614        fn get_inner(self) -> f32 {
615            self
616        }
617    }
618
619    impl GetInner for SimilarityScore<f32> {
620        fn get_inner(self) -> f32 {
621            self.into_inner()
622        }
623    }
624
625    impl GetInner for MathematicalValue<f32> {
626        fn get_inner(self) -> f32 {
627            self.into_inner()
628        }
629    }
630
631    // Comparison Bounds
632    #[derive(Clone, Copy)]
633    struct EpsilonAndRelative {
634        epsilon: f32,
635        max_relative: f32,
636    }
637
638    #[allow(clippy::too_many_arguments)]
639    fn run_test<L, R, To, Distribution, Callback>(
640        under_test: fn(&[L], &[R]) -> To,
641        reference: fn(&[L], &[R]) -> To,
642        bounds: EpsilonAndRelative,
643        dim: usize,
644        num_trials: usize,
645        distribution: Distribution,
646        rng: &mut impl Rng,
647        mut cb: Callback,
648    ) where
649        L: test_util::CornerCases,
650        R: test_util::CornerCases,
651        Distribution:
652            test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
653        To: GetInner + Copy,
654        Callback: FnMut(To, To),
655    {
656        let checker =
657            test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
658                // Invoke the callback with the received numbers.
659                cb(got, expected);
660                assert_relative_eq!(
661                    got.get_inner(),
662                    expected.get_inner(),
663                    epsilon = bounds.epsilon,
664                    max_relative = bounds.max_relative
665                );
666            });
667
668        test_util::test_distance_function(
669            checker,
670            distribution.clone(),
671            distribution.clone(),
672            dim,
673            num_trials,
674            rng,
675        );
676    }
677
678    /// The maximum dimension tested for these tests.
679    #[cfg(not(debug_assertions))]
680    const MAX_DIM: usize = 256;
681
682    #[cfg(debug_assertions)]
683    const MAX_DIM: usize = 160;
684
685    // Decrease the number of trials in debug mode to keep test run-time down.
686    #[cfg(not(debug_assertions))]
687    const INTEGER_TRIALS: usize = 10000;
688
689    #[cfg(debug_assertions)]
690    const INTEGER_TRIALS: usize = 100;
691
692    ////////////////////
693    // Integer Tester //
694    ////////////////////
695
696    // For integer tests - we expect exact reproducibility with the reference
697    // implementations.
698    fn run_integer_test<T, R>(
699        under_test: fn(&[T], &[T]) -> R,
700        reference: fn(&[T], &[T]) -> R,
701        rng: &mut impl Rng,
702    ) where
703        T: test_util::CornerCases,
704        R: GetInner + Copy,
705        rand::distr::StandardUniform: test_util::GenerateRandomArguments<T> + Clone,
706    {
707        let distribution = rand::distr::StandardUniform {};
708        let num_corner_cases = <T as test_util::CornerCases>::corner_cases().len();
709
710        for dim in 0..MAX_DIM {
711            let mut callcount = 0;
712            let callback = |_, _| {
713                callcount += 1;
714            };
715
716            run_test(
717                under_test,
718                reference,
719                EpsilonAndRelative {
720                    epsilon: 0.0,
721                    max_relative: 0.0,
722                },
723                dim,
724                INTEGER_TRIALS,
725                distribution,
726                rng,
727                callback,
728            );
729
730            // Make sure the expected number of callbacks were made.
731            assert_eq!(
732                callcount,
733                INTEGER_TRIALS + num_corner_cases * num_corner_cases
734            );
735        }
736    }
737
738    //////////////////
739    // L2 - Integer //
740    //////////////////
741
742    #[test]
743    fn test_l2_i8_mathematical() {
744        let mut rng = rand::rngs::StdRng::seed_from_u64(0x2bb701074c2b81c9);
745        run_integer_test(
746            as_function_pointer::<FullL2, i8, i8, MathematicalValue<f32>>,
747            reference::reference_l2_i8_mathematical,
748            &mut rng,
749        );
750    }
751
752    #[test]
753    fn test_l2_u8_mathematical() {
754        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9284ced6d080808c);
755        run_integer_test(
756            as_function_pointer::<FullL2, u8, u8, MathematicalValue<f32>>,
757            reference::reference_l2_u8_mathematical,
758            &mut rng,
759        );
760    }
761
762    #[test]
763    fn test_l2_i8_similarity() {
764        let mut rng = rand::rngs::StdRng::seed_from_u64(0xb196fecc4def04fa);
765        run_integer_test(
766            as_function_pointer::<FullL2, i8, i8, SimilarityScore<f32>>,
767            reference::reference_l2_i8_similarity,
768            &mut rng,
769        );
770    }
771
772    #[test]
773    fn test_l2_u8_similarity() {
774        let mut rng = rand::rngs::StdRng::seed_from_u64(0x07f6463e4a654aea);
775        run_integer_test(
776            as_function_pointer::<FullL2, u8, u8, SimilarityScore<f32>>,
777            reference::reference_l2_u8_similarity,
778            &mut rng,
779        );
780    }
781
782    ////////////////////////////
783    // InnerProduct - Integer //
784    ////////////////////////////
785
786    #[test]
787    fn test_innerproduct_i8_mathematical() {
788        let mut rng = rand::rngs::StdRng::seed_from_u64(0x2c1b1bddda5774be);
789        run_integer_test(
790            as_function_pointer::<InnerProduct, i8, i8, MathematicalValue<f32>>,
791            reference::reference_innerproduct_i8_mathematical,
792            &mut rng,
793        );
794    }
795
796    #[test]
797    fn test_innerproduct_u8_mathematical() {
798        let mut rng = rand::rngs::StdRng::seed_from_u64(0x757e363832d7f215);
799        run_integer_test(
800            as_function_pointer::<InnerProduct, u8, u8, MathematicalValue<f32>>,
801            reference::reference_innerproduct_u8_mathematical,
802            &mut rng,
803        );
804    }
805
806    #[test]
807    fn test_innerproduct_i8_similarity() {
808        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4788ce0b991eb15a);
809        run_integer_test(
810            as_function_pointer::<InnerProduct, i8, i8, SimilarityScore<f32>>,
811            reference::reference_innerproduct_i8_similarity,
812            &mut rng,
813        );
814    }
815
816    #[test]
817    fn test_innerproduct_u8_similarity() {
818        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4994adb68f814d96);
819        run_integer_test(
820            as_function_pointer::<InnerProduct, u8, u8, SimilarityScore<f32>>,
821            reference::reference_innerproduct_u8_similarity,
822            &mut rng,
823        );
824    }
825
826    //////////////////////
827    // Cosine - Integer //
828    //////////////////////
829
830    #[test]
831    fn test_cosine_i8_mathematical() {
832        let mut rng = rand::rngs::StdRng::seed_from_u64(0xedef81c780491ada);
833        run_integer_test(
834            as_function_pointer::<Cosine, i8, i8, MathematicalValue<f32>>,
835            reference::reference_cosine_i8_mathematical,
836            &mut rng,
837        );
838    }
839
840    #[test]
841    fn test_cosine_u8_mathematical() {
842        let mut rng = rand::rngs::StdRng::seed_from_u64(0x107cee2adcc58b73);
843        run_integer_test(
844            as_function_pointer::<Cosine, u8, u8, MathematicalValue<f32>>,
845            reference::reference_cosine_u8_mathematical,
846            &mut rng,
847        );
848    }
849
850    #[test]
851    fn test_cosine_i8_similarity() {
852        let mut rng = rand::rngs::StdRng::seed_from_u64(0x02d95c1cc0843647);
853        run_integer_test(
854            as_function_pointer::<Cosine, i8, i8, SimilarityScore<f32>>,
855            reference::reference_cosine_i8_similarity,
856            &mut rng,
857        );
858    }
859
860    #[test]
861    fn test_cosine_u8_similarity() {
862        let mut rng = rand::rngs::StdRng::seed_from_u64(0xf5ea1974bf8d8b3b);
863        run_integer_test(
864            as_function_pointer::<Cosine, u8, u8, SimilarityScore<f32>>,
865            reference::reference_cosine_u8_similarity,
866            &mut rng,
867        );
868    }
869
870    //////////////////
871    // Float Tester //
872    //////////////////
873
874    // For integer tests - we expect exact reproducibility with the reference
875    // implementations.
876    fn run_float_test<L, R, To, Dist>(
877        under_test: fn(&[L], &[R]) -> To,
878        reference: fn(&[L], &[R]) -> To,
879        rng: &mut impl Rng,
880        distribution: Dist,
881        bounds: EpsilonAndRelative,
882    ) where
883        L: test_util::CornerCases,
884        R: test_util::CornerCases,
885        To: GetInner + Copy,
886        Dist: test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
887    {
888        let left_corner_cases = <L as test_util::CornerCases>::corner_cases().len();
889        let right_corner_cases = <R as test_util::CornerCases>::corner_cases().len();
890        for dim in 0..MAX_DIM {
891            let mut callcount = 0;
892            let callback = |_, _| {
893                callcount += 1;
894            };
895
896            run_test(
897                under_test,
898                reference,
899                bounds,
900                dim,
901                INTEGER_TRIALS,
902                distribution.clone(),
903                rng,
904                callback,
905            );
906
907            // Make sure the expected number of callbacks were made.
908            assert_eq!(
909                callcount,
910                INTEGER_TRIALS + left_corner_cases * right_corner_cases
911            );
912        }
913    }
914
915    ////////////////
916    // L2 - Float //
917    ////////////////
918
919    fn expected_l2_errors() -> EpsilonAndRelative {
920        EpsilonAndRelative {
921            epsilon: 0.0,
922            max_relative: 1.2e-6,
923        }
924    }
925
926    #[test]
927    fn test_l2_f32_mathematical() {
928        let mut rng = rand::rngs::StdRng::seed_from_u64(0x6d22d320bdf35aec);
929        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
930        run_float_test(
931            as_function_pointer::<FullL2, f32, f32, MathematicalValue<f32>>,
932            reference::reference_l2_f32_mathematical,
933            &mut rng,
934            distribution,
935            expected_l2_errors(),
936        );
937    }
938
939    #[test]
940    fn test_l2_f16_mathematical() {
941        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
942        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
943        run_float_test(
944            as_function_pointer::<FullL2, Half, Half, MathematicalValue<f32>>,
945            reference::reference_l2_f16_mathematical,
946            &mut rng,
947            distribution,
948            expected_l2_errors(),
949        );
950    }
951
952    #[test]
953    fn test_l2_f32xf16_mathematical() {
954        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
955        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
956
957        run_float_test(
958            as_function_pointer::<FullL2, f32, Half, MathematicalValue<f32>>,
959            reference::reference_l2_f32xf16_mathematical,
960            &mut rng,
961            distribution,
962            expected_l2_errors(),
963        );
964    }
965
966    #[test]
967    fn test_l2_f32_similarity() {
968        let mut rng = rand::rngs::StdRng::seed_from_u64(0xbfc5f4b42b5bc0c1);
969        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
970        run_float_test(
971            as_function_pointer::<FullL2, f32, f32, SimilarityScore<f32>>,
972            reference::reference_l2_f32_similarity,
973            &mut rng,
974            distribution,
975            expected_l2_errors(),
976        );
977    }
978
979    #[test]
980    fn test_l2_f16_similarity() {
981        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9d3809d84f54e4b6);
982        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
983        run_float_test(
984            as_function_pointer::<FullL2, Half, Half, SimilarityScore<f32>>,
985            reference::reference_l2_f16_similarity,
986            &mut rng,
987            distribution,
988            expected_l2_errors(),
989        );
990    }
991
992    #[test]
993    fn test_l2_f32xf16_similarity() {
994        let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
995        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
996
997        run_float_test(
998            as_function_pointer::<FullL2, f32, Half, SimilarityScore<f32>>,
999            reference::reference_l2_f32xf16_similarity,
1000            &mut rng,
1001            distribution,
1002            expected_l2_errors(),
1003        );
1004    }
1005
1006    ///////////////////////////
1007    // InnerProduct - Floats //
1008    ///////////////////////////
1009
1010    fn expected_innerproduct_errors() -> EpsilonAndRelative {
1011        EpsilonAndRelative {
1012            epsilon: 2.5e-5,
1013            max_relative: 1.6e-5,
1014        }
1015    }
1016
1017    #[test]
1018    fn test_innerproduct_f32_mathematical() {
1019        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1ef6ac3b65869792);
1020        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1021        run_float_test(
1022            as_function_pointer::<InnerProduct, f32, f32, MathematicalValue<f32>>,
1023            reference::reference_innerproduct_f32_mathematical,
1024            &mut rng,
1025            distribution,
1026            expected_innerproduct_errors(),
1027        );
1028    }
1029
1030    #[test]
1031    fn test_innerproduct_f16_mathematical() {
1032        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1033        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1034        run_float_test(
1035            as_function_pointer::<InnerProduct, Half, Half, MathematicalValue<f32>>,
1036            reference::reference_innerproduct_f16_mathematical,
1037            &mut rng,
1038            distribution,
1039            expected_innerproduct_errors(),
1040        );
1041    }
1042
1043    #[test]
1044    fn test_innerproduct_f32xf16_mathematical() {
1045        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1046        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1047        run_float_test(
1048            as_function_pointer::<InnerProduct, f32, Half, MathematicalValue<f32>>,
1049            reference::reference_innerproduct_f32xf16_mathematical,
1050            &mut rng,
1051            distribution,
1052            expected_innerproduct_errors(),
1053        );
1054    }
1055
1056    #[test]
1057    fn test_innerproduct_f32_similarity() {
1058        let mut rng = rand::rngs::StdRng::seed_from_u64(0x40326b22a57db0d7);
1059        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1060        run_float_test(
1061            as_function_pointer::<InnerProduct, f32, f32, SimilarityScore<f32>>,
1062            reference::reference_innerproduct_f32_similarity,
1063            &mut rng,
1064            distribution,
1065            expected_innerproduct_errors(),
1066        );
1067    }
1068
1069    #[test]
1070    fn test_innerproduct_f16_similarity() {
1071        let mut rng = rand::rngs::StdRng::seed_from_u64(0xfb8cff47bcbc9528);
1072        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1073        run_float_test(
1074            as_function_pointer::<InnerProduct, Half, Half, SimilarityScore<f32>>,
1075            reference::reference_innerproduct_f16_similarity,
1076            &mut rng,
1077            distribution,
1078            expected_innerproduct_errors(),
1079        );
1080    }
1081
1082    #[test]
1083    fn test_innerproduct_f32xf16_similarity() {
1084        let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1085        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1086        run_float_test(
1087            as_function_pointer::<InnerProduct, f32, Half, SimilarityScore<f32>>,
1088            reference::reference_innerproduct_f32xf16_similarity,
1089            &mut rng,
1090            distribution,
1091            expected_innerproduct_errors(),
1092        );
1093    }
1094
1095    /////////////////////
1096    // Cosine - Floats //
1097    /////////////////////
1098
1099    fn expected_cosine_errors() -> EpsilonAndRelative {
1100        EpsilonAndRelative {
1101            epsilon: 3e-7,
1102            max_relative: 5e-6,
1103        }
1104    }
1105
1106    #[test]
1107    fn test_cosine_f32_mathematical() {
1108        let mut rng = rand::rngs::StdRng::seed_from_u64(0xca6eaac942999500);
1109        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1110        run_float_test(
1111            as_function_pointer::<Cosine, f32, f32, MathematicalValue<f32>>,
1112            reference::reference_cosine_f32_mathematical,
1113            &mut rng,
1114            distribution,
1115            expected_cosine_errors(),
1116        );
1117    }
1118
1119    #[test]
1120    fn test_cosine_f16_mathematical() {
1121        let mut rng = rand::rngs::StdRng::seed_from_u64(0xa736c789aa16ce86);
1122        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1123        run_float_test(
1124            as_function_pointer::<Cosine, Half, Half, MathematicalValue<f32>>,
1125            reference::reference_cosine_f16_mathematical,
1126            &mut rng,
1127            distribution,
1128            expected_cosine_errors(),
1129        );
1130    }
1131
1132    #[test]
1133    fn test_cosine_f32xf16_mathematical() {
1134        let mut rng = rand::rngs::StdRng::seed_from_u64(0xac550231088a0d5c);
1135        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1136        run_float_test(
1137            as_function_pointer::<Cosine, f32, Half, MathematicalValue<f32>>,
1138            reference::reference_cosine_f32xf16_mathematical,
1139            &mut rng,
1140            distribution,
1141            expected_cosine_errors(),
1142        );
1143    }
1144
1145    #[test]
1146    fn test_cosine_f32_similarity() {
1147        let mut rng = rand::rngs::StdRng::seed_from_u64(0x4a09ad987a6204f3);
1148        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1149        run_float_test(
1150            as_function_pointer::<Cosine, f32, f32, SimilarityScore<f32>>,
1151            reference::reference_cosine_f32_similarity,
1152            &mut rng,
1153            distribution,
1154            expected_cosine_errors(),
1155        );
1156    }
1157
1158    #[test]
1159    fn test_cosine_f16_similarity() {
1160        let mut rng = rand::rngs::StdRng::seed_from_u64(0x77a48d1914f850f2);
1161        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1162        run_float_test(
1163            as_function_pointer::<Cosine, Half, Half, SimilarityScore<f32>>,
1164            reference::reference_cosine_f16_similarity,
1165            &mut rng,
1166            distribution,
1167            expected_cosine_errors(),
1168        );
1169    }
1170
1171    #[test]
1172    fn test_cosine_f32xf16_similarity() {
1173        let mut rng = rand::rngs::StdRng::seed_from_u64(0xbd7471b815655ca1);
1174        let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1175        run_float_test(
1176            as_function_pointer::<Cosine, f32, Half, SimilarityScore<f32>>,
1177            reference::reference_cosine_f32xf16_similarity,
1178            &mut rng,
1179            distribution,
1180            expected_cosine_errors(),
1181        );
1182    }
1183
1184    ///////////////////////////////
1185    // CosineNormalized - Floats //
1186    ///////////////////////////////
1187
1188    fn expected_cosine_normalized_errors() -> EpsilonAndRelative {
1189        EpsilonAndRelative {
1190            epsilon: 3e-7,
1191            max_relative: 5e-6,
1192        }
1193    }
1194
1195    #[test]
1196    fn test_cosine_normalized_f32_mathematical() {
1197        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1fda98112747f8dd);
1198        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1199        run_float_test(
1200            as_function_pointer::<CosineNormalized, f32, f32, MathematicalValue<f32>>,
1201            reference::reference_cosine_normalized_f32_mathematical,
1202            &mut rng,
1203            test_util::Normalized(distribution),
1204            expected_cosine_normalized_errors(),
1205        );
1206    }
1207
1208    #[test]
1209    fn test_cosine_normalized_f16_mathematical() {
1210        let mut rng = rand::rngs::StdRng::seed_from_u64(0x5e8c5d5e19cdd840);
1211        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1212        run_float_test(
1213            as_function_pointer::<CosineNormalized, Half, Half, MathematicalValue<f32>>,
1214            reference::reference_cosine_normalized_f16_mathematical,
1215            &mut rng,
1216            test_util::Normalized(distribution),
1217            expected_cosine_normalized_errors(),
1218        );
1219    }
1220
1221    #[test]
1222    fn test_cosine_normalized_f32xf16_mathematical() {
1223        let mut rng = rand::rngs::StdRng::seed_from_u64(0x3fd01e1c11c9bc45);
1224        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1225        run_float_test(
1226            as_function_pointer::<CosineNormalized, f32, Half, MathematicalValue<f32>>,
1227            reference::reference_cosine_normalized_f32xf16_mathematical,
1228            &mut rng,
1229            test_util::Normalized(distribution),
1230            expected_cosine_normalized_errors(),
1231        );
1232    }
1233
1234    #[test]
1235    fn test_cosine_normalized_f32_similarity() {
1236        let mut rng = rand::rngs::StdRng::seed_from_u64(0x9446d057870e5605);
1237        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1238        run_float_test(
1239            as_function_pointer::<CosineNormalized, f32, f32, SimilarityScore<f32>>,
1240            reference::reference_cosine_normalized_f32_similarity,
1241            &mut rng,
1242            test_util::Normalized(distribution),
1243            expected_cosine_normalized_errors(),
1244        );
1245    }
1246
1247    #[test]
1248    fn test_cosine_normalized_f16_similarity() {
1249        let mut rng = rand::rngs::StdRng::seed_from_u64(0x885c371801f18174);
1250        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1251        run_float_test(
1252            as_function_pointer::<CosineNormalized, Half, Half, SimilarityScore<f32>>,
1253            reference::reference_cosine_normalized_f16_similarity,
1254            &mut rng,
1255            test_util::Normalized(distribution),
1256            expected_cosine_normalized_errors(),
1257        );
1258    }
1259
1260    #[test]
1261    fn test_cosine_normalized_f32xf16_similarity() {
1262        let mut rng = rand::rngs::StdRng::seed_from_u64(0x1c356c92d0522c0f);
1263        let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1264        run_float_test(
1265            as_function_pointer::<CosineNormalized, f32, Half, SimilarityScore<f32>>,
1266            reference::reference_cosine_normalized_f32xf16_similarity,
1267            &mut rng,
1268            test_util::Normalized(distribution),
1269            expected_cosine_normalized_errors(),
1270        );
1271    }
1272}