1use diskann_wide::{arch::Target2, Architecture, ARCH};
7
8use super::simd;
10use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};
11
12trait ToSlice {
13 type Target;
14 fn to_slice(&self) -> &[Self::Target];
15}
16
17impl<T> ToSlice for &[T] {
18 type Target = T;
19 fn to_slice(&self) -> &[T] {
20 self
21 }
22}
23impl<T, const N: usize> ToSlice for &[T; N] {
24 type Target = T;
25 fn to_slice(&self) -> &[T] {
26 &self[..]
27 }
28}
29impl<T, const N: usize> ToSlice for [T; N] {
30 type Target = T;
31 fn to_slice(&self) -> &[T] {
32 &self[..]
33 }
34}
35
36macro_rules! architecture_hook {
37 ($functor:ty, $impl:path) => {
38 impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
39 where
40 A: Architecture,
41 L: ToSlice,
42 R: ToSlice,
43 $impl: simd::SIMDSchema<L::Target, R::Target, A>,
44 Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
45 {
46 #[inline(always)]
47 fn run(self, arch: A, left: L, right: R) -> T {
48 Self::post_op(simd::simd_op(
49 &$impl,
50 arch,
51 left.to_slice(),
52 right.to_slice(),
53 ))
54 }
55 }
56
57 impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
58 where
59 A: Architecture,
60 L: ToSlice,
61 R: ToSlice,
62 Self: diskann_wide::arch::Target2<A, T, L, R>,
63 {
64 #[inline(always)]
65 fn run(arch: A, left: L, right: R) -> T {
66 arch.run2(Self::default(), left, right)
67 }
68 }
69 };
70}
71
72#[derive(Debug, Clone, Copy)]
74pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);
75
76impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
77 for Specialize<N, F>
78where
79 A: Architecture,
80 F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
81{
82 #[inline(always)]
83 fn run(arch: A, x: &[L], y: &[R]) -> T {
84 if (x.len() != N) | (y.len() != N) {
85 fail_length_check(x, y, N);
86 }
87
88 arch.run2(
93 F::default(),
94 unsafe { &*(x.as_ptr() as *const [L; N]) },
95 unsafe { &*(y.as_ptr() as *const [R; N]) },
96 )
97 }
98}
99
100#[inline(never)]
103#[allow(clippy::panic)]
104fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
105 let message = if x.len() != len {
106 ("first", x.len())
107 } else {
108 ("second", y.len())
109 };
110 panic!(
111 "expected {} argument to have length {}, instead it has length {}",
112 message.0, len, message.1
113 );
114}
115
116pub(super) trait PostOp<From, To> {
122 fn post_op(x: From) -> To;
123}
124
125macro_rules! use_simd_implementation {
127 ($functor:ty, $T:ty, $U:ty) => {
128 impl PureDistanceFunction<&[$T], &[$U], SimilarityScore<f32>> for $functor {
134 #[inline]
135 fn evaluate(x: &[$T], y: &[$U]) -> SimilarityScore<f32> {
136 <$functor>::default().run(ARCH, x, y)
137 }
138 }
139 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
141 for $functor
142 {
143 #[inline]
144 fn evaluate(x: &[$T; N], y: &[$U; N]) -> SimilarityScore<f32> {
145 <$functor>::default().run(ARCH, x, y)
146 }
147 }
148
149 impl PureDistanceFunction<&[$T], &[$U], MathematicalValue<f32>> for $functor {
155 #[inline]
156 fn evaluate(x: &[$T], y: &[$U]) -> MathematicalValue<f32> {
157 <$functor>::default().run(ARCH, x, y)
158 }
159 }
160 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], MathematicalValue<f32>>
162 for $functor
163 {
164 #[inline]
165 fn evaluate(x: &[$T; N], y: &[$U; N]) -> MathematicalValue<f32> {
166 <$functor>::default().run(ARCH, x, y)
167 }
168 }
169
170 impl PureDistanceFunction<&[$T], &[$U], f32> for $functor {
176 #[inline(always)]
177 fn evaluate(x: &[$T], y: &[$U]) -> f32 {
178 <$functor>::default().run(ARCH, x, y)
179 }
180 }
181
182 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], f32> for $functor {
184 #[inline]
185 fn evaluate(x: &[$T; N], y: &[$U; N]) -> f32 {
186 <$functor>::default().run(ARCH, x, y)
187 }
188 }
189 };
190}
191
192#[derive(Debug, Clone, Copy, Default)]
198pub struct SquaredL2 {}
199
200impl PostOp<f32, SimilarityScore<f32>> for SquaredL2 {
201 #[inline(always)]
202 fn post_op(x: f32) -> SimilarityScore<f32> {
203 SimilarityScore::new(x)
204 }
205}
206
207impl PostOp<f32, f32> for SquaredL2 {
208 #[inline(always)]
209 fn post_op(x: f32) -> f32 {
210 x
211 }
212}
213
214impl PostOp<f32, MathematicalValue<f32>> for SquaredL2 {
215 #[inline(always)]
216 fn post_op(x: f32) -> MathematicalValue<f32> {
217 MathematicalValue::new(x)
218 }
219}
220
221architecture_hook!(SquaredL2, simd::L2);
222use_simd_implementation!(SquaredL2, f32, f32);
223use_simd_implementation!(SquaredL2, f32, Half);
224use_simd_implementation!(SquaredL2, Half, Half);
225use_simd_implementation!(SquaredL2, i8, i8);
226use_simd_implementation!(SquaredL2, u8, u8);
227
228#[derive(Debug, Clone, Copy, Default)]
237pub struct FullL2 {}
238
239impl PostOp<f32, SimilarityScore<f32>> for FullL2 {
240 #[inline(always)]
241 fn post_op(x: f32) -> SimilarityScore<f32> {
242 SimilarityScore::new(x.sqrt())
243 }
244}
245
246impl PostOp<f32, f32> for FullL2 {
247 #[inline(always)]
248 fn post_op(x: f32) -> f32 {
249 x.sqrt()
250 }
251}
252
253impl PostOp<f32, MathematicalValue<f32>> for FullL2 {
254 #[inline(always)]
255 fn post_op(x: f32) -> MathematicalValue<f32> {
256 MathematicalValue::new(x.sqrt())
257 }
258}
259
260architecture_hook!(FullL2, simd::L2);
261use_simd_implementation!(FullL2, f32, f32);
262use_simd_implementation!(FullL2, f32, Half);
263use_simd_implementation!(FullL2, Half, Half);
264use_simd_implementation!(FullL2, i8, i8);
265use_simd_implementation!(FullL2, u8, u8);
266
267#[derive(Debug, Clone, Copy, Default)]
273pub struct InnerProduct {}
274
275impl PostOp<f32, SimilarityScore<f32>> for InnerProduct {
276 #[inline(always)]
280 fn post_op(x: f32) -> SimilarityScore<f32> {
281 SimilarityScore::new(-x)
282 }
283}
284
285impl PostOp<f32, MathematicalValue<f32>> for InnerProduct {
286 #[inline(always)]
287 fn post_op(x: f32) -> MathematicalValue<f32> {
288 MathematicalValue::new(x)
289 }
290}
291
292impl PostOp<f32, f32> for InnerProduct {
293 #[inline(always)]
294 fn post_op(x: f32) -> f32 {
295 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
296 }
297}
298
299architecture_hook!(InnerProduct, simd::IP);
300use_simd_implementation!(InnerProduct, f32, f32);
301use_simd_implementation!(InnerProduct, f32, Half);
302use_simd_implementation!(InnerProduct, Half, Half);
303use_simd_implementation!(InnerProduct, i8, i8);
304use_simd_implementation!(InnerProduct, u8, u8);
305
306fn cosine_transformation(x: f32) -> f32 {
314 1.0 - x
315}
316
317#[derive(Debug, Clone, Copy, Default)]
319pub struct Cosine {}
320
321impl PostOp<f32, SimilarityScore<f32>> for Cosine {
322 fn post_op(x: f32) -> SimilarityScore<f32> {
323 debug_assert!(x >= -1.0);
324 debug_assert!(x <= 1.0);
325 SimilarityScore::new(cosine_transformation(x))
326 }
327}
328
329impl PostOp<f32, MathematicalValue<f32>> for Cosine {
330 fn post_op(x: f32) -> MathematicalValue<f32> {
331 debug_assert!(x >= -1.0);
332 debug_assert!(x <= 1.0);
333 MathematicalValue::new(x)
334 }
335}
336
337impl PostOp<f32, f32> for Cosine {
338 fn post_op(x: f32) -> f32 {
339 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
340 }
341}
342
343architecture_hook!(Cosine, simd::CosineStateless);
344use_simd_implementation!(Cosine, f32, f32);
345use_simd_implementation!(Cosine, f32, Half);
346use_simd_implementation!(Cosine, Half, Half);
347use_simd_implementation!(Cosine, i8, i8);
348use_simd_implementation!(Cosine, u8, u8);
349
350#[derive(Debug, Clone, Copy, Default)]
356pub struct CosineNormalized {}
357
358impl PostOp<f32, SimilarityScore<f32>> for CosineNormalized {
359 #[inline(always)]
360 fn post_op(x: f32) -> SimilarityScore<f32> {
361 SimilarityScore::new(cosine_transformation(x))
367 }
368}
369
370impl PostOp<f32, MathematicalValue<f32>> for CosineNormalized {
371 #[inline(always)]
372 fn post_op(x: f32) -> MathematicalValue<f32> {
373 MathematicalValue::new(x)
374 }
375}
376
377impl PostOp<f32, f32> for CosineNormalized {
378 #[inline(always)]
379 fn post_op(x: f32) -> f32 {
380 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
381 }
382}
383
384architecture_hook!(CosineNormalized, simd::IP);
385use_simd_implementation!(CosineNormalized, f32, f32);
386use_simd_implementation!(CosineNormalized, f32, Half);
387use_simd_implementation!(CosineNormalized, Half, Half);
388
389#[derive(Debug, Clone, Copy, Default)]
395pub struct L1NormFunctor {}
396
397impl PostOp<f32, f32> for L1NormFunctor {
398 #[inline(always)]
399 fn post_op(x: f32) -> f32 {
400 x
401 }
402}
403
404architecture_hook!(L1NormFunctor, simd::L1Norm);
405
406impl PureDistanceFunction<&[f32], &[f32], f32> for L1NormFunctor {
407 #[inline]
408 fn evaluate(x: &[f32], y: &[f32]) -> f32 {
409 L1NormFunctor::default().run(ARCH, x, y)
410 }
411}
412
413#[cfg(test)]
418mod tests {
419
420 use std::hash::{Hash, Hasher};
421
422 use approx::assert_relative_eq;
423 use rand::{Rng, SeedableRng};
424
425 use super::*;
426 use crate::{
427 distance::{
428 reference::{self, ReferenceProvider},
429 Metric,
430 },
431 test_util::{self, Normalize},
432 };
433
434 pub fn as_function_pointer<T, Left, Right, Return>(x: &[Left], y: &[Right]) -> Return
435 where
436 T: for<'a, 'b> PureDistanceFunction<&'a [Left], &'b [Right], Return>,
437 {
438 T::evaluate(x, y)
439 }
440
441 fn simd_provider(metric: Metric) -> fn(&[f32], &[f32]) -> f32 {
442 match metric {
443 Metric::L2 => as_function_pointer::<SquaredL2, _, _, _>,
444 Metric::InnerProduct => as_function_pointer::<InnerProduct, _, _, _>,
445 Metric::Cosine => as_function_pointer::<Cosine, _, _, _>,
446 Metric::CosineNormalized => as_function_pointer::<CosineNormalized, _, _, _>,
447 }
448 }
449
450 fn random_normal_arguments(dim: usize, lo: f32, hi: f32, seed: u64) -> (Vec<f32>, Vec<f32>) {
451 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
452 let x: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
453 let y: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
454 (x, y)
455 }
456
457 struct LeftRightPair {
458 pub x: Vec<f32>,
459 pub y: Vec<f32>,
460 }
461
462 fn generate_corner_cases(dim: usize) -> Vec<LeftRightPair> {
463 let mut output = Vec::<LeftRightPair>::new();
464 let fixed_values = [0.0, -5.0, 5.0, 10.0];
465
466 for va in fixed_values.iter() {
467 for vb in fixed_values.iter() {
468 let x: Vec<f32> = vec![*va; dim];
469 let y: Vec<f32> = vec![*vb; dim];
470 output.push(LeftRightPair { x, y });
471 }
472 }
473 output
474 }
475
476 fn collect_random_arguments(
477 dim: usize,
478 num_trials: usize,
479 lo: f32,
480 hi: f32,
481 mut seed: u64,
482 ) -> Vec<LeftRightPair> {
483 (0..num_trials)
484 .map(|_| {
485 let (x, y) = random_normal_arguments(dim, lo, hi, seed);
486
487 let mut hasher = std::hash::DefaultHasher::new();
489 seed.hash(&mut hasher);
490 seed = hasher.finish();
491
492 LeftRightPair { x, y }
493 })
494 .collect()
495 }
496
497 fn test_pure_functions_impl<T>(metric: Metric, _func: T, normalize: bool)
498 where
499 T: for<'a, 'b> PureDistanceFunction<&'a [f32], &'b [f32], f32> + Clone,
500 {
501 let epsilon: f32 = 1e-4;
502 let max_relative: f32 = 1e-4;
503
504 let max_dim = 256;
505 let num_trials = 10;
506
507 let f_reference = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
508 let f_simd = simd_provider(metric);
509
510 let run_tests = |argument_pairs: Vec<LeftRightPair>| {
512 for LeftRightPair { mut x, mut y } in argument_pairs {
513 if normalize {
514 x.normalize();
515 y.normalize();
516 }
517
518 let reference: f32 = f_reference(&x, &y).into_inner();
519 let simd = f_simd(&x, &y);
520
521 assert_relative_eq!(
522 reference,
523 simd,
524 epsilon = epsilon,
525 max_relative = max_relative
526 );
527
528 let simd_direct = T::evaluate(&x, &y);
530 assert_eq!(simd_direct, simd);
531 }
532 };
533
534 for dim in 0..max_dim {
536 run_tests(generate_corner_cases(dim));
537 }
538
539 for dim in 0..max_dim {
541 run_tests(collect_random_arguments(
542 dim, num_trials, -10.0, 10.0, 0x5643,
543 ));
544 }
545 }
546
547 #[test]
548 fn test_pure_functions() {
549 println!("L2");
550 test_pure_functions_impl(Metric::L2, SquaredL2 {}, false);
551 println!("InnerProduct");
552 test_pure_functions_impl(Metric::InnerProduct, InnerProduct {}, false);
553 println!("Cosine");
554 test_pure_functions_impl(Metric::Cosine, Cosine {}, false);
555 println!("CosineNormalized");
556 test_pure_functions_impl(Metric::CosineNormalized, CosineNormalized {}, true);
557 }
558
559 #[test]
562 fn test_specialize() {
563 use diskann_wide::arch::FTarget2;
564
565 const DIM: usize = 123;
566 let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);
567
568 let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
569 let evaluated: f32 =
570 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
571
572 assert_eq!(reference, evaluated);
574 }
575
576 #[test]
577 #[should_panic]
578 fn test_function_pointer_const_panics_left() {
579 use diskann_wide::arch::FTarget2;
580
581 const DIM: usize = 34;
582 let x = vec![0.0f32; DIM + 1];
583 let y = vec![0.0f32; DIM];
584 let _: f32 =
586 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
587 }
588
589 #[test]
590 #[should_panic]
591 fn test_function_pointer_const_panics_right() {
592 use diskann_wide::arch::FTarget2;
593
594 const DIM: usize = 34;
595 let x = vec![0.0f32; DIM];
596 let y = vec![0.0f32; DIM + 1];
597 let _: f32 =
599 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
600 }
601
602 trait GetInner {
607 fn get_inner(self) -> f32;
608 }
609
610 impl GetInner for f32 {
611 fn get_inner(self) -> f32 {
612 self
613 }
614 }
615
616 impl GetInner for SimilarityScore<f32> {
617 fn get_inner(self) -> f32 {
618 self.into_inner()
619 }
620 }
621
622 impl GetInner for MathematicalValue<f32> {
623 fn get_inner(self) -> f32 {
624 self.into_inner()
625 }
626 }
627
628 #[derive(Clone, Copy)]
630 struct EpsilonAndRelative {
631 epsilon: f32,
632 max_relative: f32,
633 }
634
635 #[allow(clippy::too_many_arguments)]
636 fn run_test<L, R, To, Distribution, Callback>(
637 under_test: fn(&[L], &[R]) -> To,
638 reference: fn(&[L], &[R]) -> To,
639 bounds: EpsilonAndRelative,
640 dim: usize,
641 num_trials: usize,
642 distribution: Distribution,
643 rng: &mut impl Rng,
644 mut cb: Callback,
645 ) where
646 L: test_util::CornerCases,
647 R: test_util::CornerCases,
648 Distribution:
649 test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
650 To: GetInner + Copy,
651 Callback: FnMut(To, To),
652 {
653 let checker =
654 test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
655 cb(got, expected);
657 assert_relative_eq!(
658 got.get_inner(),
659 expected.get_inner(),
660 epsilon = bounds.epsilon,
661 max_relative = bounds.max_relative
662 );
663 });
664
665 test_util::test_distance_function(
666 checker,
667 distribution.clone(),
668 distribution.clone(),
669 dim,
670 num_trials,
671 rng,
672 );
673 }
674
675 #[cfg(not(debug_assertions))]
677 const MAX_DIM: usize = 256;
678
679 #[cfg(debug_assertions)]
680 const MAX_DIM: usize = 160;
681
682 #[cfg(not(debug_assertions))]
684 const INTEGER_TRIALS: usize = 10000;
685
686 #[cfg(debug_assertions)]
687 const INTEGER_TRIALS: usize = 100;
688
689 fn run_integer_test<T, R>(
696 under_test: fn(&[T], &[T]) -> R,
697 reference: fn(&[T], &[T]) -> R,
698 rng: &mut impl Rng,
699 ) where
700 T: test_util::CornerCases,
701 R: GetInner + Copy,
702 rand::distr::StandardUniform: test_util::GenerateRandomArguments<T> + Clone,
703 {
704 let distribution = rand::distr::StandardUniform {};
705 let num_corner_cases = <T as test_util::CornerCases>::corner_cases().len();
706
707 for dim in 0..MAX_DIM {
708 let mut callcount = 0;
709 let callback = |_, _| {
710 callcount += 1;
711 };
712
713 run_test(
714 under_test,
715 reference,
716 EpsilonAndRelative {
717 epsilon: 0.0,
718 max_relative: 0.0,
719 },
720 dim,
721 INTEGER_TRIALS,
722 distribution,
723 rng,
724 callback,
725 );
726
727 assert_eq!(
729 callcount,
730 INTEGER_TRIALS + num_corner_cases * num_corner_cases
731 );
732 }
733 }
734
735 #[test]
740 fn test_l2_i8_mathematical() {
741 let mut rng = rand::rngs::StdRng::seed_from_u64(0x2bb701074c2b81c9);
742 run_integer_test(
743 as_function_pointer::<FullL2, i8, i8, MathematicalValue<f32>>,
744 reference::reference_l2_i8_mathematical,
745 &mut rng,
746 );
747 }
748
749 #[test]
750 fn test_l2_u8_mathematical() {
751 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9284ced6d080808c);
752 run_integer_test(
753 as_function_pointer::<FullL2, u8, u8, MathematicalValue<f32>>,
754 reference::reference_l2_u8_mathematical,
755 &mut rng,
756 );
757 }
758
759 #[test]
760 fn test_l2_i8_similarity() {
761 let mut rng = rand::rngs::StdRng::seed_from_u64(0xb196fecc4def04fa);
762 run_integer_test(
763 as_function_pointer::<FullL2, i8, i8, SimilarityScore<f32>>,
764 reference::reference_l2_i8_similarity,
765 &mut rng,
766 );
767 }
768
769 #[test]
770 fn test_l2_u8_similarity() {
771 let mut rng = rand::rngs::StdRng::seed_from_u64(0x07f6463e4a654aea);
772 run_integer_test(
773 as_function_pointer::<FullL2, u8, u8, SimilarityScore<f32>>,
774 reference::reference_l2_u8_similarity,
775 &mut rng,
776 );
777 }
778
779 #[test]
784 fn test_innerproduct_i8_mathematical() {
785 let mut rng = rand::rngs::StdRng::seed_from_u64(0x2c1b1bddda5774be);
786 run_integer_test(
787 as_function_pointer::<InnerProduct, i8, i8, MathematicalValue<f32>>,
788 reference::reference_innerproduct_i8_mathematical,
789 &mut rng,
790 );
791 }
792
793 #[test]
794 fn test_innerproduct_u8_mathematical() {
795 let mut rng = rand::rngs::StdRng::seed_from_u64(0x757e363832d7f215);
796 run_integer_test(
797 as_function_pointer::<InnerProduct, u8, u8, MathematicalValue<f32>>,
798 reference::reference_innerproduct_u8_mathematical,
799 &mut rng,
800 );
801 }
802
803 #[test]
804 fn test_innerproduct_i8_similarity() {
805 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4788ce0b991eb15a);
806 run_integer_test(
807 as_function_pointer::<InnerProduct, i8, i8, SimilarityScore<f32>>,
808 reference::reference_innerproduct_i8_similarity,
809 &mut rng,
810 );
811 }
812
813 #[test]
814 fn test_innerproduct_u8_similarity() {
815 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4994adb68f814d96);
816 run_integer_test(
817 as_function_pointer::<InnerProduct, u8, u8, SimilarityScore<f32>>,
818 reference::reference_innerproduct_u8_similarity,
819 &mut rng,
820 );
821 }
822
823 #[test]
828 fn test_cosine_i8_mathematical() {
829 let mut rng = rand::rngs::StdRng::seed_from_u64(0xedef81c780491ada);
830 run_integer_test(
831 as_function_pointer::<Cosine, i8, i8, MathematicalValue<f32>>,
832 reference::reference_cosine_i8_mathematical,
833 &mut rng,
834 );
835 }
836
837 #[test]
838 fn test_cosine_u8_mathematical() {
839 let mut rng = rand::rngs::StdRng::seed_from_u64(0x107cee2adcc58b73);
840 run_integer_test(
841 as_function_pointer::<Cosine, u8, u8, MathematicalValue<f32>>,
842 reference::reference_cosine_u8_mathematical,
843 &mut rng,
844 );
845 }
846
847 #[test]
848 fn test_cosine_i8_similarity() {
849 let mut rng = rand::rngs::StdRng::seed_from_u64(0x02d95c1cc0843647);
850 run_integer_test(
851 as_function_pointer::<Cosine, i8, i8, SimilarityScore<f32>>,
852 reference::reference_cosine_i8_similarity,
853 &mut rng,
854 );
855 }
856
857 #[test]
858 fn test_cosine_u8_similarity() {
859 let mut rng = rand::rngs::StdRng::seed_from_u64(0xf5ea1974bf8d8b3b);
860 run_integer_test(
861 as_function_pointer::<Cosine, u8, u8, SimilarityScore<f32>>,
862 reference::reference_cosine_u8_similarity,
863 &mut rng,
864 );
865 }
866
867 fn run_float_test<L, R, To, Dist>(
874 under_test: fn(&[L], &[R]) -> To,
875 reference: fn(&[L], &[R]) -> To,
876 rng: &mut impl Rng,
877 distribution: Dist,
878 bounds: EpsilonAndRelative,
879 ) where
880 L: test_util::CornerCases,
881 R: test_util::CornerCases,
882 To: GetInner + Copy,
883 Dist: test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
884 {
885 let left_corner_cases = <L as test_util::CornerCases>::corner_cases().len();
886 let right_corner_cases = <R as test_util::CornerCases>::corner_cases().len();
887 for dim in 0..MAX_DIM {
888 let mut callcount = 0;
889 let callback = |_, _| {
890 callcount += 1;
891 };
892
893 run_test(
894 under_test,
895 reference,
896 bounds,
897 dim,
898 INTEGER_TRIALS,
899 distribution.clone(),
900 rng,
901 callback,
902 );
903
904 assert_eq!(
906 callcount,
907 INTEGER_TRIALS + left_corner_cases * right_corner_cases
908 );
909 }
910 }
911
912 fn expected_l2_errors() -> EpsilonAndRelative {
917 EpsilonAndRelative {
918 epsilon: 0.0,
919 max_relative: 1.2e-6,
920 }
921 }
922
923 #[test]
924 fn test_l2_f32_mathematical() {
925 let mut rng = rand::rngs::StdRng::seed_from_u64(0x6d22d320bdf35aec);
926 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
927 run_float_test(
928 as_function_pointer::<FullL2, f32, f32, MathematicalValue<f32>>,
929 reference::reference_l2_f32_mathematical,
930 &mut rng,
931 distribution,
932 expected_l2_errors(),
933 );
934 }
935
936 #[test]
937 fn test_l2_f16_mathematical() {
938 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
939 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
940 run_float_test(
941 as_function_pointer::<FullL2, Half, Half, MathematicalValue<f32>>,
942 reference::reference_l2_f16_mathematical,
943 &mut rng,
944 distribution,
945 expected_l2_errors(),
946 );
947 }
948
949 #[test]
950 fn test_l2_f32xf16_mathematical() {
951 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
952 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
953
954 run_float_test(
955 as_function_pointer::<FullL2, f32, Half, MathematicalValue<f32>>,
956 reference::reference_l2_f32xf16_mathematical,
957 &mut rng,
958 distribution,
959 expected_l2_errors(),
960 );
961 }
962
963 #[test]
964 fn test_l2_f32_similarity() {
965 let mut rng = rand::rngs::StdRng::seed_from_u64(0xbfc5f4b42b5bc0c1);
966 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
967 run_float_test(
968 as_function_pointer::<FullL2, f32, f32, SimilarityScore<f32>>,
969 reference::reference_l2_f32_similarity,
970 &mut rng,
971 distribution,
972 expected_l2_errors(),
973 );
974 }
975
976 #[test]
977 fn test_l2_f16_similarity() {
978 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9d3809d84f54e4b6);
979 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
980 run_float_test(
981 as_function_pointer::<FullL2, Half, Half, SimilarityScore<f32>>,
982 reference::reference_l2_f16_similarity,
983 &mut rng,
984 distribution,
985 expected_l2_errors(),
986 );
987 }
988
989 #[test]
990 fn test_l2_f32xf16_similarity() {
991 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
992 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
993
994 run_float_test(
995 as_function_pointer::<FullL2, f32, Half, SimilarityScore<f32>>,
996 reference::reference_l2_f32xf16_similarity,
997 &mut rng,
998 distribution,
999 expected_l2_errors(),
1000 );
1001 }
1002
1003 fn expected_innerproduct_errors() -> EpsilonAndRelative {
1008 EpsilonAndRelative {
1009 epsilon: 2.5e-5,
1010 max_relative: 1.6e-5,
1011 }
1012 }
1013
1014 #[test]
1015 fn test_innerproduct_f32_mathematical() {
1016 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1ef6ac3b65869792);
1017 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1018 run_float_test(
1019 as_function_pointer::<InnerProduct, f32, f32, MathematicalValue<f32>>,
1020 reference::reference_innerproduct_f32_mathematical,
1021 &mut rng,
1022 distribution,
1023 expected_innerproduct_errors(),
1024 );
1025 }
1026
1027 #[test]
1028 fn test_innerproduct_f16_mathematical() {
1029 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1030 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1031 run_float_test(
1032 as_function_pointer::<InnerProduct, Half, Half, MathematicalValue<f32>>,
1033 reference::reference_innerproduct_f16_mathematical,
1034 &mut rng,
1035 distribution,
1036 expected_innerproduct_errors(),
1037 );
1038 }
1039
1040 #[test]
1041 fn test_innerproduct_f32xf16_mathematical() {
1042 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1043 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1044 run_float_test(
1045 as_function_pointer::<InnerProduct, f32, Half, MathematicalValue<f32>>,
1046 reference::reference_innerproduct_f32xf16_mathematical,
1047 &mut rng,
1048 distribution,
1049 expected_innerproduct_errors(),
1050 );
1051 }
1052
1053 #[test]
1054 fn test_innerproduct_f32_similarity() {
1055 let mut rng = rand::rngs::StdRng::seed_from_u64(0x40326b22a57db0d7);
1056 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1057 run_float_test(
1058 as_function_pointer::<InnerProduct, f32, f32, SimilarityScore<f32>>,
1059 reference::reference_innerproduct_f32_similarity,
1060 &mut rng,
1061 distribution,
1062 expected_innerproduct_errors(),
1063 );
1064 }
1065
1066 #[test]
1067 fn test_innerproduct_f16_similarity() {
1068 let mut rng = rand::rngs::StdRng::seed_from_u64(0xfb8cff47bcbc9528);
1069 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1070 run_float_test(
1071 as_function_pointer::<InnerProduct, Half, Half, SimilarityScore<f32>>,
1072 reference::reference_innerproduct_f16_similarity,
1073 &mut rng,
1074 distribution,
1075 expected_innerproduct_errors(),
1076 );
1077 }
1078
1079 #[test]
1080 fn test_innerproduct_f32xf16_similarity() {
1081 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1082 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1083 run_float_test(
1084 as_function_pointer::<InnerProduct, f32, Half, SimilarityScore<f32>>,
1085 reference::reference_innerproduct_f32xf16_similarity,
1086 &mut rng,
1087 distribution,
1088 expected_innerproduct_errors(),
1089 );
1090 }
1091
1092 fn expected_cosine_errors() -> EpsilonAndRelative {
1097 EpsilonAndRelative {
1098 epsilon: 3e-7,
1099 max_relative: 5e-6,
1100 }
1101 }
1102
1103 #[test]
1104 fn test_cosine_f32_mathematical() {
1105 let mut rng = rand::rngs::StdRng::seed_from_u64(0xca6eaac942999500);
1106 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1107 run_float_test(
1108 as_function_pointer::<Cosine, f32, f32, MathematicalValue<f32>>,
1109 reference::reference_cosine_f32_mathematical,
1110 &mut rng,
1111 distribution,
1112 expected_cosine_errors(),
1113 );
1114 }
1115
1116 #[test]
1117 fn test_cosine_f16_mathematical() {
1118 let mut rng = rand::rngs::StdRng::seed_from_u64(0xa736c789aa16ce86);
1119 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1120 run_float_test(
1121 as_function_pointer::<Cosine, Half, Half, MathematicalValue<f32>>,
1122 reference::reference_cosine_f16_mathematical,
1123 &mut rng,
1124 distribution,
1125 expected_cosine_errors(),
1126 );
1127 }
1128
1129 #[test]
1130 fn test_cosine_f32xf16_mathematical() {
1131 let mut rng = rand::rngs::StdRng::seed_from_u64(0xac550231088a0d5c);
1132 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1133 run_float_test(
1134 as_function_pointer::<Cosine, f32, Half, MathematicalValue<f32>>,
1135 reference::reference_cosine_f32xf16_mathematical,
1136 &mut rng,
1137 distribution,
1138 expected_cosine_errors(),
1139 );
1140 }
1141
1142 #[test]
1143 fn test_cosine_f32_similarity() {
1144 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4a09ad987a6204f3);
1145 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1146 run_float_test(
1147 as_function_pointer::<Cosine, f32, f32, SimilarityScore<f32>>,
1148 reference::reference_cosine_f32_similarity,
1149 &mut rng,
1150 distribution,
1151 expected_cosine_errors(),
1152 );
1153 }
1154
1155 #[test]
1156 fn test_cosine_f16_similarity() {
1157 let mut rng = rand::rngs::StdRng::seed_from_u64(0x77a48d1914f850f2);
1158 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1159 run_float_test(
1160 as_function_pointer::<Cosine, Half, Half, SimilarityScore<f32>>,
1161 reference::reference_cosine_f16_similarity,
1162 &mut rng,
1163 distribution,
1164 expected_cosine_errors(),
1165 );
1166 }
1167
1168 #[test]
1169 fn test_cosine_f32xf16_similarity() {
1170 let mut rng = rand::rngs::StdRng::seed_from_u64(0xbd7471b815655ca1);
1171 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1172 run_float_test(
1173 as_function_pointer::<Cosine, f32, Half, SimilarityScore<f32>>,
1174 reference::reference_cosine_f32xf16_similarity,
1175 &mut rng,
1176 distribution,
1177 expected_cosine_errors(),
1178 );
1179 }
1180
1181 fn expected_cosine_normalized_errors() -> EpsilonAndRelative {
1186 EpsilonAndRelative {
1187 epsilon: 3e-7,
1188 max_relative: 5e-6,
1189 }
1190 }
1191
1192 #[test]
1193 fn test_cosine_normalized_f32_mathematical() {
1194 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1fda98112747f8dd);
1195 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1196 run_float_test(
1197 as_function_pointer::<CosineNormalized, f32, f32, MathematicalValue<f32>>,
1198 reference::reference_cosine_normalized_f32_mathematical,
1199 &mut rng,
1200 test_util::Normalized(distribution),
1201 expected_cosine_normalized_errors(),
1202 );
1203 }
1204
1205 #[test]
1206 fn test_cosine_normalized_f16_mathematical() {
1207 let mut rng = rand::rngs::StdRng::seed_from_u64(0x5e8c5d5e19cdd840);
1208 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1209 run_float_test(
1210 as_function_pointer::<CosineNormalized, Half, Half, MathematicalValue<f32>>,
1211 reference::reference_cosine_normalized_f16_mathematical,
1212 &mut rng,
1213 test_util::Normalized(distribution),
1214 expected_cosine_normalized_errors(),
1215 );
1216 }
1217
1218 #[test]
1219 fn test_cosine_normalized_f32xf16_mathematical() {
1220 let mut rng = rand::rngs::StdRng::seed_from_u64(0x3fd01e1c11c9bc45);
1221 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1222 run_float_test(
1223 as_function_pointer::<CosineNormalized, f32, Half, MathematicalValue<f32>>,
1224 reference::reference_cosine_normalized_f32xf16_mathematical,
1225 &mut rng,
1226 test_util::Normalized(distribution),
1227 expected_cosine_normalized_errors(),
1228 );
1229 }
1230
1231 #[test]
1232 fn test_cosine_normalized_f32_similarity() {
1233 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9446d057870e5605);
1234 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1235 run_float_test(
1236 as_function_pointer::<CosineNormalized, f32, f32, SimilarityScore<f32>>,
1237 reference::reference_cosine_normalized_f32_similarity,
1238 &mut rng,
1239 test_util::Normalized(distribution),
1240 expected_cosine_normalized_errors(),
1241 );
1242 }
1243
1244 #[test]
1245 fn test_cosine_normalized_f16_similarity() {
1246 let mut rng = rand::rngs::StdRng::seed_from_u64(0x885c371801f18174);
1247 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1248 run_float_test(
1249 as_function_pointer::<CosineNormalized, Half, Half, SimilarityScore<f32>>,
1250 reference::reference_cosine_normalized_f16_similarity,
1251 &mut rng,
1252 test_util::Normalized(distribution),
1253 expected_cosine_normalized_errors(),
1254 );
1255 }
1256
1257 #[test]
1258 fn test_cosine_normalized_f32xf16_similarity() {
1259 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1c356c92d0522c0f);
1260 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1261 run_float_test(
1262 as_function_pointer::<CosineNormalized, f32, Half, SimilarityScore<f32>>,
1263 reference::reference_cosine_normalized_f32xf16_similarity,
1264 &mut rng,
1265 test_util::Normalized(distribution),
1266 expected_cosine_normalized_errors(),
1267 );
1268 }
1269}