Skip to main content

diskann_vector/distance/
distance_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6#[cfg(target_arch = "x86_64")]
7use diskann_wide::arch::x86_64::{V3, V4};
8use diskann_wide::{
9    arch::{Dispatched2, FTarget2, Scalar},
10    lifetime::Ref,
11    Architecture,
12};
13use half::f16;
14
15use super::{Cosine, CosineNormalized, InnerProduct, SquaredL2};
16use crate::distance::Metric;
17
18#[cfg(target_arch = "x86_64")]
19use super::implementations::Specialize;
20
21/// Return a function pointer-like [`Distance`] to compute the requested metric.
22///
23/// If `dimension` is provided, then the returned function may **only** be used on
24/// slices with length `dimension`. Calling the returned function with a different sized
25/// slice **may** panic.
26///
27/// If `dimension` is not provided, then the returned function will work for all sizes.
28///
29/// The functions returned by `distance_comparer` do not have strict alignment
30/// requirements, though aligning your data *may* yield better memory performance.
31///
32/// # Metric Semantics
33///
34/// The values computed by the returned functions may be modified from the true mathematical
35/// definition of the metric to ensure that values closer to `-infinity` imply more similar.
36///
37/// * `L2`: Computes the squared L2 distance between vectors.
38/// * `InnerProduct`: Returns the **negative** inner-product.
39/// * `Cosine`: Returns `1 - cosine-similarity` and will work on un-normalized vectors.
40/// * `CosineNormalized`: Returns `1 - cosinesimilarity` with the hint that the provided
41///   vectors have norm 1. This allows for potentially more-efficient implementations but the
42///   results may be incorrect if called with unnormalized data.
43///
44///   When provided with integer arguments (for which normalization does not make sense), this
45///   behaves as if `Cosine` was provided.
46pub trait DistanceProvider<T>: Sized + 'static {
47    fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
48}
49
50/// A function pointer-like type for computing distances between `&[T]` and `&[U]`.
51///
52/// See: [`DistanceProvider`].
53#[derive(Debug, Clone, Copy)]
54pub struct Distance<T, U>
55where
56    T: 'static,
57    U: 'static,
58{
59    f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
60}
61
62impl<T, U> Distance<T, U>
63where
64    T: 'static,
65    U: 'static,
66{
67    fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
68        Self { f }
69    }
70
71    /// Compute a distances between `x` and `y`.
72    ///
73    /// The actual distance computed depends on the metric supplied to [`DistanceProvider`].
74    ///
75    /// Additionally, if a dimension were given to [`DistanceProvider`], this function may
76    /// panic if provided with slices with a length not equal to this dimension.
77    #[inline]
78    pub fn call(&self, x: &[T], y: &[U]) -> f32 {
79        self.f.call(x, y)
80    }
81}
82
83impl<T, U> crate::DistanceFunction<&[T], &[U], f32> for Distance<T, U>
84where
85    T: 'static,
86    U: 'static,
87{
88    fn evaluate_similarity(&self, x: &[T], y: &[U]) -> f32 {
89        self.call(x, y)
90    }
91}
92
93////////////////////
94// Implementation //
95////////////////////
96
97// Implementation Notes
98//
99// Our implementation of `DistanceProvider` dispatches across:
100//
101// * Data Types
102// * Metric
103// * Dimensions
104// * Runtime Micro-architecture
105//
106// This is a combinatorial explosing of potentially compiled kernels. To get a handle on the
107// sheer number of compiled functions, we manually control the dimensional specialization on
108// a case-by-case basis.
109//
110// This is facilitated by the `specialize!` macro, which accepts a list of dimensions and
111// instantiates the necessary machinery.
112//
113// To explain the machiner a little, a [`Cons`] compile-time list is constructed. This type
114// might look like
115//
116// * `Cons<Spec<100>, Cons<Spec<64>, Spec<32>>>`: To specialize dimensions 100, 64, and 32.
117// * `Cons<Spec<100>, Null>`: To specialize just dimension 100.
118// * `Cons<Null, Null>`: To specialize no dimensions.
119//
120// The `TrySpecialize` trait is then used specialize a kernel `F` for an architecture `A`
121// with implementations
122//
123// * `Spec<N>`: Check if the requested dimension is equal to `N` and if so, return the
124//   specialized method.
125// * `Null`: Never specialize.
126// * `Cons<Head, Tail>`: Try to specialize using `Head` returning if successful. Otherwise,
127//   return the specialization of `Tail`.
128//
129//   This definition is what allows nested `Cons` structures to specialize multiple dimensions.
130//
131//   The `Cons` list also compiles a generic-dimensional fallback if none of the
132//   specializations match.
133//
134// The overall flow is
135//
136// 1. Enter the `DistanceProvider` implementation.
137//
138// 2. First dispatch across micro-architecture using `ArgumentTypes` to hold the data types.
139//    `ArgumentTypes` implements `Target2` to facilitate this dispatch.
140//
141// 3. The implementations of `Target2` for `ArgumentTypes` are performed by the `specialize!`
142//    macro, which creates a `Cons` list of requested specializations, switches across
143//    metrics and invokes `Cons:specialize` on the requested metric.
144
145macro_rules! provider {
146    ($T:ty, $U:ty) => {
147        impl DistanceProvider<$U> for $T {
148            fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<$T, $U> {
149                // Use the `no-features` variant because we do not care if the target gets
150                // compiled for higher micro-architecture levels.
151                //
152                // It's the returned kernel that matters.
153                diskann_wide::arch::dispatch2_no_features(
154                    ArgumentTypes::<$T, $U>::new(),
155                    metric,
156                    dimension,
157                )
158            }
159        }
160    };
161}
162
163provider!(f32, f32);
164provider!(f16, f16);
165provider!(f32, f16);
166provider!(i8, i8);
167provider!(u8, u8);
168
169/////////////////////////
170// Specialization List //
171/////////////////////////
172
173macro_rules! spec_list {
174    ($($Ns:literal),* $(,)?) => {
175        spec_list!(@value, $($Ns,)*)
176    };
177    (@value $(,)?) => {
178        Cons::new(Null, Null)
179    };
180    (@value, $N0:literal $(,)?) => {
181        Cons::new(Spec::<$N0>, Null)
182    };
183    (@value, $N0:literal, $N1:literal $(,)?) => {
184        Cons::new(Spec::<$N0>, Spec::<$N1>)
185    };
186    (@value, $N0:literal, $N1:literal, $($Ns:literal),+ $(,)?) => {
187        Cons::new(Spec::<$N0>, spec_list!(@value, $N1, $($Ns,)+))
188    };
189}
190
191struct ArgumentTypes<T: 'static, U: 'static>(std::marker::PhantomData<(T, U)>);
192
193impl<T, U> ArgumentTypes<T, U>
194where
195    T: 'static,
196    U: 'static,
197{
198    fn new() -> Self {
199        Self(std::marker::PhantomData)
200    }
201}
202
203macro_rules! specialize {
204    ($arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
205        impl diskann_wide::arch::Target2<
206            $arch,
207            Distance<$T, $U>,
208            Metric,
209            Option<usize>,
210        > for ArgumentTypes<$T, $U> {
211            fn run(
212                self,
213                arch: $arch,
214                metric: Metric,
215                dim: Option<usize>,
216            ) -> Distance<$T, $U> {
217                let spec = spec_list!($($Ns),*);
218                match metric {
219                    Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
220                    Metric::Cosine => spec.specialize(arch, Cosine {}, dim),
221                    Metric::CosineNormalized => spec.specialize(arch, CosineNormalized {}, dim),
222                    Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
223                }
224            }
225        }
226    };
227    // Integer types redirect `CosineNormalized` to `Cosine`.
228    (@integer, $arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
229        impl diskann_wide::arch::Target2<
230            $arch,
231            Distance<$T, $U>,
232            Metric,
233            Option<usize>,
234        > for ArgumentTypes<$T, $U> {
235            fn run(
236                self,
237                arch: $arch,
238                metric: Metric,
239                dim: Option<usize>,
240            ) -> Distance<$T, $U> {
241                let spec = spec_list!($($Ns),*);
242                match metric {
243                    Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
244                    Metric::Cosine | Metric::CosineNormalized => {
245                        spec.specialize(arch, Cosine {}, dim)
246                    },
247                    Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
248                }
249            }
250        }
251    };
252}
253
254specialize!(Scalar, f32, f32,);
255specialize!(Scalar, f32, f16,);
256specialize!(Scalar, f16, f16,);
257specialize!(@integer, Scalar, u8, u8,);
258specialize!(@integer, Scalar, i8, i8,);
259
260#[cfg(target_arch = "x86_64")]
261mod x86_64 {
262    use super::*;
263
264    specialize!(V3, f32, f32, 768, 384, 128, 100);
265    specialize!(V4, f32, f32, 768, 384, 128, 100);
266
267    specialize!(V3, f32, f16, 768, 384, 128, 100);
268    specialize!(V4, f32, f16, 768, 384, 128, 100);
269
270    specialize!(V3, f16, f16, 768, 384, 128, 100);
271    specialize!(V4, f16, f16, 768, 384, 128, 100);
272
273    specialize!(@integer, V3, u8, u8, 128);
274    specialize!(@integer, V4, u8, u8, 128);
275
276    specialize!(@integer, V3, i8, i8, 128, 100);
277    specialize!(@integer, V4, i8, i8, 128, 100);
278}
279
280/// Specialize a distance function `F` for the dimension `dim` if possible. Otherwise,
281/// return `None`.
282trait TrySpecialize<A, F, T, U>
283where
284    A: Architecture,
285    T: 'static,
286    U: 'static,
287{
288    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>>;
289}
290
291/// Specialize a distance function for the requested dimensionality.
292#[cfg(target_arch = "x86_64")]
293struct Spec<const N: usize>;
294
295#[cfg(target_arch = "x86_64")]
296impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
297where
298    A: Architecture,
299    Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
300    T: 'static,
301    U: 'static,
302{
303    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
304        if let Some(d) = dim {
305            if d == N {
306                return Some(Distance::new(
307                    // NOTE: This line here is what actually compiles the specialized kernel.
308                    arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
309                ));
310            }
311        }
312        None
313    }
314}
315
316/// Don't specialize at all.
317struct Null;
318
319impl<A, F, T, U> TrySpecialize<A, F, T, U> for Null
320where
321    A: Architecture,
322    T: 'static,
323    U: 'static,
324{
325    fn try_specialize(&self, _arch: A, _dim: Option<usize>) -> Option<Distance<T, U>> {
326        None
327    }
328}
329
330/// A recursive compile-time list for building a list of specializations.
331struct Cons<Head, Tail> {
332    head: Head,
333    tail: Tail,
334}
335
336impl<Head, Tail> Cons<Head, Tail> {
337    const fn new(head: Head, tail: Tail) -> Self {
338        Self { head, tail }
339    }
340
341    /// Try to specialize `F`. If no such specialization is available, return a fallback
342    /// implementation.
343    fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
344    where
345        A: Architecture,
346        F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
347        Head: TrySpecialize<A, F, T, U>,
348        Tail: TrySpecialize<A, F, T, U>,
349        T: 'static,
350        U: 'static,
351    {
352        if let Some(f) = self.try_specialize(arch, dim) {
353            f
354        } else {
355            Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
356        }
357    }
358}
359
360// Try `Head` and then `Tail`.
361impl<A, Head, Tail, F, T, U> TrySpecialize<A, F, T, U> for Cons<Head, Tail>
362where
363    A: Architecture,
364    Head: TrySpecialize<A, F, T, U>,
365    Tail: TrySpecialize<A, F, T, U>,
366    T: 'static,
367    U: 'static,
368{
369    fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
370        if let Some(f) = self.head.try_specialize(arch, dim) {
371            Some(f)
372        } else {
373            self.tail.try_specialize(arch, dim)
374        }
375    }
376}
377
378///////////
379// Tests //
380///////////
381
382#[cfg(test)]
383mod test_unaligned_distance_provider {
384    use approx::assert_relative_eq;
385    use rand::{self, SeedableRng};
386
387    use super::*;
388    use crate::{
389        distance::{reference::ReferenceProvider, Metric},
390        test_util, SimilarityScore,
391    };
392
393    // Comparison Bounds
394    struct EpsilonAndRelative {
395        epsilon: f32,
396        max_relative: f32,
397    }
398
399    /// For now - these are rough bounds selected heuristically.
400    /// Eventually (once we have implementations using compensated arithmetic), we should
401    /// empirically derive bounds based on a combination of
402    ///
403    /// 1. Input Distribution
404    /// 2. Distance Function
405    /// 3. Dimensionality
406    ///
407    /// To ensure that these bounds are tight.
408    fn get_float_bounds(metric: Metric) -> EpsilonAndRelative {
409        match metric {
410            Metric::L2 => EpsilonAndRelative {
411                epsilon: 1e-5,
412                max_relative: 1e-5,
413            },
414            Metric::InnerProduct => EpsilonAndRelative {
415                epsilon: 1e-4,
416                max_relative: 1e-4,
417            },
418            Metric::Cosine => EpsilonAndRelative {
419                epsilon: 1e-4,
420                max_relative: 1e-4,
421            },
422            Metric::CosineNormalized => EpsilonAndRelative {
423                epsilon: 1e-4,
424                max_relative: 1e-4,
425            },
426        }
427    }
428
429    fn get_int_bounds(metric: Metric) -> EpsilonAndRelative {
430        match metric {
431            // Allow for some error when handling the normalization at the end.
432            Metric::Cosine | Metric::CosineNormalized => EpsilonAndRelative {
433                epsilon: 1e-6,
434                max_relative: 1e-6,
435            },
436            // These should be exact.
437            Metric::L2 | Metric::InnerProduct => EpsilonAndRelative {
438                epsilon: 0.0,
439                max_relative: 0.0,
440            },
441        }
442    }
443
444    fn do_test<T, Distribution>(
445        under_test: Distance<T, T>,
446        reference: fn(&[T], &[T]) -> SimilarityScore<f32>,
447        bounds: EpsilonAndRelative,
448        dim: usize,
449        distribution: Distribution,
450    ) where
451        T: test_util::CornerCases,
452        Distribution: test_util::GenerateRandomArguments<T> + Clone,
453    {
454        let mut rng = rand::rngs::StdRng::seed_from_u64(0xef0053c);
455
456        // Unwrap the SimilarityScore for the reference implementation.
457        let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };
458
459        let checker = test_util::Checker::<T, T, f32>::new(
460            |a, b| under_test.call(a, b),
461            converted,
462            |got: f32, expected: f32| {
463                assert_relative_eq!(
464                    got,
465                    expected,
466                    epsilon = bounds.epsilon,
467                    max_relative = bounds.max_relative
468                );
469            },
470        );
471
472        test_util::test_distance_function(
473            checker,
474            distribution.clone(),
475            distribution.clone(),
476            dim,
477            10,
478            &mut rng,
479        );
480    }
481
482    fn all_metrics() -> [Metric; 4] {
483        [
484            Metric::L2,
485            Metric::InnerProduct,
486            Metric::Cosine,
487            Metric::CosineNormalized,
488        ]
489    }
490
491    /// The maximum dimension used for unaligned behavior checking with simple distances.
492    const MAX_DIM: usize = 256;
493
494    #[test]
495    fn test_unaligned_f32() {
496        let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
497        for metric in all_metrics() {
498            for dim in 0..MAX_DIM {
499                println!("Metric = {:?}, dim = {}", metric, dim);
500                let unaligned = <f32 as DistanceProvider<f32>>::distance_comparer(metric, None);
501                let simple = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
502                let bounds = get_float_bounds(metric);
503                do_test(unaligned, simple, bounds, dim, dist);
504            }
505        }
506    }
507
508    #[test]
509    fn test_unaligned_f16() {
510        let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
511        for metric in all_metrics() {
512            for dim in 0..MAX_DIM {
513                println!("Metric = {:?}, dim = {}", metric, dim);
514                let unaligned = <f16 as DistanceProvider<f16>>::distance_comparer(metric, None);
515                let simple = <f16 as ReferenceProvider<f16>>::reference_implementation(metric);
516                let bounds = get_float_bounds(metric);
517                do_test(unaligned, simple, bounds, dim, dist);
518            }
519        }
520    }
521
522    #[test]
523    fn test_unaligned_u8() {
524        let dist = rand::distr::StandardUniform {};
525        for metric in all_metrics() {
526            for dim in 0..MAX_DIM {
527                println!("Metric = {:?}, dim = {}", metric, dim);
528                let unaligned = <u8 as DistanceProvider<u8>>::distance_comparer(metric, None);
529                let simple = <u8 as ReferenceProvider<u8>>::reference_implementation(metric);
530                let bounds = get_int_bounds(metric);
531                do_test(unaligned, simple, bounds, dim, dist);
532            }
533        }
534    }
535
536    #[test]
537    fn test_unaligned_i8() {
538        let dist = rand::distr::StandardUniform {};
539        for metric in all_metrics() {
540            for dim in 0..MAX_DIM {
541                println!("Metric = {:?}, dim = {}", metric, dim);
542                let unaligned = <i8 as DistanceProvider<i8>>::distance_comparer(metric, None);
543                let simple = <i8 as ReferenceProvider<i8>>::reference_implementation(metric);
544
545                let bounds = get_int_bounds(metric);
546                do_test(unaligned, simple, bounds, dim, dist);
547            }
548        }
549    }
550}
551
552#[cfg(test)]
553mod distance_provider_f32_tests {
554    use approx::assert_abs_diff_eq;
555    use rand::{rngs::StdRng, Rng, SeedableRng};
556
557    use super::*;
558    use crate::{distance::reference, test_util::*};
559
560    #[repr(C, align(32))]
561    pub struct F32Slice112([f32; 112]);
562    #[repr(C, align(32))]
563    pub struct F32Slice104([f32; 104]);
564    #[repr(C, align(32))]
565    pub struct F32Slice128([f32; 128]);
566    #[repr(C, align(32))]
567    pub struct F32Slice256([f32; 256]);
568    #[repr(C, align(32))]
569    pub struct F32Slice4096([f32; 4096]);
570
571    pub fn get_turing_test_data_f32_dim(dim: usize) -> (Vec<f32>, Vec<f32>) {
572        let mut a_slice = vec![0.0f32; dim];
573        let mut b_slice = vec![0.0f32; dim];
574
575        let mut rng = StdRng::seed_from_u64(42);
576        for i in 0..dim {
577            a_slice[i] = rng.random_range(-1.0..1.0);
578            b_slice[i] = rng.random_range(-1.0..1.0);
579        }
580
581        ((a_slice), (b_slice))
582    }
583
584    #[test]
585    fn test_dist_l2_float_turing_104() {
586        let (a_data, b_data) = get_turing_test_data_f32_dim(104);
587        let (a_slice, b_slice) = (
588            F32Slice104(a_data.try_into().unwrap()),
589            F32Slice104(b_data.try_into().unwrap()),
590        );
591
592        let distance: f32 = compare_two_vec::<f32>(104, Metric::L2, &a_slice.0, &b_slice.0);
593
594        assert_abs_diff_eq!(
595            distance as f64,
596            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
597            epsilon = 1e-4f64
598        );
599    }
600
601    #[test]
602    fn test_dist_l2_float_turing_112() {
603        let (a_data, b_data) = get_turing_test_data_f32_dim(112);
604        let (a_slice, b_slice) = (
605            F32Slice112(a_data.try_into().unwrap()),
606            F32Slice112(b_data.try_into().unwrap()),
607        );
608
609        let distance: f32 = compare_two_vec::<f32>(112, Metric::L2, &a_slice.0, &b_slice.0);
610
611        assert_abs_diff_eq!(
612            distance as f64,
613            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
614            epsilon = 1e-4f64
615        );
616    }
617
618    #[test]
619    fn test_dist_l2_float_turing_128() {
620        let (a_data, b_data) = get_turing_test_data_f32_dim(128);
621        let (a_slice, b_slice) = (
622            F32Slice128(a_data.try_into().unwrap()),
623            F32Slice128(b_data.try_into().unwrap()),
624        );
625
626        let distance: f32 = compare_two_vec::<f32>(128, Metric::L2, &a_slice.0, &b_slice.0);
627
628        assert_abs_diff_eq!(
629            distance as f64,
630            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
631            epsilon = 1e-4f64
632        );
633    }
634
635    #[test]
636    fn test_dist_l2_float_turing_256() {
637        let (a_data, b_data) = get_turing_test_data_f32_dim(256);
638        let (a_slice, b_slice) = (
639            F32Slice256(a_data.try_into().unwrap()),
640            F32Slice256(b_data.try_into().unwrap()),
641        );
642
643        let distance: f32 = compare_two_vec::<f32>(256, Metric::L2, &a_slice.0, &b_slice.0);
644
645        assert_abs_diff_eq!(
646            distance as f64,
647            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
648            epsilon = 1e-3f64
649        );
650    }
651
652    #[test]
653    fn test_dist_l2_float_turing_4096() {
654        let (a_data, b_data) = get_turing_test_data_f32_dim(4096);
655        let (a_slice, b_slice) = (
656            F32Slice4096(a_data.try_into().unwrap()),
657            F32Slice4096(b_data.try_into().unwrap()),
658        );
659
660        let distance: f32 = compare_two_vec::<f32>(4096, Metric::L2, &a_slice.0, &b_slice.0);
661
662        assert_abs_diff_eq!(
663            distance as f64,
664            no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
665            epsilon = 1e-2f64
666        );
667    }
668
669    #[test]
670    fn test_dist_ip_float_turing_112() {
671        let (a_data, b_data) = get_turing_test_data_f32_dim(112);
672        let (a_slice, b_slice) = (
673            F32Slice112(a_data.try_into().unwrap()),
674            F32Slice112(b_data.try_into().unwrap()),
675        );
676
677        let distance: f32 =
678            compare_two_vec::<f32>(112, Metric::InnerProduct, &a_slice.0, &b_slice.0);
679
680        assert_abs_diff_eq!(
681            distance,
682            reference::reference_innerproduct_f32_similarity(&a_slice.0, &b_slice.0).into_inner(),
683            epsilon = 1e-4f32
684        );
685    }
686
687    #[test]
688    fn distance_test() {
689        #[repr(C, align(32))]
690        struct Vector32ByteAligned {
691            v: [f32; 512],
692        }
693
694        // two vectors are allocated in the contiguous heap memory
695        let two_vec = Box::new(Vector32ByteAligned {
696            v: [
697                69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
698                20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
699                65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
700                95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
701                71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
702                87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
703                40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
704                66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
705                27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
706                26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
707                35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
708                1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
709                72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
710                62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
711                70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
712                35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
713                31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
714                93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
715                94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
716                80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
717                14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
718                14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
719                64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
720                55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
721                69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
722                76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
723                70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
724                87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
725                57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
726                23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
727                34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
728                94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
729                40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
730                11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
731                95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
732                73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
733                39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
734                14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
735                2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
736                18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
737                23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
738                63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
739                26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
740                28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
741                99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
742                77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
743                72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
744                66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
745                77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
746                18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
747                7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
748                11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
749                51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
750                46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
751                6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
752                3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
753                96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
754                12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
755                44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
756                54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
757                26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
758                40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
759                25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
760                34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
761                86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
762                32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
763                9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
764                50.41517, 28.156603, 42.644154,
765            ],
766        });
767
768        let distance: f32 = compare::<f32>(256, Metric::L2, &two_vec.v);
769
770        assert_eq!(distance, 429141.2);
771    }
772
773    fn compare<T>(dim: usize, metric: Metric, v: &[T]) -> f32
774    where
775        T: DistanceProvider<T>,
776    {
777        let distance_comparer = T::distance_comparer(metric, Some(dim));
778        distance_comparer.call(&v[..dim], &v[dim..])
779    }
780
781    pub fn compare_two_vec<T>(dim: usize, metric: Metric, v1: &[T], v2: &[T]) -> f32
782    where
783        T: DistanceProvider<T>,
784    {
785        let distance_comparer = T::distance_comparer(metric, Some(dim));
786        distance_comparer.call(&v1[..dim], &v2[..dim])
787    }
788}
789
790#[cfg(test)]
791mod distance_provider_f16_tests {
792    use approx::assert_abs_diff_eq;
793
794    use super::{distance_provider_f32_tests::get_turing_test_data_f32_dim, *};
795    use crate::{
796        distance::distance_provider::distance_provider_f32_tests::compare_two_vec,
797        test_util::no_vector_compare_f16_as_f64,
798    };
799
800    #[repr(C, align(32))]
801    pub struct F16Slice112([f16; 112]);
802    #[repr(C, align(32))]
803    pub struct F16Slice104([f16; 104]);
804    #[repr(C, align(32))]
805    pub struct F16Slice128([f16; 128]);
806    #[repr(C, align(32))]
807    pub struct F16Slice256([f16; 256]);
808    #[repr(C, align(32))]
809    pub struct F16Slice4096([f16; 4096]);
810
811    fn get_turing_test_data_f16_dim(dim: usize) -> (Vec<f16>, Vec<f16>) {
812        let (a_slice, b_slice) = get_turing_test_data_f32_dim(dim);
813        let a_data = a_slice.iter().map(|x| f16::from_f32(*x)).collect();
814        let b_data = b_slice.iter().map(|x| f16::from_f32(*x)).collect();
815        (a_data, b_data)
816    }
817
818    #[test]
819    fn test_dist_l2_f16_turing_112() {
820        // two vectors are allocated in the contiguous heap memory
821        let (a_data, b_data) = get_turing_test_data_f16_dim(112);
822        let (a_slice, b_slice) = (
823            F16Slice112(a_data.try_into().unwrap()),
824            F16Slice112(b_data.try_into().unwrap()),
825        );
826
827        let distance: f32 = compare_two_vec::<f16>(112, Metric::L2, &a_slice.0, &b_slice.0);
828
829        // Note the variance between the full 32 bit precision and the 16 bit precision
830        assert_abs_diff_eq!(
831            distance as f64,
832            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
833            epsilon = 1e-3f64
834        );
835    }
836
837    #[test]
838    fn test_dist_l2_f16_turing_104() {
839        // two vectors are allocated in the contiguous heap memory
840        let (a_data, b_data) = get_turing_test_data_f16_dim(104);
841        let (a_slice, b_slice) = (
842            F16Slice104(a_data.try_into().unwrap()),
843            F16Slice104(b_data.try_into().unwrap()),
844        );
845
846        let distance: f32 = compare_two_vec::<f16>(104, Metric::L2, &a_slice.0, &b_slice.0);
847
848        // Note the variance between the full 32 bit precision and the 16 bit precision
849        assert_abs_diff_eq!(
850            distance as f64,
851            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
852            epsilon = 1e-3f64
853        );
854    }
855
856    #[test]
857    fn test_dist_l2_f16_turing_256() {
858        // two vectors are allocated in the contiguous heap memory
859        let (a_data, b_data) = get_turing_test_data_f16_dim(256);
860        let (a_slice, b_slice) = (
861            F16Slice256(a_data.try_into().unwrap()),
862            F16Slice256(b_data.try_into().unwrap()),
863        );
864
865        let distance: f32 = compare_two_vec::<f16>(256, Metric::L2, &a_slice.0, &b_slice.0);
866
867        // Note the variance between the full 32 bit precision and the 16 bit precision
868        assert_abs_diff_eq!(
869            distance as f64,
870            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
871            epsilon = 1e-3f64
872        );
873    }
874
875    #[test]
876    fn test_dist_l2_f16_turing_128() {
877        // two vectors are allocated in the contiguous heap memory
878        let (a_data, b_data) = get_turing_test_data_f16_dim(128);
879        let (a_slice, b_slice) = (
880            F16Slice128(a_data.try_into().unwrap()),
881            F16Slice128(b_data.try_into().unwrap()),
882        );
883
884        let distance: f32 = compare_two_vec::<f16>(128, Metric::L2, &a_slice.0, &b_slice.0);
885
886        // Note the variance between the full 32 bit precision and the 16 bit precision
887        assert_abs_diff_eq!(
888            distance as f64,
889            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
890            epsilon = 1e-3f64
891        );
892    }
893
894    #[test]
895    fn test_dist_l2_f16_turing_4096() {
896        // two vectors are allocated in the contiguous heap memory
897        let (a_data, b_data) = get_turing_test_data_f16_dim(4096);
898        let (a_slice, b_slice) = (
899            F16Slice4096(a_data.try_into().unwrap()),
900            F16Slice4096(b_data.try_into().unwrap()),
901        );
902
903        let distance: f32 = compare_two_vec::<f16>(4096, Metric::L2, &a_slice.0, &b_slice.0);
904
905        // Note the variance between the full 32 bit precision and the 16 bit precision
906        assert_abs_diff_eq!(
907            distance as f64,
908            no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
909            epsilon = 1e-2f64
910        );
911    }
912
913    #[test]
914    fn test_dist_l2_f16_produces_nan_distance_for_infinity_vectors() {
915        let a_data = vec![f16::INFINITY; 384];
916        let b_data = vec![f16::INFINITY; 384];
917
918        let distance: f32 = compare_two_vec::<f16>(384, Metric::L2, &a_data, &b_data);
919        assert!(distance.is_nan());
920    }
921}