Skip to main content

diskann_vector/distance/
distance_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6#[cfg(target_arch = "aarch64")]
7use diskann_wide::arch::aarch64::Neon;
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10use diskann_wide::{
11    arch::{Dispatched2, FTarget2, Scalar},
12    lifetime::Ref,
13    Architecture,
14};
15use half::f16;
16
17use super::{implementations::Specialize, Cosine, CosineNormalized, InnerProduct, SquaredL2};
18use crate::distance::Metric;
19
20/// Return a function pointer-like [`Distance`] to compute the requested metric.
21///
22/// If `dimension` is provided, then the returned function may **only** be used on
23/// slices with length `dimension`. Calling the returned function with a different sized
24/// slice **may** panic.
25///
26/// If `dimension` is not provided, then the returned function will work for all sizes.
27///
28/// The functions returned by `distance_comparer` do not have strict alignment
29/// requirements, though aligning your data *may* yield better memory performance.
30///
31/// # Metric Semantics
32///
33/// The values computed by the returned functions may be modified from the true mathematical
34/// definition of the metric to ensure that values closer to `-infinity` imply more similar.
35///
36/// * `L2`: Computes the squared L2 distance between vectors.
37/// * `InnerProduct`: Returns the **negative** inner-product.
38/// * `Cosine`: Returns `1 - cosine-similarity` and will work on un-normalized vectors.
39/// * `CosineNormalized`: Returns `1 - cosinesimilarity` with the hint that the provided
40///   vectors have norm 1. This allows for potentially more-efficient implementations but the
41///   results may be incorrect if called with unnormalized data.
42///
43///   When provided with integer arguments (for which normalization does not make sense), this
44///   behaves as if `Cosine` was provided.
45pub trait DistanceProvider<T>: Sized + 'static {
46    fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
47}
48
49/// A function pointer-like type for computing distances between `&[T]` and `&[U]`.
50///
51/// See: [`DistanceProvider`].
52#[derive(Debug, Clone, Copy)]
53pub struct Distance<T, U>
54where
55    T: 'static,
56    U: 'static,
57{
58    f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
59}
60
61impl<T, U> Distance<T, U>
62where
63    T: 'static,
64    U: 'static,
65{
66    fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
67        Self { f }
68    }
69
70    /// Compute a distances between `x` and `y`.
71    ///
72    /// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
73    ///
74    /// Additionally, if a dimension were given to [`DistanceProvider`], this function may
75    /// panic if provided with slices with a length not equal to this dimension.
76    #[inline]
77    pub fn call(&self, x: &[T], y: &[U]) -> f32 {
78        self.f.call(x, y)
79    }
80}
81
82impl<T, U> crate::DistanceFunction<&[T], &[U], f32> for Distance<T, U>
83where
84    T: 'static,
85    U: 'static,
86{
87    fn evaluate_similarity(&self, x: &[T], y: &[U]) -> f32 {
88        self.call(x, y)
89    }
90}
91
92////////////////////
93// Implementation //
94////////////////////
95
96// Implementation Notes
97//
98// Our implementation of `DistanceProvider` dispatches across:
99//
100// * Data Types
101// * Metric
102// * Dimensions
103// * Runtime Micro-architecture
104//
105// This is a combinatorial explosing of potentially compiled kernels. To get a handle on the
106// sheer number of compiled functions, we manually control the dimensional specialization on
107// a case-by-case basis.
108//
109// This is facilitated by the `specialize!` macro, which accepts a list of dimensions and
110// instantiates the necessary machinery.
111//
112// To explain the machiner a little, a [`Cons`] compile-time list is constructed. This type
113// might look like
114//
115// * `Cons<Spec<100>, Cons<Spec<64>, Spec<32>>>`: To specialize dimensions 100, 64, and 32.
116// * `Cons<Spec<100>, Null>`: To specialize just dimension 100.
117// * `Cons<Null, Null>`: To specialize no dimensions.
118//
119// The `TrySpecialize` trait is then used specialize a kernel `F` for an architecture `A`
120// with implementations
121//
122// * `Spec<N>`: Check if the requested dimension is equal to `N` and if so, return the
123//   specialized method.
124// * `Null`: Never specialize.
125// * `Cons<Head, Tail>`: Try to specialize using `Head` returning if successful. Otherwise,
126//   return the specialization of `Tail`.
127//
128//   This definition is what allows nested `Cons` structures to specialize multiple dimensions.
129//
130//   The `Cons` list also compiles a generic-dimensional fallback if none of the
131//   specializations match.
132//
133// The overall flow is
134//
135// 1. Enter the `DistanceProvider` implementation.
136//
137// 2. First dispatch across micro-architecture using `ArgumentTypes` to hold the data types.
138//    `ArgumentTypes` implements `Target2` to facilitate this dispatch.
139//
140// 3. The implementations of `Target2` for `ArgumentTypes` are performed by the `specialize!`
141//    macro, which creates a `Cons` list of requested specializations, switches across
142//    metrics and invokes `Cons:specialize` on the requested metric.
143
144macro_rules! provider {
145    ($T:ty, $U:ty) => {
146        impl DistanceProvider<$U> for $T {
147            fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<$T, $U> {
148                // Use the `no-features` variant because we do not care if the target gets
149                // compiled for higher micro-architecture levels.
150                //
151                // It's the returned kernel that matters.
152                diskann_wide::arch::dispatch2_no_features(
153                    ArgumentTypes::<$T, $U>::new(),
154                    metric,
155                    dimension,
156                )
157            }
158        }
159    };
160}
161
162provider!(f32, f32);
163provider!(f16, f16);
164provider!(f32, f16);
165provider!(i8, i8);
166provider!(u8, u8);
167
168/////////////////////////
169// Specialization List //
170/////////////////////////
171
172macro_rules! spec_list {
173    ($($Ns:literal),* $(,)?) => {
174        spec_list!(@value, $($Ns,)*)
175    };
176    (@value $(,)?) => {
177        Cons::new(Null, Null)
178    };
179    (@value, $N0:literal $(,)?) => {
180        Cons::new(Spec::<$N0>, Null)
181    };
182    (@value, $N0:literal, $N1:literal $(,)?) => {
183        Cons::new(Spec::<$N0>, Spec::<$N1>)
184    };
185    (@value, $N0:literal, $N1:literal, $($Ns:literal),+ $(,)?) => {
186        Cons::new(Spec::<$N0>, spec_list!(@value, $N1, $($Ns,)+))
187    };
188}
189
190struct ArgumentTypes<T: 'static, U: 'static>(std::marker::PhantomData<(T, U)>);
191
192impl<T, U> ArgumentTypes<T, U>
193where
194    T: 'static,
195    U: 'static,
196{
197    fn new() -> Self {
198        Self(std::marker::PhantomData)
199    }
200}
201
202macro_rules! specialize {
203    ($arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
204        impl diskann_wide::arch::Target2<
205            $arch,
206            Distance<$T, $U>,
207            Metric,
208            Option<usize>,
209        > for ArgumentTypes<$T, $U> {
210            fn run(
211                self,
212                arch: $arch,
213                metric: Metric,
214                dim: Option<usize>,
215            ) -> Distance<$T, $U> {
216                let spec = spec_list!($($Ns),*);
217                match metric {
218                    Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
219                    Metric::Cosine => spec.specialize(arch, Cosine {}, dim),
220                    Metric::CosineNormalized => spec.specialize(arch, CosineNormalized {}, dim),
221                    Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
222                }
223            }
224        }
225    };
226    // Integer types redirect `CosineNormalized` to `Cosine`.
227    (@integer, $arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
228        impl diskann_wide::arch::Target2<
229            $arch,
230            Distance<$T, $U>,
231            Metric,
232            Option<usize>,
233        > for ArgumentTypes<$T, $U> {
234            fn run(
235                self,
236                arch: $arch,
237                metric: Metric,
238                dim: Option<usize>,
239            ) -> Distance<$T, $U> {
240                let spec = spec_list!($($Ns),*);
241                match metric {
242                    Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
243                    Metric::Cosine | Metric::CosineNormalized => {
244                        spec.specialize(arch, Cosine {}, dim)
245                    },
246                    Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
247                }
248            }
249        }
250    };
251}
252
253specialize!(Scalar, f32, f32,);
254specialize!(Scalar, f32, f16,);
255specialize!(Scalar, f16, f16,);
256specialize!(@integer, Scalar, u8, u8,);
257specialize!(@integer, Scalar, i8, i8,);
258
259#[cfg(target_arch = "x86_64")]
260mod x86_64 {
261    use super::*;
262
263    specialize!(V3, f32, f32, 768, 384, 128, 100);
264    specialize!(V4, f32, f32, 768, 384, 128, 100);
265
266    specialize!(V3, f32, f16, 768, 384, 128, 100);
267    specialize!(V4, f32, f16, 768, 384, 128, 100);
268
269    specialize!(V3, f16, f16, 768, 384, 128, 100);
270    specialize!(V4, f16, f16, 768, 384, 128, 100);
271
272    specialize!(@integer, V3, u8, u8, 128);
273    specialize!(@integer, V4, u8, u8, 128);
274
275    specialize!(@integer, V3, i8, i8, 128, 100);
276    specialize!(@integer, V4, i8, i8, 128, 100);
277}
278
279#[cfg(target_arch = "aarch64")]
280mod aarch64 {
281    use super::*;
282
283    specialize!(Neon, f32, f32, 768, 384, 128, 100);
284    specialize!(Neon, f32, f16, 768, 384, 128, 100);
285    specialize!(Neon, f16, f16, 768, 384, 128, 100);
286
287    specialize!(@integer, Neon, u8, u8, 128);
288    specialize!(@integer, Neon, i8, i8, 128, 100);
289}
290
291/// Specialize a distance function `F` for the dimension `dim` if possible. Otherwise,
292/// return `None`.
293trait TrySpecialize<A, F, T, U>
294where
295    A: Architecture,
296    T: 'static,
297    U: 'static,
298{
299    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>>;
300}
301
302/// Specialize a distance function for the requested dimensionality.
303struct Spec<const N: usize>;
304
305impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
306where
307    A: Architecture,
308    Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
309    T: 'static,
310    U: 'static,
311{
312    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
313        if let Some(d) = dim {
314            if d == N {
315                return Some(Distance::new(
316                    // NOTE: This line here is what actually compiles the specialized kernel.
317                    arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
318                ));
319            }
320        }
321        None
322    }
323}
324
325/// Don't specialize at all.
326struct Null;
327
328impl<A, F, T, U> TrySpecialize<A, F, T, U> for Null
329where
330    A: Architecture,
331    T: 'static,
332    U: 'static,
333{
334    fn try_specialize(&self, _arch: A, _dim: Option<usize>) -> Option<Distance<T, U>> {
335        None
336    }
337}
338
339/// A recursive compile-time list for building a list of specializations.
340struct Cons<Head, Tail> {
341    head: Head,
342    tail: Tail,
343}
344
345impl<Head, Tail> Cons<Head, Tail> {
346    const fn new(head: Head, tail: Tail) -> Self {
347        Self { head, tail }
348    }
349
350    /// Try to specialize `F`. If no such specialization is available, return a fallback
351    /// implementation.
352    fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
353    where
354        A: Architecture,
355        F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
356        Head: TrySpecialize<A, F, T, U>,
357        Tail: TrySpecialize<A, F, T, U>,
358        T: 'static,
359        U: 'static,
360    {
361        if let Some(f) = self.try_specialize(arch, dim) {
362            f
363        } else {
364            Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
365        }
366    }
367}
368
369// Try `Head` and then `Tail`.
370impl<A, Head, Tail, F, T, U> TrySpecialize<A, F, T, U> for Cons<Head, Tail>
371where
372    A: Architecture,
373    Head: TrySpecialize<A, F, T, U>,
374    Tail: TrySpecialize<A, F, T, U>,
375    T: 'static,
376    U: 'static,
377{
378    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
379        if let Some(f) = self.head.try_specialize(arch, dim) {
380            Some(f)
381        } else {
382            self.tail.try_specialize(arch, dim)
383        }
384    }
385}
386
387///////////
388// Tests //
389///////////
390
391#[cfg(test)]
392mod test_unaligned_distance_provider {
393    use approx::assert_relative_eq;
394    use rand::{self, SeedableRng};
395
396    use super::*;
397    use crate::{
398        distance::{reference::ReferenceProvider, Metric},
399        test_util, SimilarityScore,
400    };
401
402    // Comparison Bounds
403    struct EpsilonAndRelative {
404        epsilon: f32,
405        max_relative: f32,
406    }
407
408    /// For now - these are rough bounds selected heuristically.
409    /// Eventually (once we have implementations using compensated arithmetic), we should
410    /// empirically derive bounds based on a combination of
411    ///
412    /// 1. Input Distribution
413    /// 2. Distance Function
414    /// 3. Dimensionality
415    ///
416    /// To ensure that these bounds are tight.
417    fn get_float_bounds(metric: Metric) -> EpsilonAndRelative {
418        match metric {
419            Metric::L2 => EpsilonAndRelative {
420                epsilon: 1e-5,
421                max_relative: 1e-5,
422            },
423            Metric::InnerProduct => EpsilonAndRelative {
424                epsilon: 1e-4,
425                max_relative: 1e-4,
426            },
427            Metric::Cosine => EpsilonAndRelative {
428                epsilon: 1e-4,
429                max_relative: 1e-4,
430            },
431            Metric::CosineNormalized => EpsilonAndRelative {
432                epsilon: 1e-4,
433                max_relative: 1e-4,
434            },
435        }
436    }
437
438    fn get_int_bounds(metric: Metric) -> EpsilonAndRelative {
439        match metric {
440            // Allow for some error when handling the normalization at the end.
441            Metric::Cosine | Metric::CosineNormalized => EpsilonAndRelative {
442                epsilon: 1e-6,
443                max_relative: 1e-6,
444            },
445            // These should be exact.
446            Metric::L2 | Metric::InnerProduct => EpsilonAndRelative {
447                epsilon: 0.0,
448                max_relative: 0.0,
449            },
450        }
451    }
452
453    fn do_test<T, Distribution>(
454        under_test: Distance<T, T>,
455        reference: fn(&[T], &[T]) -> SimilarityScore<f32>,
456        bounds: EpsilonAndRelative,
457        dim: usize,
458        distribution: Distribution,
459    ) where
460        T: test_util::CornerCases,
461        Distribution: test_util::GenerateRandomArguments<T> + Clone,
462    {
463        let mut rng = rand::rngs::StdRng::seed_from_u64(0xef0053c);
464
465        // Unwrap the SimilarityScore for the reference implementation.
466        let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };
467
468        let checker = test_util::Checker::<T, T, f32>::new(
469            |a, b| under_test.call(a, b),
470            converted,
471            |got: f32, expected: f32| {
472                assert_relative_eq!(
473                    got,
474                    expected,
475                    epsilon = bounds.epsilon,
476                    max_relative = bounds.max_relative
477                );
478            },
479        );
480
481        test_util::test_distance_function(
482            checker,
483            distribution.clone(),
484            distribution.clone(),
485            dim,
486            10,
487            &mut rng,
488        );
489    }
490
491    fn all_metrics() -> [Metric; 4] {
492        [
493            Metric::L2,
494            Metric::InnerProduct,
495            Metric::Cosine,
496            Metric::CosineNormalized,
497        ]
498    }
499
500    /// The maximum dimension used for unaligned behavior checking with simple distances.
501    const MAX_DIM: usize = 256;
502
503    #[test]
504    fn test_unaligned_f32() {
505        let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
506        for metric in all_metrics() {
507            for dim in 0..MAX_DIM {
508                println!("Metric = {:?}, dim = {}", metric, dim);
509                let unaligned = <f32 as DistanceProvider<f32>>::distance_comparer(metric, None);
510                let simple = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
511                let bounds = get_float_bounds(metric);
512                do_test(unaligned, simple, bounds, dim, dist);
513            }
514        }
515    }
516
517    #[test]
518    fn test_unaligned_f16() {
519        let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
520        for metric in all_metrics() {
521            for dim in 0..MAX_DIM {
522                println!("Metric = {:?}, dim = {}", metric, dim);
523                let unaligned = <f16 as DistanceProvider<f16>>::distance_comparer(metric, None);
524                let simple = <f16 as ReferenceProvider<f16>>::reference_implementation(metric);
525                let bounds = get_float_bounds(metric);
526                do_test(unaligned, simple, bounds, dim, dist);
527            }
528        }
529    }
530
531    #[test]
532    fn test_unaligned_u8() {
533        let dist = rand::distr::StandardUniform {};
534        for metric in all_metrics() {
535            for dim in 0..MAX_DIM {
536                println!("Metric = {:?}, dim = {}", metric, dim);
537                let unaligned = <u8 as DistanceProvider<u8>>::distance_comparer(metric, None);
538                let simple = <u8 as ReferenceProvider<u8>>::reference_implementation(metric);
539                let bounds = get_int_bounds(metric);
540                do_test(unaligned, simple, bounds, dim, dist);
541            }
542        }
543    }
544
545    #[test]
546    fn test_unaligned_i8() {
547        let dist = rand::distr::StandardUniform {};
548        for metric in all_metrics() {
549            for dim in 0..MAX_DIM {
550                println!("Metric = {:?}, dim = {}", metric, dim);
551                let unaligned = <i8 as DistanceProvider<i8>>::distance_comparer(metric, None);
552                let simple = <i8 as ReferenceProvider<i8>>::reference_implementation(metric);
553
554                let bounds = get_int_bounds(metric);
555                do_test(unaligned, simple, bounds, dim, dist);
556            }
557        }
558    }
559}
560
561#[cfg(test)]
562mod distance_provider_f32_tests {
563    use approx::assert_abs_diff_eq;
564    use rand::{rngs::StdRng, Rng, SeedableRng};
565
566    use super::*;
567    use crate::{distance::reference, test_util::*};
568
569    #[repr(C, align(32))]
570    pub struct F32Slice112([f32; 112]);
571    #[repr(C, align(32))]
572    pub struct F32Slice104([f32; 104]);
573    #[repr(C, align(32))]
574    pub struct F32Slice128([f32; 128]);
575    #[repr(C, align(32))]
576    pub struct F32Slice256([f32; 256]);
577    #[repr(C, align(32))]
578    pub struct F32Slice4096([f32; 4096]);
579
580    pub fn get_turing_test_data_f32_dim(dim: usize) -> (Vec<f32>, Vec<f32>) {
581        let mut a_slice = vec![0.0f32; dim];
582        let mut b_slice = vec![0.0f32; dim];
583
584        let mut rng = StdRng::seed_from_u64(42);
585        for i in 0..dim {
586            a_slice[i] = rng.random_range(-1.0..1.0);
587            b_slice[i] = rng.random_range(-1.0..1.0);
588        }
589
590        ((a_slice), (b_slice))
591    }
592
593    #[test]
594    fn test_dist_l2_float_turing_104() {
595        let (a_data, b_data) = get_turing_test_data_f32_dim(104);
596        let (a_slice, b_slice) = (
597            F32Slice104(a_data.try_into().unwrap()),
598            F32Slice104(b_data.try_into().unwrap()),
599        );
600
601        let distance: f32 = compare_two_vec::<f32>(104, Metric::L2, &a_slice.0, &b_slice.0);
602
603        assert_abs_diff_eq!(
604            distance as f64,
605            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
606            epsilon = 1e-4f64
607        );
608    }
609
610    #[test]
611    fn test_dist_l2_float_turing_112() {
612        let (a_data, b_data) = get_turing_test_data_f32_dim(112);
613        let (a_slice, b_slice) = (
614            F32Slice112(a_data.try_into().unwrap()),
615            F32Slice112(b_data.try_into().unwrap()),
616        );
617
618        let distance: f32 = compare_two_vec::<f32>(112, Metric::L2, &a_slice.0, &b_slice.0);
619
620        assert_abs_diff_eq!(
621            distance as f64,
622            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
623            epsilon = 1e-4f64
624        );
625    }
626
627    #[test]
628    fn test_dist_l2_float_turing_128() {
629        let (a_data, b_data) = get_turing_test_data_f32_dim(128);
630        let (a_slice, b_slice) = (
631            F32Slice128(a_data.try_into().unwrap()),
632            F32Slice128(b_data.try_into().unwrap()),
633        );
634
635        let distance: f32 = compare_two_vec::<f32>(128, Metric::L2, &a_slice.0, &b_slice.0);
636
637        assert_abs_diff_eq!(
638            distance as f64,
639            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
640            epsilon = 1e-4f64
641        );
642    }
643
644    #[test]
645    fn test_dist_l2_float_turing_256() {
646        let (a_data, b_data) = get_turing_test_data_f32_dim(256);
647        let (a_slice, b_slice) = (
648            F32Slice256(a_data.try_into().unwrap()),
649            F32Slice256(b_data.try_into().unwrap()),
650        );
651
652        let distance: f32 = compare_two_vec::<f32>(256, Metric::L2, &a_slice.0, &b_slice.0);
653
654        assert_abs_diff_eq!(
655            distance as f64,
656            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
657            epsilon = 1e-3f64
658        );
659    }
660
661    #[test]
662    fn test_dist_l2_float_turing_4096() {
663        let (a_data, b_data) = get_turing_test_data_f32_dim(4096);
664        let (a_slice, b_slice) = (
665            F32Slice4096(a_data.try_into().unwrap()),
666            F32Slice4096(b_data.try_into().unwrap()),
667        );
668
669        let distance: f32 = compare_two_vec::<f32>(4096, Metric::L2, &a_slice.0, &b_slice.0);
670
671        assert_abs_diff_eq!(
672            distance as f64,
673            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
674            epsilon = 1e-2f64
675        );
676    }
677
678    #[test]
679    fn test_dist_ip_float_turing_112() {
680        let (a_data, b_data) = get_turing_test_data_f32_dim(112);
681        let (a_slice, b_slice) = (
682            F32Slice112(a_data.try_into().unwrap()),
683            F32Slice112(b_data.try_into().unwrap()),
684        );
685
686        let distance: f32 =
687            compare_two_vec::<f32>(112, Metric::InnerProduct, &a_slice.0, &b_slice.0);
688
689        assert_abs_diff_eq!(
690            distance,
691            reference::reference_innerproduct_f32_similarity(&a_slice.0, &b_slice.0).into_inner(),
692            epsilon = 1e-4f32
693        );
694    }
695
696    #[test]
697    fn distance_test() {
698        #[repr(C, align(32))]
699        struct Vector32ByteAligned {
700            v: [f32; 512],
701        }
702
703        // two vectors are allocated in the contiguous heap memory
704        let two_vec = Box::new(Vector32ByteAligned {
705            v: [
706                69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
707                20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
708                65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
709                95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
710                71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
711                87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
712                40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
713                66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
714                27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
715                26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
716                35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
717                1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
718                72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
719                62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
720                70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
721                35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
722                31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
723                93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
724                94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
725                80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
726                14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
727                14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
728                64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
729                55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
730                69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
731                76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
732                70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
733                87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
734                57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
735                23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
736                34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
737                94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
738                40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
739                11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
740                95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
741                73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
742                39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
743                14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
744                2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
745                18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
746                23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
747                63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
748                26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
749                28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
750                99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
751                77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
752                72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
753                66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
754                77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
755                18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
756                7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
757                11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
758                51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
759                46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
760                6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
761                3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
762                96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
763                12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
764                44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
765                54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
766                26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
767                40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
768                25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
769                34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
770                86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
771                32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
772                9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
773                50.41517, 28.156603, 42.644154,
774            ],
775        });
776
777        let distance: f32 = compare::<f32>(256, Metric::L2, &two_vec.v);
778
779        assert_eq!(distance, 429141.2);
780    }
781
782    fn compare<T>(dim: usize, metric: Metric, v: &[T]) -> f32
783    where
784        T: DistanceProvider<T>,
785    {
786        let distance_comparer = T::distance_comparer(metric, Some(dim));
787        distance_comparer.call(&v[..dim], &v[dim..])
788    }
789
790    pub fn compare_two_vec<T>(dim: usize, metric: Metric, v1: &[T], v2: &[T]) -> f32
791    where
792        T: DistanceProvider<T>,
793    {
794        let distance_comparer = T::distance_comparer(metric, Some(dim));
795        distance_comparer.call(&v1[..dim], &v2[..dim])
796    }
797}
798
799#[cfg(test)]
800mod distance_provider_f16_tests {
801    use approx::assert_abs_diff_eq;
802
803    use super::{distance_provider_f32_tests::get_turing_test_data_f32_dim, *};
804    use crate::{
805        distance::distance_provider::distance_provider_f32_tests::compare_two_vec,
806        test_util::no_vector_compare_f16_as_f64,
807    };
808
809    #[repr(C, align(32))]
810    pub struct F16Slice112([f16; 112]);
811    #[repr(C, align(32))]
812    pub struct F16Slice104([f16; 104]);
813    #[repr(C, align(32))]
814    pub struct F16Slice128([f16; 128]);
815    #[repr(C, align(32))]
816    pub struct F16Slice256([f16; 256]);
817    #[repr(C, align(32))]
818    pub struct F16Slice4096([f16; 4096]);
819
820    fn get_turing_test_data_f16_dim(dim: usize) -> (Vec<f16>, Vec<f16>) {
821        let (a_slice, b_slice) = get_turing_test_data_f32_dim(dim);
822        let a_data = a_slice.iter().map(|x| f16::from_f32(*x)).collect();
823        let b_data = b_slice.iter().map(|x| f16::from_f32(*x)).collect();
824        (a_data, b_data)
825    }
826
827    #[test]
828    fn test_dist_l2_f16_turing_112() {
829        // two vectors are allocated in the contiguous heap memory
830        let (a_data, b_data) = get_turing_test_data_f16_dim(112);
831        let (a_slice, b_slice) = (
832            F16Slice112(a_data.try_into().unwrap()),
833            F16Slice112(b_data.try_into().unwrap()),
834        );
835
836        let distance: f32 = compare_two_vec::<f16>(112, Metric::L2, &a_slice.0, &b_slice.0);
837
838        // Note the variance between the full 32 bit precision and the 16 bit precision
839        assert_abs_diff_eq!(
840            distance as f64,
841            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
842            epsilon = 1e-3f64
843        );
844    }
845
846    #[test]
847    fn test_dist_l2_f16_turing_104() {
848        // two vectors are allocated in the contiguous heap memory
849        let (a_data, b_data) = get_turing_test_data_f16_dim(104);
850        let (a_slice, b_slice) = (
851            F16Slice104(a_data.try_into().unwrap()),
852            F16Slice104(b_data.try_into().unwrap()),
853        );
854
855        let distance: f32 = compare_two_vec::<f16>(104, Metric::L2, &a_slice.0, &b_slice.0);
856
857        // Note the variance between the full 32 bit precision and the 16 bit precision
858        assert_abs_diff_eq!(
859            distance as f64,
860            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
861            epsilon = 1e-3f64
862        );
863    }
864
865    #[test]
866    fn test_dist_l2_f16_turing_256() {
867        // two vectors are allocated in the contiguous heap memory
868        let (a_data, b_data) = get_turing_test_data_f16_dim(256);
869        let (a_slice, b_slice) = (
870            F16Slice256(a_data.try_into().unwrap()),
871            F16Slice256(b_data.try_into().unwrap()),
872        );
873
874        let distance: f32 = compare_two_vec::<f16>(256, Metric::L2, &a_slice.0, &b_slice.0);
875
876        // Note the variance between the full 32 bit precision and the 16 bit precision
877        assert_abs_diff_eq!(
878            distance as f64,
879            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
880            epsilon = 1e-3f64
881        );
882    }
883
884    #[test]
885    fn test_dist_l2_f16_turing_128() {
886        // two vectors are allocated in the contiguous heap memory
887        let (a_data, b_data) = get_turing_test_data_f16_dim(128);
888        let (a_slice, b_slice) = (
889            F16Slice128(a_data.try_into().unwrap()),
890            F16Slice128(b_data.try_into().unwrap()),
891        );
892
893        let distance: f32 = compare_two_vec::<f16>(128, Metric::L2, &a_slice.0, &b_slice.0);
894
895        // Note the variance between the full 32 bit precision and the 16 bit precision
896        assert_abs_diff_eq!(
897            distance as f64,
898            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
899            epsilon = 1e-3f64
900        );
901    }
902
903    #[test]
904    fn test_dist_l2_f16_turing_4096() {
905        // two vectors are allocated in the contiguous heap memory
906        let (a_data, b_data) = get_turing_test_data_f16_dim(4096);
907        let (a_slice, b_slice) = (
908            F16Slice4096(a_data.try_into().unwrap()),
909            F16Slice4096(b_data.try_into().unwrap()),
910        );
911
912        let distance: f32 = compare_two_vec::<f16>(4096, Metric::L2, &a_slice.0, &b_slice.0);
913
914        // Note the variance between the full 32 bit precision and the 16 bit precision
915        assert_abs_diff_eq!(
916            distance as f64,
917            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
918            epsilon = 1e-2f64
919        );
920    }
921
922    #[test]
923    fn test_dist_l2_f16_produces_nan_distance_for_infinity_vectors() {
924        let a_data = vec![f16::INFINITY; 384];
925        let b_data = vec![f16::INFINITY; 384];
926
927        let distance: f32 = compare_two_vec::<f16>(384, Metric::L2, &a_data, &b_data);
928        assert!(distance.is_nan());
929    }
930}