1#[cfg(target_arch = "aarch64")]
7use diskann_wide::arch::aarch64::Neon;
8#[cfg(target_arch = "x86_64")]
9use diskann_wide::arch::x86_64::{V3, V4};
10use diskann_wide::{
11 arch::{Dispatched2, FTarget2, Scalar},
12 lifetime::Ref,
13 Architecture,
14};
15use half::f16;
16
17use super::{implementations::Specialize, Cosine, CosineNormalized, InnerProduct, SquaredL2};
18use crate::distance::Metric;
19
20pub trait DistanceProvider<T>: Sized + 'static {
46 fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<Self, T>;
47}
48
49#[derive(Debug, Clone, Copy)]
53pub struct Distance<T, U>
54where
55 T: 'static,
56 U: 'static,
57{
58 f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>,
59}
60
61impl<T, U> Distance<T, U>
62where
63 T: 'static,
64 U: 'static,
65{
66 fn new(f: Dispatched2<f32, Ref<[T]>, Ref<[U]>>) -> Self {
67 Self { f }
68 }
69
70 #[inline]
77 pub fn call(&self, x: &[T], y: &[U]) -> f32 {
78 self.f.call(x, y)
79 }
80}
81
82impl<T, U> crate::DistanceFunction<&[T], &[U], f32> for Distance<T, U>
83where
84 T: 'static,
85 U: 'static,
86{
87 fn evaluate_similarity(&self, x: &[T], y: &[U]) -> f32 {
88 self.call(x, y)
89 }
90}
91
92macro_rules! provider {
145 ($T:ty, $U:ty) => {
146 impl DistanceProvider<$U> for $T {
147 fn distance_comparer(metric: Metric, dimension: Option<usize>) -> Distance<$T, $U> {
148 diskann_wide::arch::dispatch2_no_features(
153 ArgumentTypes::<$T, $U>::new(),
154 metric,
155 dimension,
156 )
157 }
158 }
159 };
160}
161
162provider!(f32, f32);
163provider!(f16, f16);
164provider!(f32, f16);
165provider!(i8, i8);
166provider!(u8, u8);
167
168macro_rules! spec_list {
173 ($($Ns:literal),* $(,)?) => {
174 spec_list!(@value, $($Ns,)*)
175 };
176 (@value $(,)?) => {
177 Cons::new(Null, Null)
178 };
179 (@value, $N0:literal $(,)?) => {
180 Cons::new(Spec::<$N0>, Null)
181 };
182 (@value, $N0:literal, $N1:literal $(,)?) => {
183 Cons::new(Spec::<$N0>, Spec::<$N1>)
184 };
185 (@value, $N0:literal, $N1:literal, $($Ns:literal),+ $(,)?) => {
186 Cons::new(Spec::<$N0>, spec_list!(@value, $N1, $($Ns,)+))
187 };
188}
189
190struct ArgumentTypes<T: 'static, U: 'static>(std::marker::PhantomData<(T, U)>);
191
192impl<T, U> ArgumentTypes<T, U>
193where
194 T: 'static,
195 U: 'static,
196{
197 fn new() -> Self {
198 Self(std::marker::PhantomData)
199 }
200}
201
202macro_rules! specialize {
203 ($arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
204 impl diskann_wide::arch::Target2<
205 $arch,
206 Distance<$T, $U>,
207 Metric,
208 Option<usize>,
209 > for ArgumentTypes<$T, $U> {
210 fn run(
211 self,
212 arch: $arch,
213 metric: Metric,
214 dim: Option<usize>,
215 ) -> Distance<$T, $U> {
216 let spec = spec_list!($($Ns),*);
217 match metric {
218 Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
219 Metric::Cosine => spec.specialize(arch, Cosine {}, dim),
220 Metric::CosineNormalized => spec.specialize(arch, CosineNormalized {}, dim),
221 Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
222 }
223 }
224 }
225 };
226 (@integer, $arch:ty, $T:ty, $U:ty, $($Ns:literal),* $(,)?) => {
228 impl diskann_wide::arch::Target2<
229 $arch,
230 Distance<$T, $U>,
231 Metric,
232 Option<usize>,
233 > for ArgumentTypes<$T, $U> {
234 fn run(
235 self,
236 arch: $arch,
237 metric: Metric,
238 dim: Option<usize>,
239 ) -> Distance<$T, $U> {
240 let spec = spec_list!($($Ns),*);
241 match metric {
242 Metric::L2 => spec.specialize(arch, SquaredL2 {}, dim),
243 Metric::Cosine | Metric::CosineNormalized => {
244 spec.specialize(arch, Cosine {}, dim)
245 },
246 Metric::InnerProduct => spec.specialize(arch, InnerProduct {}, dim),
247 }
248 }
249 }
250 };
251}
252
253specialize!(Scalar, f32, f32,);
254specialize!(Scalar, f32, f16,);
255specialize!(Scalar, f16, f16,);
256specialize!(@integer, Scalar, u8, u8,);
257specialize!(@integer, Scalar, i8, i8,);
258
259#[cfg(target_arch = "x86_64")]
260mod x86_64 {
261 use super::*;
262
263 specialize!(V3, f32, f32, 768, 384, 128, 100);
264 specialize!(V4, f32, f32, 768, 384, 128, 100);
265
266 specialize!(V3, f32, f16, 768, 384, 128, 100);
267 specialize!(V4, f32, f16, 768, 384, 128, 100);
268
269 specialize!(V3, f16, f16, 768, 384, 128, 100);
270 specialize!(V4, f16, f16, 768, 384, 128, 100);
271
272 specialize!(@integer, V3, u8, u8, 128);
273 specialize!(@integer, V4, u8, u8, 128);
274
275 specialize!(@integer, V3, i8, i8, 128, 100);
276 specialize!(@integer, V4, i8, i8, 128, 100);
277}
278
279#[cfg(target_arch = "aarch64")]
280mod aarch64 {
281 use super::*;
282
283 specialize!(Neon, f32, f32, 768, 384, 128, 100);
284 specialize!(Neon, f32, f16, 768, 384, 128, 100);
285 specialize!(Neon, f16, f16, 768, 384, 128, 100);
286
287 specialize!(@integer, Neon, u8, u8, 128);
288 specialize!(@integer, Neon, i8, i8, 128, 100);
289}
290
291trait TrySpecialize<A, F, T, U>
294where
295 A: Architecture,
296 T: 'static,
297 U: 'static,
298{
299 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>>;
300}
301
302struct Spec<const N: usize>;
304
305impl<A, F, const N: usize, T, U> TrySpecialize<A, F, T, U> for Spec<N>
306where
307 A: Architecture,
308 Specialize<N, F>: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
309 T: 'static,
310 U: 'static,
311{
312 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
313 if let Some(d) = dim {
314 if d == N {
315 return Some(Distance::new(
316 arch.dispatch2::<Specialize<N, F>, f32, Ref<[T]>, Ref<[U]>>(),
318 ));
319 }
320 }
321 None
322 }
323}
324
325struct Null;
327
328impl<A, F, T, U> TrySpecialize<A, F, T, U> for Null
329where
330 A: Architecture,
331 T: 'static,
332 U: 'static,
333{
334 fn try_specialize(&self, _arch: A, _dim: Option<usize>) -> Option<Distance<T, U>> {
335 None
336 }
337}
338
339struct Cons<Head, Tail> {
341 head: Head,
342 tail: Tail,
343}
344
345impl<Head, Tail> Cons<Head, Tail> {
346 const fn new(head: Head, tail: Tail) -> Self {
347 Self { head, tail }
348 }
349
350 fn specialize<A, F, T, U>(&self, arch: A, _f: F, dim: Option<usize>) -> Distance<T, U>
353 where
354 A: Architecture,
355 F: for<'a, 'b> FTarget2<A, f32, &'a [T], &'b [U]>,
356 Head: TrySpecialize<A, F, T, U>,
357 Tail: TrySpecialize<A, F, T, U>,
358 T: 'static,
359 U: 'static,
360 {
361 if let Some(f) = self.try_specialize(arch, dim) {
362 f
363 } else {
364 Distance::new(arch.dispatch2::<F, f32, Ref<[T]>, Ref<[U]>>())
365 }
366 }
367}
368
369impl<A, Head, Tail, F, T, U> TrySpecialize<A, F, T, U> for Cons<Head, Tail>
371where
372 A: Architecture,
373 Head: TrySpecialize<A, F, T, U>,
374 Tail: TrySpecialize<A, F, T, U>,
375 T: 'static,
376 U: 'static,
377{
378 fn try_specialize(&self, arch: A, dim: Option<usize>) -> Option<Distance<T, U>> {
379 if let Some(f) = self.head.try_specialize(arch, dim) {
380 Some(f)
381 } else {
382 self.tail.try_specialize(arch, dim)
383 }
384 }
385}
386
387#[cfg(test)]
392mod test_unaligned_distance_provider {
393 use approx::assert_relative_eq;
394 use rand::{self, SeedableRng};
395
396 use super::*;
397 use crate::{
398 distance::{reference::ReferenceProvider, Metric},
399 test_util, SimilarityScore,
400 };
401
402 struct EpsilonAndRelative {
404 epsilon: f32,
405 max_relative: f32,
406 }
407
408 fn get_float_bounds(metric: Metric) -> EpsilonAndRelative {
418 match metric {
419 Metric::L2 => EpsilonAndRelative {
420 epsilon: 1e-5,
421 max_relative: 1e-5,
422 },
423 Metric::InnerProduct => EpsilonAndRelative {
424 epsilon: 1e-4,
425 max_relative: 1e-4,
426 },
427 Metric::Cosine => EpsilonAndRelative {
428 epsilon: 1e-4,
429 max_relative: 1e-4,
430 },
431 Metric::CosineNormalized => EpsilonAndRelative {
432 epsilon: 1e-4,
433 max_relative: 1e-4,
434 },
435 }
436 }
437
438 fn get_int_bounds(metric: Metric) -> EpsilonAndRelative {
439 match metric {
440 Metric::Cosine | Metric::CosineNormalized => EpsilonAndRelative {
442 epsilon: 1e-6,
443 max_relative: 1e-6,
444 },
445 Metric::L2 | Metric::InnerProduct => EpsilonAndRelative {
447 epsilon: 0.0,
448 max_relative: 0.0,
449 },
450 }
451 }
452
453 fn do_test<T, Distribution>(
454 under_test: Distance<T, T>,
455 reference: fn(&[T], &[T]) -> SimilarityScore<f32>,
456 bounds: EpsilonAndRelative,
457 dim: usize,
458 distribution: Distribution,
459 ) where
460 T: test_util::CornerCases,
461 Distribution: test_util::GenerateRandomArguments<T> + Clone,
462 {
463 let mut rng = rand::rngs::StdRng::seed_from_u64(0xef0053c);
464
465 let converted = |a: &[T], b: &[T]| -> f32 { reference(a, b).into_inner() };
467
468 let checker = test_util::Checker::<T, T, f32>::new(
469 |a, b| under_test.call(a, b),
470 converted,
471 |got: f32, expected: f32| {
472 assert_relative_eq!(
473 got,
474 expected,
475 epsilon = bounds.epsilon,
476 max_relative = bounds.max_relative
477 );
478 },
479 );
480
481 test_util::test_distance_function(
482 checker,
483 distribution.clone(),
484 distribution.clone(),
485 dim,
486 10,
487 &mut rng,
488 );
489 }
490
491 fn all_metrics() -> [Metric; 4] {
492 [
493 Metric::L2,
494 Metric::InnerProduct,
495 Metric::Cosine,
496 Metric::CosineNormalized,
497 ]
498 }
499
500 const MAX_DIM: usize = 256;
502
503 #[test]
504 fn test_unaligned_f32() {
505 let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
506 for metric in all_metrics() {
507 for dim in 0..MAX_DIM {
508 println!("Metric = {:?}, dim = {}", metric, dim);
509 let unaligned = <f32 as DistanceProvider<f32>>::distance_comparer(metric, None);
510 let simple = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
511 let bounds = get_float_bounds(metric);
512 do_test(unaligned, simple, bounds, dim, dist);
513 }
514 }
515 }
516
517 #[test]
518 fn test_unaligned_f16() {
519 let dist = rand_distr::Normal::new(0.0, 1.0).unwrap();
520 for metric in all_metrics() {
521 for dim in 0..MAX_DIM {
522 println!("Metric = {:?}, dim = {}", metric, dim);
523 let unaligned = <f16 as DistanceProvider<f16>>::distance_comparer(metric, None);
524 let simple = <f16 as ReferenceProvider<f16>>::reference_implementation(metric);
525 let bounds = get_float_bounds(metric);
526 do_test(unaligned, simple, bounds, dim, dist);
527 }
528 }
529 }
530
531 #[test]
532 fn test_unaligned_u8() {
533 let dist = rand::distr::StandardUniform {};
534 for metric in all_metrics() {
535 for dim in 0..MAX_DIM {
536 println!("Metric = {:?}, dim = {}", metric, dim);
537 let unaligned = <u8 as DistanceProvider<u8>>::distance_comparer(metric, None);
538 let simple = <u8 as ReferenceProvider<u8>>::reference_implementation(metric);
539 let bounds = get_int_bounds(metric);
540 do_test(unaligned, simple, bounds, dim, dist);
541 }
542 }
543 }
544
545 #[test]
546 fn test_unaligned_i8() {
547 let dist = rand::distr::StandardUniform {};
548 for metric in all_metrics() {
549 for dim in 0..MAX_DIM {
550 println!("Metric = {:?}, dim = {}", metric, dim);
551 let unaligned = <i8 as DistanceProvider<i8>>::distance_comparer(metric, None);
552 let simple = <i8 as ReferenceProvider<i8>>::reference_implementation(metric);
553
554 let bounds = get_int_bounds(metric);
555 do_test(unaligned, simple, bounds, dim, dist);
556 }
557 }
558 }
559}
560
561#[cfg(test)]
562mod distance_provider_f32_tests {
563 use approx::assert_abs_diff_eq;
564 use rand::{rngs::StdRng, Rng, SeedableRng};
565
566 use super::*;
567 use crate::{distance::reference, test_util::*};
568
569 #[repr(C, align(32))]
570 pub struct F32Slice112([f32; 112]);
571 #[repr(C, align(32))]
572 pub struct F32Slice104([f32; 104]);
573 #[repr(C, align(32))]
574 pub struct F32Slice128([f32; 128]);
575 #[repr(C, align(32))]
576 pub struct F32Slice256([f32; 256]);
577 #[repr(C, align(32))]
578 pub struct F32Slice4096([f32; 4096]);
579
580 pub fn get_turing_test_data_f32_dim(dim: usize) -> (Vec<f32>, Vec<f32>) {
581 let mut a_slice = vec![0.0f32; dim];
582 let mut b_slice = vec![0.0f32; dim];
583
584 let mut rng = StdRng::seed_from_u64(42);
585 for i in 0..dim {
586 a_slice[i] = rng.random_range(-1.0..1.0);
587 b_slice[i] = rng.random_range(-1.0..1.0);
588 }
589
590 ((a_slice), (b_slice))
591 }
592
593 #[test]
594 fn test_dist_l2_float_turing_104() {
595 let (a_data, b_data) = get_turing_test_data_f32_dim(104);
596 let (a_slice, b_slice) = (
597 F32Slice104(a_data.try_into().unwrap()),
598 F32Slice104(b_data.try_into().unwrap()),
599 );
600
601 let distance: f32 = compare_two_vec::<f32>(104, Metric::L2, &a_slice.0, &b_slice.0);
602
603 assert_abs_diff_eq!(
604 distance as f64,
605 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
606 epsilon = 1e-4f64
607 );
608 }
609
610 #[test]
611 fn test_dist_l2_float_turing_112() {
612 let (a_data, b_data) = get_turing_test_data_f32_dim(112);
613 let (a_slice, b_slice) = (
614 F32Slice112(a_data.try_into().unwrap()),
615 F32Slice112(b_data.try_into().unwrap()),
616 );
617
618 let distance: f32 = compare_two_vec::<f32>(112, Metric::L2, &a_slice.0, &b_slice.0);
619
620 assert_abs_diff_eq!(
621 distance as f64,
622 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
623 epsilon = 1e-4f64
624 );
625 }
626
627 #[test]
628 fn test_dist_l2_float_turing_128() {
629 let (a_data, b_data) = get_turing_test_data_f32_dim(128);
630 let (a_slice, b_slice) = (
631 F32Slice128(a_data.try_into().unwrap()),
632 F32Slice128(b_data.try_into().unwrap()),
633 );
634
635 let distance: f32 = compare_two_vec::<f32>(128, Metric::L2, &a_slice.0, &b_slice.0);
636
637 assert_abs_diff_eq!(
638 distance as f64,
639 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
640 epsilon = 1e-4f64
641 );
642 }
643
644 #[test]
645 fn test_dist_l2_float_turing_256() {
646 let (a_data, b_data) = get_turing_test_data_f32_dim(256);
647 let (a_slice, b_slice) = (
648 F32Slice256(a_data.try_into().unwrap()),
649 F32Slice256(b_data.try_into().unwrap()),
650 );
651
652 let distance: f32 = compare_two_vec::<f32>(256, Metric::L2, &a_slice.0, &b_slice.0);
653
654 assert_abs_diff_eq!(
655 distance as f64,
656 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
657 epsilon = 1e-3f64
658 );
659 }
660
661 #[test]
662 fn test_dist_l2_float_turing_4096() {
663 let (a_data, b_data) = get_turing_test_data_f32_dim(4096);
664 let (a_slice, b_slice) = (
665 F32Slice4096(a_data.try_into().unwrap()),
666 F32Slice4096(b_data.try_into().unwrap()),
667 );
668
669 let distance: f32 = compare_two_vec::<f32>(4096, Metric::L2, &a_slice.0, &b_slice.0);
670
671 assert_abs_diff_eq!(
672 distance as f64,
673 no_vector_compare_f32_as_f64(&a_slice.0, &b_slice.0),
674 epsilon = 1e-2f64
675 );
676 }
677
678 #[test]
679 fn test_dist_ip_float_turing_112() {
680 let (a_data, b_data) = get_turing_test_data_f32_dim(112);
681 let (a_slice, b_slice) = (
682 F32Slice112(a_data.try_into().unwrap()),
683 F32Slice112(b_data.try_into().unwrap()),
684 );
685
686 let distance: f32 =
687 compare_two_vec::<f32>(112, Metric::InnerProduct, &a_slice.0, &b_slice.0);
688
689 assert_abs_diff_eq!(
690 distance,
691 reference::reference_innerproduct_f32_similarity(&a_slice.0, &b_slice.0).into_inner(),
692 epsilon = 1e-4f32
693 );
694 }
695
696 #[test]
697 fn distance_test() {
698 #[repr(C, align(32))]
699 struct Vector32ByteAligned {
700 v: [f32; 512],
701 }
702
703 let two_vec = Box::new(Vector32ByteAligned {
705 v: [
706 69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
707 20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
708 65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
709 95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
710 71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
711 87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
712 40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
713 66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
714 27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
715 26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
716 35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
717 1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
718 72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
719 62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
720 70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
721 35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
722 31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
723 93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
724 94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
725 80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
726 14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
727 14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
728 64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
729 55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
730 69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
731 76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
732 70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
733 87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
734 57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
735 23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
736 34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
737 94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
738 40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
739 11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
740 95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
741 73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
742 39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
743 14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
744 2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
745 18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
746 23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
747 63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
748 26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
749 28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
750 99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
751 77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
752 72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
753 66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
754 77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
755 18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
756 7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
757 11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
758 51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
759 46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
760 6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
761 3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
762 96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
763 12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
764 44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
765 54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
766 26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
767 40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
768 25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
769 34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
770 86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
771 32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
772 9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
773 50.41517, 28.156603, 42.644154,
774 ],
775 });
776
777 let distance: f32 = compare::<f32>(256, Metric::L2, &two_vec.v);
778
779 assert_eq!(distance, 429141.2);
780 }
781
782 fn compare<T>(dim: usize, metric: Metric, v: &[T]) -> f32
783 where
784 T: DistanceProvider<T>,
785 {
786 let distance_comparer = T::distance_comparer(metric, Some(dim));
787 distance_comparer.call(&v[..dim], &v[dim..])
788 }
789
790 pub fn compare_two_vec<T>(dim: usize, metric: Metric, v1: &[T], v2: &[T]) -> f32
791 where
792 T: DistanceProvider<T>,
793 {
794 let distance_comparer = T::distance_comparer(metric, Some(dim));
795 distance_comparer.call(&v1[..dim], &v2[..dim])
796 }
797}
798
799#[cfg(test)]
800mod distance_provider_f16_tests {
801 use approx::assert_abs_diff_eq;
802
803 use super::{distance_provider_f32_tests::get_turing_test_data_f32_dim, *};
804 use crate::{
805 distance::distance_provider::distance_provider_f32_tests::compare_two_vec,
806 test_util::no_vector_compare_f16_as_f64,
807 };
808
809 #[repr(C, align(32))]
810 pub struct F16Slice112([f16; 112]);
811 #[repr(C, align(32))]
812 pub struct F16Slice104([f16; 104]);
813 #[repr(C, align(32))]
814 pub struct F16Slice128([f16; 128]);
815 #[repr(C, align(32))]
816 pub struct F16Slice256([f16; 256]);
817 #[repr(C, align(32))]
818 pub struct F16Slice4096([f16; 4096]);
819
820 fn get_turing_test_data_f16_dim(dim: usize) -> (Vec<f16>, Vec<f16>) {
821 let (a_slice, b_slice) = get_turing_test_data_f32_dim(dim);
822 let a_data = a_slice.iter().map(|x| f16::from_f32(*x)).collect();
823 let b_data = b_slice.iter().map(|x| f16::from_f32(*x)).collect();
824 (a_data, b_data)
825 }
826
827 #[test]
828 fn test_dist_l2_f16_turing_112() {
829 let (a_data, b_data) = get_turing_test_data_f16_dim(112);
831 let (a_slice, b_slice) = (
832 F16Slice112(a_data.try_into().unwrap()),
833 F16Slice112(b_data.try_into().unwrap()),
834 );
835
836 let distance: f32 = compare_two_vec::<f16>(112, Metric::L2, &a_slice.0, &b_slice.0);
837
838 assert_abs_diff_eq!(
840 distance as f64,
841 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
842 epsilon = 1e-3f64
843 );
844 }
845
846 #[test]
847 fn test_dist_l2_f16_turing_104() {
848 let (a_data, b_data) = get_turing_test_data_f16_dim(104);
850 let (a_slice, b_slice) = (
851 F16Slice104(a_data.try_into().unwrap()),
852 F16Slice104(b_data.try_into().unwrap()),
853 );
854
855 let distance: f32 = compare_two_vec::<f16>(104, Metric::L2, &a_slice.0, &b_slice.0);
856
857 assert_abs_diff_eq!(
859 distance as f64,
860 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
861 epsilon = 1e-3f64
862 );
863 }
864
865 #[test]
866 fn test_dist_l2_f16_turing_256() {
867 let (a_data, b_data) = get_turing_test_data_f16_dim(256);
869 let (a_slice, b_slice) = (
870 F16Slice256(a_data.try_into().unwrap()),
871 F16Slice256(b_data.try_into().unwrap()),
872 );
873
874 let distance: f32 = compare_two_vec::<f16>(256, Metric::L2, &a_slice.0, &b_slice.0);
875
876 assert_abs_diff_eq!(
878 distance as f64,
879 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
880 epsilon = 1e-3f64
881 );
882 }
883
884 #[test]
885 fn test_dist_l2_f16_turing_128() {
886 let (a_data, b_data) = get_turing_test_data_f16_dim(128);
888 let (a_slice, b_slice) = (
889 F16Slice128(a_data.try_into().unwrap()),
890 F16Slice128(b_data.try_into().unwrap()),
891 );
892
893 let distance: f32 = compare_two_vec::<f16>(128, Metric::L2, &a_slice.0, &b_slice.0);
894
895 assert_abs_diff_eq!(
897 distance as f64,
898 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
899 epsilon = 1e-3f64
900 );
901 }
902
903 #[test]
904 fn test_dist_l2_f16_turing_4096() {
905 let (a_data, b_data) = get_turing_test_data_f16_dim(4096);
907 let (a_slice, b_slice) = (
908 F16Slice4096(a_data.try_into().unwrap()),
909 F16Slice4096(b_data.try_into().unwrap()),
910 );
911
912 let distance: f32 = compare_two_vec::<f16>(4096, Metric::L2, &a_slice.0, &b_slice.0);
913
914 assert_abs_diff_eq!(
916 distance as f64,
917 no_vector_compare_f16_as_f64(&a_slice.0, &b_slice.0),
918 epsilon = 1e-2f64
919 );
920 }
921
922 #[test]
923 fn test_dist_l2_f16_produces_nan_distance_for_infinity_vectors() {
924 let a_data = vec![f16::INFINITY; 384];
925 let b_data = vec![f16::INFINITY; 384];
926
927 let distance: f32 = compare_two_vec::<f16>(384, Metric::L2, &a_data, &b_data);
928 assert!(distance.is_nan());
929 }
930}