1#[cfg(target_arch = "x86_64")]
7use diskann_wide::arch::x86_64::{V3, V4};
8use diskann_wide::{
9 arch::{Dispatched2, FTarget2, Scalar},
10 lifetime::Ref,
11 Architecture,
12};
13use half::f16;
14
15use super::{Cosine, CosineNormalized, InnerProduct, SquaredL2};
16use crate::distance::Metric;
17
18#[cfg(target_arch = "x86_64")]
19use super::implementations::Specialize;
20
21pub trait DistanceProvider<T>: Sized + 'static {
47 fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
48}
49
50#[derive(Debug, Clone, Copy)]
54pub struct Distance<T, U>
55where
56 T: 'static,
57 U: 'static,
58{
59 f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
60}
61
62impl<T, U> Distance<T, U>
63where
64 T: 'static,
65 U: 'static,
66{
67 fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
68 Self { f }
69 }
70
71 #[inline]
78 pub fn call(&self, x: &[T], y: &[U]) -> f32 {
79 self.f.call(x, y)
80 }
81}
82
83impl<T, U> crate::DistanceFunction<&[T], &[U], f32> for Distance<T, U>
84where
85 T: 'static,
86 U: 'static,
87{
88 fn evaluate_similarity(&self, x: &[T], y: &[U]) -> f32 {
89 self.call(x, y)
90 }
91}
92
93macro_rules! provider {
146 ($T:ty, $U:ty) => {
147 impl DistanceProvider<$U> for $T {
148 fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<$T, $U> {
149 diskann_wide::arch::dispatch2_no_features(
154 ArgumentTypes::<$T, $U>::new(),
155 metric,
156 dimension,
157 )
158 }
159 }
160 };
161}
162
163provider!(f32, f32);
164provider!(f16, f16);
165provider!(f32, f16);
166provider!(i8, i8);
167provider!(u8, u8);
168
169macro_rules! spec_list {
174 ($($Ns:literal),* $(,)?) => {
175 spec_list!(@value, $($Ns,)*)
176 };
177 (@value $(,)?) => {
178 Cons::new(Null, Null)
179 };
180 (@value, $N0:literal $(,)?) => {
181 Cons::new(Spec::<$N0>, Null)
182 };
183 (@value, $N0:literal, $N1:literal $(,)?) => {
184 Cons::new(Spec::<$N0>, Spec::<$N1>)
185 };
186 (@value, $N0:literal, $N1:literal, $($Ns:literal),+ $(,)?) => {
187 Cons::new(Spec::<$N0>, spec_list!(@value, $N1, $($Ns,)+))
188 };
189}
190
191struct ArgumentTypes<T: 'static, U: 'static>(std::marker::PhantomData<(T, U)>);
192
193impl<T, U> ArgumentTypes<T, U>
194where
195 T: 'static,
196 U: 'static,
197{
198 fn new() -> Self {
199 Self(std::marker::PhantomData)
200 }
201}
202
203macro_rules! specialize {
204 ($arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
205 impl diskann_wide::arch::Target2<
206 $arch,
207 Distance<$T, $U>,
208 Metric,
209 Option<usize>,
210 > for ArgumentTypes<$T, $U> {
211 fn run(
212 self,
213 arch: $arch,
214 metric: Metric,
215 dim: Option<usize>,
216 ) -> Distance<$T, $U> {
217 let spec = spec_list!($($Ns),*);
218 match metric {
219 Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
220 Metric::Cosine => spec.specialize(arch, Cosine {}, dim),
221 Metric::CosineNormalized => spec.specialize(arch, CosineNormalized {}, dim),
222 Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
223 }
224 }
225 }
226 };
227 (@integer, $arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
229 impl diskann_wide::arch::Target2<
230 $arch,
231 Distance<$T, $U>,
232 Metric,
233 Option<usize>,
234 > for ArgumentTypes<$T, $U> {
235 fn run(
236 self,
237 arch: $arch,
238 metric: Metric,
239 dim: Option<usize>,
240 ) -> Distance<$T, $U> {
241 let spec = spec_list!($($Ns),*);
242 match metric {
243 Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
244 Metric::Cosine | Metric::CosineNormalized => {
245 spec.specialize(arch, Cosine {}, dim)
246 },
247 Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
248 }
249 }
250 }
251 };
252}
253
254specialize!(Scalar, f32, f32,);
255specialize!(Scalar, f32, f16,);
256specialize!(Scalar, f16, f16,);
257specialize!(@integer, Scalar, u8, u8,);
258specialize!(@integer, Scalar, i8, i8,);
259
260#[cfg(target_arch = "x86_64")]
261mod x86_64 {
262 use super::*;
263
264 specialize!(V3, f32, f32, 768, 384, 128, 100);
265 specialize!(V4, f32, f32, 768, 384, 128, 100);
266
267 specialize!(V3, f32, f16, 768, 384, 128, 100);
268 specialize!(V4, f32, f16, 768, 384, 128, 100);
269
270 specialize!(V3, f16, f16, 768, 384, 128, 100);
271 specialize!(V4, f16, f16, 768, 384, 128, 100);
272
273 specialize!(@integer, V3, u8, u8, 128);
274 specialize!(@integer, V4, u8, u8, 128);
275
276 specialize!(@integer, V3, i8, i8, 128, 100);
277 specialize!(@integer, V4, i8, i8, 128, 100);
278}
279
280trait TrySpecialize<A, F, T, U>
283where
284 A: Architecture,
285 T: 'static,
286 U: 'static,
287{
288 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>>;
289}
290
291#[cfg(target_arch = "x86_64")]
293struct Spec<const N: usize>;
294
295#[cfg(target_arch = "x86_64")]
296impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
297where
298 A: Architecture,
299 Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
300 T: 'static,
301 U: 'static,
302{
303 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
304 if let Some(d) = dim {
305 if d == N {
306 return Some(Distance::new(
307 arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
309 ));
310 }
311 }
312 None
313 }
314}
315
316struct Null;
318
319impl<A, F, T, U> TrySpecialize<A, F, T, U> for Null
320where
321 A: Architecture,
322 T: 'static,
323 U: 'static,
324{
325 fn try_specialize(&self, _arch: A, _dim: Option<usize>) -> Option<Distance<T, U>> {
326 None
327 }
328}
329
330struct Cons<Head, Tail> {
332 head: Head,
333 tail: Tail,
334}
335
336impl<Head, Tail> Cons<Head, Tail> {
337 const fn new(head: Head, tail: Tail) -> Self {
338 Self { head, tail }
339 }
340
341 fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
344 where
345 A: Architecture,
346 F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
347 Head: TrySpecialize<A, F, T, U>,
348 Tail: TrySpecialize<A, F, T, U>,
349 T: 'static,
350 U: 'static,
351 {
352 if let Some(f) = self.try_specialize(arch, dim) {
353 f
354 } else {
355 Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
356 }
357 }
358}
359
360impl<A, Head, Tail, F, T, U> TrySpecialize<A, F, T, U> for Cons<Head, Tail>
362where
363 A: Architecture,
364 Head: TrySpecialize<A, F, T, U>,
365 Tail: TrySpecialize<A, F, T, U>,
366 T: 'static,
367 U: 'static,
368{
369 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
370 if let Some(f) = self.head.try_specialize(arch, dim) {
371 Some(f)
372 } else {
373 self.tail.try_specialize(arch, dim)
374 }
375 }
376}
377
378#[cfg(test)]
383mod test_unaligned_distance_provider {
384 use approx::assert_relative_eq;
385 use rand::{self, SeedableRng};
386
387 use super::*;
388 use crate::{
389 distance::{reference::ReferenceProvider, Metric},
390 test_util, SimilarityScore,
391 };
392
393 struct EpsilonAndRelative {
395 epsilon: f32,
396 max_relative: f32,
397 }
398
399 fn get_float_bounds(metric: Metric) -> EpsilonAndRelative {
409 match metric {
410 Metric::L2 => EpsilonAndRelative {
411 epsilon: 1e-5,
412 max_relative: 1e-5,
413 },
414 Metric::InnerProduct => EpsilonAndRelative {
415 epsilon: 1e-4,
416 max_relative: 1e-4,
417 },
418 Metric::Cosine => EpsilonAndRelative {
419 epsilon: 1e-4,
420 max_relative: 1e-4,
421 },
422 Metric::CosineNormalized => EpsilonAndRelative {
423 epsilon: 1e-4,
424 max_relative: 1e-4,
425 },
426 }
427 }
428
429 fn get_int_bounds(metric: Metric) -> EpsilonAndRelative {
430 match metric {
431 Metric::Cosine | Metric::CosineNormalized => EpsilonAndRelative {
433 epsilon: 1e-6,
434 max_relative: 1e-6,
435 },
436 Metric::L2 | Metric::InnerProduct => EpsilonAndRelative {
438 epsilon: 0.0,
439 max_relative: 0.0,
440 },
441 }
442 }
443
444 fn do_test<T, Distribution>(
445 under_test: Distance<T, T>,
446 reference: fn(&[T], &[T]) -> SimilarityScore<f32>,
447 bounds: EpsilonAndRelative,
448 dim: usize,
449 distribution: Distribution,
450 ) where
451 T: test_util::CornerCases,
452 Distribution: test_util::GenerateRandomArguments<T> + Clone,
453 {
454 let mut rng = rand::rngs::StdRng::seed_from_u64(0xef0053c);
455
456 let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };
458
459 let checker = test_util::Checker::<T, T, f32>::new(
460 |a, b| under_test.call(a, b),
461 converted,
462 |got: f32, expected: f32| {
463 assert_relative_eq!(
464 got,
465 expected,
466 epsilon = bounds.epsilon,
467 max_relative = bounds.max_relative
468 );
469 },
470 );
471
472 test_util::test_distance_function(
473 checker,
474 distribution.clone(),
475 distribution.clone(),
476 dim,
477 10,
478 &mut rng,
479 );
480 }
481
482 fn all_metrics() -> [Metric; 4] {
483 [
484 Metric::L2,
485 Metric::InnerProduct,
486 Metric::Cosine,
487 Metric::CosineNormalized,
488 ]
489 }
490
491 const MAX_DIM: usize = 256;
493
494 #[test]
495 fn test_unaligned_f32() {
496 let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
497 for metric in all_metrics() {
498 for dim in 0..MAX_DIM {
499 println!("Metric = {:?}, dim = {}", metric, dim);
500 let unaligned = <f32 as DistanceProvider<f32>>::distance_comparer(metric, None);
501 let simple = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
502 let bounds = get_float_bounds(metric);
503 do_test(unaligned, simple, bounds, dim, dist);
504 }
505 }
506 }
507
508 #[test]
509 fn test_unaligned_f16() {
510 let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
511 for metric in all_metrics() {
512 for dim in 0..MAX_DIM {
513 println!("Metric = {:?}, dim = {}", metric, dim);
514 let unaligned = <f16 as DistanceProvider<f16>>::distance_comparer(metric, None);
515 let simple = <f16 as ReferenceProvider<f16>>::reference_implementation(metric);
516 let bounds = get_float_bounds(metric);
517 do_test(unaligned, simple, bounds, dim, dist);
518 }
519 }
520 }
521
522 #[test]
523 fn test_unaligned_u8() {
524 let dist = rand::distr::StandardUniform {};
525 for metric in all_metrics() {
526 for dim in 0..MAX_DIM {
527 println!("Metric = {:?}, dim = {}", metric, dim);
528 let unaligned = <u8 as DistanceProvider<u8>>::distance_comparer(metric, None);
529 let simple = <u8 as ReferenceProvider<u8>>::reference_implementation(metric);
530 let bounds = get_int_bounds(metric);
531 do_test(unaligned, simple, bounds, dim, dist);
532 }
533 }
534 }
535
536 #[test]
537 fn test_unaligned_i8() {
538 let dist = rand::distr::StandardUniform {};
539 for metric in all_metrics() {
540 for dim in 0..MAX_DIM {
541 println!("Metric = {:?}, dim = {}", metric, dim);
542 let unaligned = <i8 as DistanceProvider<i8>>::distance_comparer(metric, None);
543 let simple = <i8 as ReferenceProvider<i8>>::reference_implementation(metric);
544
545 let bounds = get_int_bounds(metric);
546 do_test(unaligned, simple, bounds, dim, dist);
547 }
548 }
549 }
550}
551
552#[cfg(test)]
553mod distance_provider_f32_tests {
554 use approx::assert_abs_diff_eq;
555 use rand::{rngs::StdRng, Rng, SeedableRng};
556
557 use super::*;
558 use crate::{distance::reference, test_util::*};
559
560 #[repr(C, align(32))]
561 pub struct F32Slice112([f32; 112]);
562 #[repr(C, align(32))]
563 pub struct F32Slice104([f32; 104]);
564 #[repr(C, align(32))]
565 pub struct F32Slice128([f32; 128]);
566 #[repr(C, align(32))]
567 pub struct F32Slice256([f32; 256]);
568 #[repr(C, align(32))]
569 pub struct F32Slice4096([f32; 4096]);
570
571 pub fn get_turing_test_data_f32_dim(dim: usize) -> (Vec<f32>, Vec<f32>) {
572 let mut a_slice = vec![0.0f32; dim];
573 let mut b_slice = vec![0.0f32; dim];
574
575 let mut rng = StdRng::seed_from_u64(42);
576 for i in 0..dim {
577 a_slice[i] = rng.random_range(-1.0..1.0);
578 b_slice[i] = rng.random_range(-1.0..1.0);
579 }
580
581 ((a_slice), (b_slice))
582 }
583
584 #[test]
585 fn test_dist_l2_float_turing_104() {
586 let (a_data, b_data) = get_turing_test_data_f32_dim(104);
587 let (a_slice, b_slice) = (
588 F32Slice104(a_data.try_into().unwrap()),
589 F32Slice104(b_data.try_into().unwrap()),
590 );
591
592 let distance: f32 = compare_two_vec::<f32>(104, Metric::L2, &a_slice.0, &b_slice.0);
593
594 assert_abs_diff_eq!(
595 distance as f64,
596 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
597 epsilon = 1e-4f64
598 );
599 }
600
601 #[test]
602 fn test_dist_l2_float_turing_112() {
603 let (a_data, b_data) = get_turing_test_data_f32_dim(112);
604 let (a_slice, b_slice) = (
605 F32Slice112(a_data.try_into().unwrap()),
606 F32Slice112(b_data.try_into().unwrap()),
607 );
608
609 let distance: f32 = compare_two_vec::<f32>(112, Metric::L2, &a_slice.0, &b_slice.0);
610
611 assert_abs_diff_eq!(
612 distance as f64,
613 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
614 epsilon = 1e-4f64
615 );
616 }
617
618 #[test]
619 fn test_dist_l2_float_turing_128() {
620 let (a_data, b_data) = get_turing_test_data_f32_dim(128);
621 let (a_slice, b_slice) = (
622 F32Slice128(a_data.try_into().unwrap()),
623 F32Slice128(b_data.try_into().unwrap()),
624 );
625
626 let distance: f32 = compare_two_vec::<f32>(128, Metric::L2, &a_slice.0, &b_slice.0);
627
628 assert_abs_diff_eq!(
629 distance as f64,
630 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
631 epsilon = 1e-4f64
632 );
633 }
634
635 #[test]
636 fn test_dist_l2_float_turing_256() {
637 let (a_data, b_data) = get_turing_test_data_f32_dim(256);
638 let (a_slice, b_slice) = (
639 F32Slice256(a_data.try_into().unwrap()),
640 F32Slice256(b_data.try_into().unwrap()),
641 );
642
643 let distance: f32 = compare_two_vec::<f32>(256, Metric::L2, &a_slice.0, &b_slice.0);
644
645 assert_abs_diff_eq!(
646 distance as f64,
647 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
648 epsilon = 1e-3f64
649 );
650 }
651
652 #[test]
653 fn test_dist_l2_float_turing_4096() {
654 let (a_data, b_data) = get_turing_test_data_f32_dim(4096);
655 let (a_slice, b_slice) = (
656 F32Slice4096(a_data.try_into().unwrap()),
657 F32Slice4096(b_data.try_into().unwrap()),
658 );
659
660 let distance: f32 = compare_two_vec::<f32>(4096, Metric::L2, &a_slice.0, &b_slice.0);
661
662 assert_abs_diff_eq!(
663 distance as f64,
664 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
665 epsilon = 1e-2f64
666 );
667 }
668
669 #[test]
670 fn test_dist_ip_float_turing_112() {
671 let (a_data, b_data) = get_turing_test_data_f32_dim(112);
672 let (a_slice, b_slice) = (
673 F32Slice112(a_data.try_into().unwrap()),
674 F32Slice112(b_data.try_into().unwrap()),
675 );
676
677 let distance: f32 =
678 compare_two_vec::<f32>(112, Metric::InnerProduct, &a_slice.0, &b_slice.0);
679
680 assert_abs_diff_eq!(
681 distance,
682 reference::reference_innerproduct_f32_similarity(&a_slice.0, &b_slice.0).into_inner(),
683 epsilon = 1e-4f32
684 );
685 }
686
687 #[test]
688 fn distance_test() {
689 #[repr(C, align(32))]
690 struct Vector32ByteAligned {
691 v: [f32; 512],
692 }
693
694 let two_vec = Box::new(Vector32ByteAligned {
696 v: [
697 69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
698 20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
699 65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
700 95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
701 71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
702 87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
703 40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
704 66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
705 27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
706 26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
707 35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
708 1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
709 72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
710 62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
711 70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
712 35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
713 31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
714 93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
715 94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
716 80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
717 14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
718 14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
719 64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
720 55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
721 69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
722 76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
723 70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
724 87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
725 57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
726 23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
727 34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
728 94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
729 40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
730 11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
731 95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
732 73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
733 39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
734 14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
735 2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
736 18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
737 23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
738 63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
739 26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
740 28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
741 99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
742 77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
743 72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
744 66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
745 77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
746 18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
747 7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
748 11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
749 51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
750 46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
751 6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
752 3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
753 96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
754 12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
755 44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
756 54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
757 26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
758 40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
759 25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
760 34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
761 86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
762 32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
763 9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
764 50.41517, 28.156603, 42.644154,
765 ],
766 });
767
768 let distance: f32 = compare::<f32>(256, Metric::L2, &two_vec.v);
769
770 assert_eq!(distance, 429141.2);
771 }
772
773 fn compare<T>(dim: usize, metric: Metric, v: &[T]) -> f32
774 where
775 T: DistanceProvider<T>,
776 {
777 let distance_comparer = T::distance_comparer(metric, Some(dim));
778 distance_comparer.call(&v[..dim], &v[dim..])
779 }
780
781 pub fn compare_two_vec<T>(dim: usize, metric: Metric, v1: &[T], v2: &[T]) -> f32
782 where
783 T: DistanceProvider<T>,
784 {
785 let distance_comparer = T::distance_comparer(metric, Some(dim));
786 distance_comparer.call(&v1[..dim], &v2[..dim])
787 }
788}
789
790#[cfg(test)]
791mod distance_provider_f16_tests {
792 use approx::assert_abs_diff_eq;
793
794 use super::{distance_provider_f32_tests::get_turing_test_data_f32_dim, *};
795 use crate::{
796 distance::distance_provider::distance_provider_f32_tests::compare_two_vec,
797 test_util::no_vector_compare_f16_as_f64,
798 };
799
800 #[repr(C, align(32))]
801 pub struct F16Slice112([f16; 112]);
802 #[repr(C, align(32))]
803 pub struct F16Slice104([f16; 104]);
804 #[repr(C, align(32))]
805 pub struct F16Slice128([f16; 128]);
806 #[repr(C, align(32))]
807 pub struct F16Slice256([f16; 256]);
808 #[repr(C, align(32))]
809 pub struct F16Slice4096([f16; 4096]);
810
811 fn get_turing_test_data_f16_dim(dim: usize) -> (Vec<f16>, Vec<f16>) {
812 let (a_slice, b_slice) = get_turing_test_data_f32_dim(dim);
813 let a_data = a_slice.iter().map(|x| f16::from_f32(*x)).collect();
814 let b_data = b_slice.iter().map(|x| f16::from_f32(*x)).collect();
815 (a_data, b_data)
816 }
817
818 #[test]
819 fn test_dist_l2_f16_turing_112() {
820 let (a_data, b_data) = get_turing_test_data_f16_dim(112);
822 let (a_slice, b_slice) = (
823 F16Slice112(a_data.try_into().unwrap()),
824 F16Slice112(b_data.try_into().unwrap()),
825 );
826
827 let distance: f32 = compare_two_vec::<f16>(112, Metric::L2, &a_slice.0, &b_slice.0);
828
829 assert_abs_diff_eq!(
831 distance as f64,
832 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
833 epsilon = 1e-3f64
834 );
835 }
836
837 #[test]
838 fn test_dist_l2_f16_turing_104() {
839 let (a_data, b_data) = get_turing_test_data_f16_dim(104);
841 let (a_slice, b_slice) = (
842 F16Slice104(a_data.try_into().unwrap()),
843 F16Slice104(b_data.try_into().unwrap()),
844 );
845
846 let distance: f32 = compare_two_vec::<f16>(104, Metric::L2, &a_slice.0, &b_slice.0);
847
848 assert_abs_diff_eq!(
850 distance as f64,
851 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
852 epsilon = 1e-3f64
853 );
854 }
855
856 #[test]
857 fn test_dist_l2_f16_turing_256() {
858 let (a_data, b_data) = get_turing_test_data_f16_dim(256);
860 let (a_slice, b_slice) = (
861 F16Slice256(a_data.try_into().unwrap()),
862 F16Slice256(b_data.try_into().unwrap()),
863 );
864
865 let distance: f32 = compare_two_vec::<f16>(256, Metric::L2, &a_slice.0, &b_slice.0);
866
867 assert_abs_diff_eq!(
869 distance as f64,
870 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
871 epsilon = 1e-3f64
872 );
873 }
874
875 #[test]
876 fn test_dist_l2_f16_turing_128() {
877 let (a_data, b_data) = get_turing_test_data_f16_dim(128);
879 let (a_slice, b_slice) = (
880 F16Slice128(a_data.try_into().unwrap()),
881 F16Slice128(b_data.try_into().unwrap()),
882 );
883
884 let distance: f32 = compare_two_vec::<f16>(128, Metric::L2, &a_slice.0, &b_slice.0);
885
886 assert_abs_diff_eq!(
888 distance as f64,
889 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
890 epsilon = 1e-3f64
891 );
892 }
893
894 #[test]
895 fn test_dist_l2_f16_turing_4096() {
896 let (a_data, b_data) = get_turing_test_data_f16_dim(4096);
898 let (a_slice, b_slice) = (
899 F16Slice4096(a_data.try_into().unwrap()),
900 F16Slice4096(b_data.try_into().unwrap()),
901 );
902
903 let distance: f32 = compare_two_vec::<f16>(4096, Metric::L2, &a_slice.0, &b_slice.0);
904
905 assert_abs_diff_eq!(
907 distance as f64,
908 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
909 epsilon = 1e-2f64
910 );
911 }
912
913 #[test]
914 fn test_dist_l2_f16_produces_nan_distance_for_infinity_vectors() {
915 let a_data = vec![f16::INFINITY; 384];
916 let b_data = vec![f16::INFINITY; 384];
917
918 let distance: f32 = compare_two_vec::<f16>(384, Metric::L2, &a_data, &b_data);
919 assert!(distance.is_nan());
920 }
921}