1use diskann_wide::{arch::Target2, Architecture, ARCH};
7
8use super::simd;
10use crate::{Half, MathematicalValue, PureDistanceFunction, SimilarityScore};
11
12trait ToSlice {
13 type Target;
14 fn to_slice(&self) -> &[Self::Target];
15}
16
17impl<T> ToSlice for &[T] {
18 type Target = T;
19 fn to_slice(&self) -> &[T] {
20 self
21 }
22}
23impl<T, const N: usize> ToSlice for &[T; N] {
24 type Target = T;
25 fn to_slice(&self) -> &[T] {
26 &self[..]
27 }
28}
29impl<T, const N: usize> ToSlice for [T; N] {
30 type Target = T;
31 fn to_slice(&self) -> &[T] {
32 &self[..]
33 }
34}
35
36macro_rules! architecture_hook {
37 ($functor:ty, $impl:path) => {
38 impl<A, T, L, R> diskann_wide::arch::Target2<A, T, L, R> for $functor
39 where
40 A: Architecture,
41 L: ToSlice,
42 R: ToSlice,
43 $impl: simd::SIMDSchema<L::Target, R::Target, A>,
44 Self: PostOp<<$impl as simd::SIMDSchema<L::Target, R::Target, A>>::Return, T>,
45 {
46 #[inline(always)]
47 fn run(self, arch: A, left: L, right: R) -> T {
48 Self::post_op(simd::simd_op(
49 &$impl,
50 arch,
51 left.to_slice(),
52 right.to_slice(),
53 ))
54 }
55 }
56
57 impl<A, T, L, R> diskann_wide::arch::FTarget2<A, T, L, R> for $functor
58 where
59 A: Architecture,
60 L: ToSlice,
61 R: ToSlice,
62 Self: diskann_wide::arch::Target2<A, T, L, R>,
63 {
64 #[inline(always)]
65 fn run(arch: A, left: L, right: R) -> T {
66 arch.run2(Self::default(), left, right)
67 }
68 }
69 };
70}
71
72#[cfg(any(test, target_arch = "x86_64"))]
74#[derive(Debug, Clone, Copy)]
75pub(crate) struct Specialize<const N: usize, F>(std::marker::PhantomData<F>);
76
77#[cfg(any(test, target_arch = "x86_64"))]
78impl<A, T, L, R, const N: usize, F> diskann_wide::arch::FTarget2<A, T, &[L], &[R]>
79 for Specialize<N, F>
80where
81 A: Architecture,
82 F: for<'a, 'b> diskann_wide::arch::Target2<A, T, &'a [L; N], &'b [R; N]> + Default,
83{
84 #[inline(always)]
85 fn run(arch: A, x: &[L], y: &[R]) -> T {
86 if (x.len() != N) | (y.len() != N) {
87 fail_length_check(x, y, N);
88 }
89
90 arch.run2(
95 F::default(),
96 unsafe { &*(x.as_ptr() as *const [L; N]) },
97 unsafe { &*(y.as_ptr() as *const [R; N]) },
98 )
99 }
100}
101
102#[cfg(any(test, target_arch = "x86_64"))]
105#[inline(never)]
106#[allow(clippy::panic)]
107fn fail_length_check<L, R>(x: &[L], y: &[R], len: usize) -> ! {
108 let message = if x.len() != len {
109 ("first", x.len())
110 } else {
111 ("second", y.len())
112 };
113 panic!(
114 "expected {} argument to have length {}, instead it has length {}",
115 message.0, len, message.1
116 );
117}
118
119pub(super) trait PostOp<From, To> {
125 fn post_op(x: From) -> To;
126}
127
128macro_rules! use_simd_implementation {
130 ($functor:ty, $T:ty, $U:ty) => {
131 impl PureDistanceFunction<&[$T], &[$U], SimilarityScore<f32>> for $functor {
137 #[inline]
138 fn evaluate(x: &[$T], y: &[$U]) -> SimilarityScore<f32> {
139 <$functor>::default().run(ARCH, x, y)
140 }
141 }
142 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], SimilarityScore<f32>>
144 for $functor
145 {
146 #[inline]
147 fn evaluate(x: &[$T; N], y: &[$U; N]) -> SimilarityScore<f32> {
148 <$functor>::default().run(ARCH, x, y)
149 }
150 }
151
152 impl PureDistanceFunction<&[$T], &[$U], MathematicalValue<f32>> for $functor {
158 #[inline]
159 fn evaluate(x: &[$T], y: &[$U]) -> MathematicalValue<f32> {
160 <$functor>::default().run(ARCH, x, y)
161 }
162 }
163 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], MathematicalValue<f32>>
165 for $functor
166 {
167 #[inline]
168 fn evaluate(x: &[$T; N], y: &[$U; N]) -> MathematicalValue<f32> {
169 <$functor>::default().run(ARCH, x, y)
170 }
171 }
172
173 impl PureDistanceFunction<&[$T], &[$U], f32> for $functor {
179 #[inline(always)]
180 fn evaluate(x: &[$T], y: &[$U]) -> f32 {
181 <$functor>::default().run(ARCH, x, y)
182 }
183 }
184
185 impl<const N: usize> PureDistanceFunction<&[$T; N], &[$U; N], f32> for $functor {
187 #[inline]
188 fn evaluate(x: &[$T; N], y: &[$U; N]) -> f32 {
189 <$functor>::default().run(ARCH, x, y)
190 }
191 }
192 };
193}
194
195#[derive(Debug, Clone, Copy, Default)]
201pub struct SquaredL2 {}
202
203impl PostOp<f32, SimilarityScore<f32>> for SquaredL2 {
204 #[inline(always)]
205 fn post_op(x: f32) -> SimilarityScore<f32> {
206 SimilarityScore::new(x)
207 }
208}
209
210impl PostOp<f32, f32> for SquaredL2 {
211 #[inline(always)]
212 fn post_op(x: f32) -> f32 {
213 x
214 }
215}
216
217impl PostOp<f32, MathematicalValue<f32>> for SquaredL2 {
218 #[inline(always)]
219 fn post_op(x: f32) -> MathematicalValue<f32> {
220 MathematicalValue::new(x)
221 }
222}
223
224architecture_hook!(SquaredL2, simd::L2);
225use_simd_implementation!(SquaredL2, f32, f32);
226use_simd_implementation!(SquaredL2, f32, Half);
227use_simd_implementation!(SquaredL2, Half, Half);
228use_simd_implementation!(SquaredL2, i8, i8);
229use_simd_implementation!(SquaredL2, u8, u8);
230
231#[derive(Debug, Clone, Copy, Default)]
240pub struct FullL2 {}
241
242impl PostOp<f32, SimilarityScore<f32>> for FullL2 {
243 #[inline(always)]
244 fn post_op(x: f32) -> SimilarityScore<f32> {
245 SimilarityScore::new(x.sqrt())
246 }
247}
248
249impl PostOp<f32, f32> for FullL2 {
250 #[inline(always)]
251 fn post_op(x: f32) -> f32 {
252 x.sqrt()
253 }
254}
255
256impl PostOp<f32, MathematicalValue<f32>> for FullL2 {
257 #[inline(always)]
258 fn post_op(x: f32) -> MathematicalValue<f32> {
259 MathematicalValue::new(x.sqrt())
260 }
261}
262
263architecture_hook!(FullL2, simd::L2);
264use_simd_implementation!(FullL2, f32, f32);
265use_simd_implementation!(FullL2, f32, Half);
266use_simd_implementation!(FullL2, Half, Half);
267use_simd_implementation!(FullL2, i8, i8);
268use_simd_implementation!(FullL2, u8, u8);
269
270#[derive(Debug, Clone, Copy, Default)]
276pub struct InnerProduct {}
277
278impl PostOp<f32, SimilarityScore<f32>> for InnerProduct {
279 #[inline(always)]
283 fn post_op(x: f32) -> SimilarityScore<f32> {
284 SimilarityScore::new(-x)
285 }
286}
287
288impl PostOp<f32, MathematicalValue<f32>> for InnerProduct {
289 #[inline(always)]
290 fn post_op(x: f32) -> MathematicalValue<f32> {
291 MathematicalValue::new(x)
292 }
293}
294
295impl PostOp<f32, f32> for InnerProduct {
296 #[inline(always)]
297 fn post_op(x: f32) -> f32 {
298 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
299 }
300}
301
302architecture_hook!(InnerProduct, simd::IP);
303use_simd_implementation!(InnerProduct, f32, f32);
304use_simd_implementation!(InnerProduct, f32, Half);
305use_simd_implementation!(InnerProduct, Half, Half);
306use_simd_implementation!(InnerProduct, i8, i8);
307use_simd_implementation!(InnerProduct, u8, u8);
308
309fn cosine_transformation(x: f32) -> f32 {
317 1.0 - x
318}
319
320#[derive(Debug, Clone, Copy, Default)]
322pub struct Cosine {}
323
324impl PostOp<f32, SimilarityScore<f32>> for Cosine {
325 fn post_op(x: f32) -> SimilarityScore<f32> {
326 debug_assert!(x >= -1.0);
327 debug_assert!(x <= 1.0);
328 SimilarityScore::new(cosine_transformation(x))
329 }
330}
331
332impl PostOp<f32, MathematicalValue<f32>> for Cosine {
333 fn post_op(x: f32) -> MathematicalValue<f32> {
334 debug_assert!(x >= -1.0);
335 debug_assert!(x <= 1.0);
336 MathematicalValue::new(x)
337 }
338}
339
340impl PostOp<f32, f32> for Cosine {
341 fn post_op(x: f32) -> f32 {
342 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
343 }
344}
345
346architecture_hook!(Cosine, simd::CosineStateless);
347use_simd_implementation!(Cosine, f32, f32);
348use_simd_implementation!(Cosine, f32, Half);
349use_simd_implementation!(Cosine, Half, Half);
350use_simd_implementation!(Cosine, i8, i8);
351use_simd_implementation!(Cosine, u8, u8);
352
353#[derive(Debug, Clone, Copy, Default)]
359pub struct CosineNormalized {}
360
361impl PostOp<f32, SimilarityScore<f32>> for CosineNormalized {
362 #[inline(always)]
363 fn post_op(x: f32) -> SimilarityScore<f32> {
364 SimilarityScore::new(cosine_transformation(x))
370 }
371}
372
373impl PostOp<f32, MathematicalValue<f32>> for CosineNormalized {
374 #[inline(always)]
375 fn post_op(x: f32) -> MathematicalValue<f32> {
376 MathematicalValue::new(x)
377 }
378}
379
380impl PostOp<f32, f32> for CosineNormalized {
381 #[inline(always)]
382 fn post_op(x: f32) -> f32 {
383 <Self as PostOp<f32, SimilarityScore<f32>>>::post_op(x).into_inner()
384 }
385}
386
387architecture_hook!(CosineNormalized, simd::IP);
388use_simd_implementation!(CosineNormalized, f32, f32);
389use_simd_implementation!(CosineNormalized, f32, Half);
390use_simd_implementation!(CosineNormalized, Half, Half);
391
392#[derive(Debug, Clone, Copy, Default)]
398pub struct L1NormFunctor {}
399
400impl PostOp<f32, f32> for L1NormFunctor {
401 #[inline(always)]
402 fn post_op(x: f32) -> f32 {
403 x
404 }
405}
406
407architecture_hook!(L1NormFunctor, simd::L1Norm);
408
409impl PureDistanceFunction<&[f32], &[f32], f32> for L1NormFunctor {
410 #[inline]
411 fn evaluate(x: &[f32], y: &[f32]) -> f32 {
412 L1NormFunctor::default().run(ARCH, x, y)
413 }
414}
415
416#[cfg(test)]
421mod tests {
422
423 use std::hash::{Hash, Hasher};
424
425 use approx::assert_relative_eq;
426 use rand::{Rng, SeedableRng};
427
428 use super::*;
429 use crate::{
430 distance::{
431 reference::{self, ReferenceProvider},
432 Metric,
433 },
434 test_util::{self, Normalize},
435 };
436
437 pub fn as_function_pointer<T, Left, Right, Return>(x: &[Left], y: &[Right]) -> Return
438 where
439 T: for<'a, 'b> PureDistanceFunction<&'a [Left], &'b [Right], Return>,
440 {
441 T::evaluate(x, y)
442 }
443
444 fn simd_provider(metric: Metric) -> fn(&[f32], &[f32]) -> f32 {
445 match metric {
446 Metric::L2 => as_function_pointer::<SquaredL2, _, _, _>,
447 Metric::InnerProduct => as_function_pointer::<InnerProduct, _, _, _>,
448 Metric::Cosine => as_function_pointer::<Cosine, _, _, _>,
449 Metric::CosineNormalized => as_function_pointer::<CosineNormalized, _, _, _>,
450 }
451 }
452
453 fn random_normal_arguments(dim: usize, lo: f32, hi: f32, seed: u64) -> (Vec<f32>, Vec<f32>) {
454 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
455 let x: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
456 let y: Vec<f32> = (0..dim).map(|_| rng.random_range(lo..hi)).collect();
457 (x, y)
458 }
459
460 struct LeftRightPair {
461 pub x: Vec<f32>,
462 pub y: Vec<f32>,
463 }
464
465 fn generate_corner_cases(dim: usize) -> Vec<LeftRightPair> {
466 let mut output = Vec::<LeftRightPair>::new();
467 let fixed_values = [0.0, -5.0, 5.0, 10.0];
468
469 for va in fixed_values.iter() {
470 for vb in fixed_values.iter() {
471 let x: Vec<f32> = vec![*va; dim];
472 let y: Vec<f32> = vec![*vb; dim];
473 output.push(LeftRightPair { x, y });
474 }
475 }
476 output
477 }
478
479 fn collect_random_arguments(
480 dim: usize,
481 num_trials: usize,
482 lo: f32,
483 hi: f32,
484 mut seed: u64,
485 ) -> Vec<LeftRightPair> {
486 (0..num_trials)
487 .map(|_| {
488 let (x, y) = random_normal_arguments(dim, lo, hi, seed);
489
490 let mut hasher = std::hash::DefaultHasher::new();
492 seed.hash(&mut hasher);
493 seed = hasher.finish();
494
495 LeftRightPair { x, y }
496 })
497 .collect()
498 }
499
500 fn test_pure_functions_impl<T>(metric: Metric, _func: T, normalize: bool)
501 where
502 T: for<'a, 'b> PureDistanceFunction<&'a [f32], &'b [f32], f32> + Clone,
503 {
504 let epsilon: f32 = 1e-4;
505 let max_relative: f32 = 1e-4;
506
507 let max_dim = 256;
508 let num_trials = 10;
509
510 let f_reference = <f32 as ReferenceProvider<f32>>::reference_implementation(metric);
511 let f_simd = simd_provider(metric);
512
513 let run_tests = |argument_pairs: Vec<LeftRightPair>| {
515 for LeftRightPair { mut x, mut y } in argument_pairs {
516 if normalize {
517 x.normalize();
518 y.normalize();
519 }
520
521 let reference: f32 = f_reference(&x, &y).into_inner();
522 let simd = f_simd(&x, &y);
523
524 assert_relative_eq!(
525 reference,
526 simd,
527 epsilon = epsilon,
528 max_relative = max_relative
529 );
530
531 let simd_direct = T::evaluate(&x, &y);
533 assert_eq!(simd_direct, simd);
534 }
535 };
536
537 for dim in 0..max_dim {
539 run_tests(generate_corner_cases(dim));
540 }
541
542 for dim in 0..max_dim {
544 run_tests(collect_random_arguments(
545 dim, num_trials, -10.0, 10.0, 0x5643,
546 ));
547 }
548 }
549
550 #[test]
551 fn test_pure_functions() {
552 println!("L2");
553 test_pure_functions_impl(Metric::L2, SquaredL2 {}, false);
554 println!("InnerProduct");
555 test_pure_functions_impl(Metric::InnerProduct, InnerProduct {}, false);
556 println!("Cosine");
557 test_pure_functions_impl(Metric::Cosine, Cosine {}, false);
558 println!("CosineNormalized");
559 test_pure_functions_impl(Metric::CosineNormalized, CosineNormalized {}, true);
560 }
561
562 #[test]
565 fn test_specialize() {
566 use diskann_wide::arch::FTarget2;
567
568 const DIM: usize = 123;
569 let (x, y) = random_normal_arguments(DIM, -100.0, 100.0, 0x023457AA);
570
571 let reference: f32 = SquaredL2::evaluate(x.as_slice(), y.as_slice());
572 let evaluated: f32 =
573 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
574
575 assert_eq!(reference, evaluated);
577 }
578
579 #[test]
580 #[should_panic]
581 fn test_function_pointer_const_panics_left() {
582 use diskann_wide::arch::FTarget2;
583
584 const DIM: usize = 34;
585 let x = vec![0.0f32; DIM + 1];
586 let y = vec![0.0f32; DIM];
587 let _: f32 =
589 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
590 }
591
592 #[test]
593 #[should_panic]
594 fn test_function_pointer_const_panics_right() {
595 use diskann_wide::arch::FTarget2;
596
597 const DIM: usize = 34;
598 let x = vec![0.0f32; DIM];
599 let y = vec![0.0f32; DIM + 1];
600 let _: f32 =
602 Specialize::<DIM, SquaredL2>::run(diskann_wide::ARCH, x.as_slice(), y.as_slice());
603 }
604
605 trait GetInner {
610 fn get_inner(self) -> f32;
611 }
612
613 impl GetInner for f32 {
614 fn get_inner(self) -> f32 {
615 self
616 }
617 }
618
619 impl GetInner for SimilarityScore<f32> {
620 fn get_inner(self) -> f32 {
621 self.into_inner()
622 }
623 }
624
625 impl GetInner for MathematicalValue<f32> {
626 fn get_inner(self) -> f32 {
627 self.into_inner()
628 }
629 }
630
631 #[derive(Clone, Copy)]
633 struct EpsilonAndRelative {
634 epsilon: f32,
635 max_relative: f32,
636 }
637
638 #[allow(clippy::too_many_arguments)]
639 fn run_test<L, R, To, Distribution, Callback>(
640 under_test: fn(&[L], &[R]) -> To,
641 reference: fn(&[L], &[R]) -> To,
642 bounds: EpsilonAndRelative,
643 dim: usize,
644 num_trials: usize,
645 distribution: Distribution,
646 rng: &mut impl Rng,
647 mut cb: Callback,
648 ) where
649 L: test_util::CornerCases,
650 R: test_util::CornerCases,
651 Distribution:
652 test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
653 To: GetInner + Copy,
654 Callback: FnMut(To, To),
655 {
656 let checker =
657 test_util::Checker::<L, R, To>::new(under_test, reference, |got, expected| {
658 cb(got, expected);
660 assert_relative_eq!(
661 got.get_inner(),
662 expected.get_inner(),
663 epsilon = bounds.epsilon,
664 max_relative = bounds.max_relative
665 );
666 });
667
668 test_util::test_distance_function(
669 checker,
670 distribution.clone(),
671 distribution.clone(),
672 dim,
673 num_trials,
674 rng,
675 );
676 }
677
678 #[cfg(not(debug_assertions))]
680 const MAX_DIM: usize = 256;
681
682 #[cfg(debug_assertions)]
683 const MAX_DIM: usize = 160;
684
685 #[cfg(not(debug_assertions))]
687 const INTEGER_TRIALS: usize = 10000;
688
689 #[cfg(debug_assertions)]
690 const INTEGER_TRIALS: usize = 100;
691
692 fn run_integer_test<T, R>(
699 under_test: fn(&[T], &[T]) -> R,
700 reference: fn(&[T], &[T]) -> R,
701 rng: &mut impl Rng,
702 ) where
703 T: test_util::CornerCases,
704 R: GetInner + Copy,
705 rand::distr::StandardUniform: test_util::GenerateRandomArguments<T> + Clone,
706 {
707 let distribution = rand::distr::StandardUniform {};
708 let num_corner_cases = <T as test_util::CornerCases>::corner_cases().len();
709
710 for dim in 0..MAX_DIM {
711 let mut callcount = 0;
712 let callback = |_, _| {
713 callcount += 1;
714 };
715
716 run_test(
717 under_test,
718 reference,
719 EpsilonAndRelative {
720 epsilon: 0.0,
721 max_relative: 0.0,
722 },
723 dim,
724 INTEGER_TRIALS,
725 distribution,
726 rng,
727 callback,
728 );
729
730 assert_eq!(
732 callcount,
733 INTEGER_TRIALS + num_corner_cases * num_corner_cases
734 );
735 }
736 }
737
738 #[test]
743 fn test_l2_i8_mathematical() {
744 let mut rng = rand::rngs::StdRng::seed_from_u64(0x2bb701074c2b81c9);
745 run_integer_test(
746 as_function_pointer::<FullL2, i8, i8, MathematicalValue<f32>>,
747 reference::reference_l2_i8_mathematical,
748 &mut rng,
749 );
750 }
751
752 #[test]
753 fn test_l2_u8_mathematical() {
754 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9284ced6d080808c);
755 run_integer_test(
756 as_function_pointer::<FullL2, u8, u8, MathematicalValue<f32>>,
757 reference::reference_l2_u8_mathematical,
758 &mut rng,
759 );
760 }
761
762 #[test]
763 fn test_l2_i8_similarity() {
764 let mut rng = rand::rngs::StdRng::seed_from_u64(0xb196fecc4def04fa);
765 run_integer_test(
766 as_function_pointer::<FullL2, i8, i8, SimilarityScore<f32>>,
767 reference::reference_l2_i8_similarity,
768 &mut rng,
769 );
770 }
771
772 #[test]
773 fn test_l2_u8_similarity() {
774 let mut rng = rand::rngs::StdRng::seed_from_u64(0x07f6463e4a654aea);
775 run_integer_test(
776 as_function_pointer::<FullL2, u8, u8, SimilarityScore<f32>>,
777 reference::reference_l2_u8_similarity,
778 &mut rng,
779 );
780 }
781
782 #[test]
787 fn test_innerproduct_i8_mathematical() {
788 let mut rng = rand::rngs::StdRng::seed_from_u64(0x2c1b1bddda5774be);
789 run_integer_test(
790 as_function_pointer::<InnerProduct, i8, i8, MathematicalValue<f32>>,
791 reference::reference_innerproduct_i8_mathematical,
792 &mut rng,
793 );
794 }
795
796 #[test]
797 fn test_innerproduct_u8_mathematical() {
798 let mut rng = rand::rngs::StdRng::seed_from_u64(0x757e363832d7f215);
799 run_integer_test(
800 as_function_pointer::<InnerProduct, u8, u8, MathematicalValue<f32>>,
801 reference::reference_innerproduct_u8_mathematical,
802 &mut rng,
803 );
804 }
805
806 #[test]
807 fn test_innerproduct_i8_similarity() {
808 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4788ce0b991eb15a);
809 run_integer_test(
810 as_function_pointer::<InnerProduct, i8, i8, SimilarityScore<f32>>,
811 reference::reference_innerproduct_i8_similarity,
812 &mut rng,
813 );
814 }
815
816 #[test]
817 fn test_innerproduct_u8_similarity() {
818 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4994adb68f814d96);
819 run_integer_test(
820 as_function_pointer::<InnerProduct, u8, u8, SimilarityScore<f32>>,
821 reference::reference_innerproduct_u8_similarity,
822 &mut rng,
823 );
824 }
825
826 #[test]
831 fn test_cosine_i8_mathematical() {
832 let mut rng = rand::rngs::StdRng::seed_from_u64(0xedef81c780491ada);
833 run_integer_test(
834 as_function_pointer::<Cosine, i8, i8, MathematicalValue<f32>>,
835 reference::reference_cosine_i8_mathematical,
836 &mut rng,
837 );
838 }
839
840 #[test]
841 fn test_cosine_u8_mathematical() {
842 let mut rng = rand::rngs::StdRng::seed_from_u64(0x107cee2adcc58b73);
843 run_integer_test(
844 as_function_pointer::<Cosine, u8, u8, MathematicalValue<f32>>,
845 reference::reference_cosine_u8_mathematical,
846 &mut rng,
847 );
848 }
849
850 #[test]
851 fn test_cosine_i8_similarity() {
852 let mut rng = rand::rngs::StdRng::seed_from_u64(0x02d95c1cc0843647);
853 run_integer_test(
854 as_function_pointer::<Cosine, i8, i8, SimilarityScore<f32>>,
855 reference::reference_cosine_i8_similarity,
856 &mut rng,
857 );
858 }
859
860 #[test]
861 fn test_cosine_u8_similarity() {
862 let mut rng = rand::rngs::StdRng::seed_from_u64(0xf5ea1974bf8d8b3b);
863 run_integer_test(
864 as_function_pointer::<Cosine, u8, u8, SimilarityScore<f32>>,
865 reference::reference_cosine_u8_similarity,
866 &mut rng,
867 );
868 }
869
870 fn run_float_test<L, R, To, Dist>(
877 under_test: fn(&[L], &[R]) -> To,
878 reference: fn(&[L], &[R]) -> To,
879 rng: &mut impl Rng,
880 distribution: Dist,
881 bounds: EpsilonAndRelative,
882 ) where
883 L: test_util::CornerCases,
884 R: test_util::CornerCases,
885 To: GetInner + Copy,
886 Dist: test_util::GenerateRandomArguments<L> + test_util::GenerateRandomArguments<R> + Clone,
887 {
888 let left_corner_cases = <L as test_util::CornerCases>::corner_cases().len();
889 let right_corner_cases = <R as test_util::CornerCases>::corner_cases().len();
890 for dim in 0..MAX_DIM {
891 let mut callcount = 0;
892 let callback = |_, _| {
893 callcount += 1;
894 };
895
896 run_test(
897 under_test,
898 reference,
899 bounds,
900 dim,
901 INTEGER_TRIALS,
902 distribution.clone(),
903 rng,
904 callback,
905 );
906
907 assert_eq!(
909 callcount,
910 INTEGER_TRIALS + left_corner_cases * right_corner_cases
911 );
912 }
913 }
914
915 fn expected_l2_errors() -> EpsilonAndRelative {
920 EpsilonAndRelative {
921 epsilon: 0.0,
922 max_relative: 1.2e-6,
923 }
924 }
925
926 #[test]
927 fn test_l2_f32_mathematical() {
928 let mut rng = rand::rngs::StdRng::seed_from_u64(0x6d22d320bdf35aec);
929 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
930 run_float_test(
931 as_function_pointer::<FullL2, f32, f32, MathematicalValue<f32>>,
932 reference::reference_l2_f32_mathematical,
933 &mut rng,
934 distribution,
935 expected_l2_errors(),
936 );
937 }
938
939 #[test]
940 fn test_l2_f16_mathematical() {
941 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
942 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
943 run_float_test(
944 as_function_pointer::<FullL2, Half, Half, MathematicalValue<f32>>,
945 reference::reference_l2_f16_mathematical,
946 &mut rng,
947 distribution,
948 expected_l2_errors(),
949 );
950 }
951
952 #[test]
953 fn test_l2_f32xf16_mathematical() {
954 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
955 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
956
957 run_float_test(
958 as_function_pointer::<FullL2, f32, Half, MathematicalValue<f32>>,
959 reference::reference_l2_f32xf16_mathematical,
960 &mut rng,
961 distribution,
962 expected_l2_errors(),
963 );
964 }
965
966 #[test]
967 fn test_l2_f32_similarity() {
968 let mut rng = rand::rngs::StdRng::seed_from_u64(0xbfc5f4b42b5bc0c1);
969 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
970 run_float_test(
971 as_function_pointer::<FullL2, f32, f32, SimilarityScore<f32>>,
972 reference::reference_l2_f32_similarity,
973 &mut rng,
974 distribution,
975 expected_l2_errors(),
976 );
977 }
978
979 #[test]
980 fn test_l2_f16_similarity() {
981 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9d3809d84f54e4b6);
982 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
983 run_float_test(
984 as_function_pointer::<FullL2, Half, Half, SimilarityScore<f32>>,
985 reference::reference_l2_f16_similarity,
986 &mut rng,
987 distribution,
988 expected_l2_errors(),
989 );
990 }
991
992 #[test]
993 fn test_l2_f32xf16_similarity() {
994 let mut rng = rand::rngs::StdRng::seed_from_u64(0x755819460c190db4);
995 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
996
997 run_float_test(
998 as_function_pointer::<FullL2, f32, Half, SimilarityScore<f32>>,
999 reference::reference_l2_f32xf16_similarity,
1000 &mut rng,
1001 distribution,
1002 expected_l2_errors(),
1003 );
1004 }
1005
1006 fn expected_innerproduct_errors() -> EpsilonAndRelative {
1011 EpsilonAndRelative {
1012 epsilon: 2.5e-5,
1013 max_relative: 1.6e-5,
1014 }
1015 }
1016
1017 #[test]
1018 fn test_innerproduct_f32_mathematical() {
1019 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1ef6ac3b65869792);
1020 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1021 run_float_test(
1022 as_function_pointer::<InnerProduct, f32, f32, MathematicalValue<f32>>,
1023 reference::reference_innerproduct_f32_mathematical,
1024 &mut rng,
1025 distribution,
1026 expected_innerproduct_errors(),
1027 );
1028 }
1029
1030 #[test]
1031 fn test_innerproduct_f16_mathematical() {
1032 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1033 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1034 run_float_test(
1035 as_function_pointer::<InnerProduct, Half, Half, MathematicalValue<f32>>,
1036 reference::reference_innerproduct_f16_mathematical,
1037 &mut rng,
1038 distribution,
1039 expected_innerproduct_errors(),
1040 );
1041 }
1042
1043 #[test]
1044 fn test_innerproduct_f32xf16_mathematical() {
1045 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1046 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1047 run_float_test(
1048 as_function_pointer::<InnerProduct, f32, Half, MathematicalValue<f32>>,
1049 reference::reference_innerproduct_f32xf16_mathematical,
1050 &mut rng,
1051 distribution,
1052 expected_innerproduct_errors(),
1053 );
1054 }
1055
1056 #[test]
1057 fn test_innerproduct_f32_similarity() {
1058 let mut rng = rand::rngs::StdRng::seed_from_u64(0x40326b22a57db0d7);
1059 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1060 run_float_test(
1061 as_function_pointer::<InnerProduct, f32, f32, SimilarityScore<f32>>,
1062 reference::reference_innerproduct_f32_similarity,
1063 &mut rng,
1064 distribution,
1065 expected_innerproduct_errors(),
1066 );
1067 }
1068
1069 #[test]
1070 fn test_innerproduct_f16_similarity() {
1071 let mut rng = rand::rngs::StdRng::seed_from_u64(0xfb8cff47bcbc9528);
1072 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1073 run_float_test(
1074 as_function_pointer::<InnerProduct, Half, Half, SimilarityScore<f32>>,
1075 reference::reference_innerproduct_f16_similarity,
1076 &mut rng,
1077 distribution,
1078 expected_innerproduct_errors(),
1079 );
1080 }
1081
1082 #[test]
1083 fn test_innerproduct_f32xf16_similarity() {
1084 let mut rng = rand::rngs::StdRng::seed_from_u64(0x24c51e4b825b0329);
1085 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1086 run_float_test(
1087 as_function_pointer::<InnerProduct, f32, Half, SimilarityScore<f32>>,
1088 reference::reference_innerproduct_f32xf16_similarity,
1089 &mut rng,
1090 distribution,
1091 expected_innerproduct_errors(),
1092 );
1093 }
1094
1095 fn expected_cosine_errors() -> EpsilonAndRelative {
1100 EpsilonAndRelative {
1101 epsilon: 3e-7,
1102 max_relative: 5e-6,
1103 }
1104 }
1105
1106 #[test]
1107 fn test_cosine_f32_mathematical() {
1108 let mut rng = rand::rngs::StdRng::seed_from_u64(0xca6eaac942999500);
1109 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1110 run_float_test(
1111 as_function_pointer::<Cosine, f32, f32, MathematicalValue<f32>>,
1112 reference::reference_cosine_f32_mathematical,
1113 &mut rng,
1114 distribution,
1115 expected_cosine_errors(),
1116 );
1117 }
1118
1119 #[test]
1120 fn test_cosine_f16_mathematical() {
1121 let mut rng = rand::rngs::StdRng::seed_from_u64(0xa736c789aa16ce86);
1122 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1123 run_float_test(
1124 as_function_pointer::<Cosine, Half, Half, MathematicalValue<f32>>,
1125 reference::reference_cosine_f16_mathematical,
1126 &mut rng,
1127 distribution,
1128 expected_cosine_errors(),
1129 );
1130 }
1131
1132 #[test]
1133 fn test_cosine_f32xf16_mathematical() {
1134 let mut rng = rand::rngs::StdRng::seed_from_u64(0xac550231088a0d5c);
1135 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1136 run_float_test(
1137 as_function_pointer::<Cosine, f32, Half, MathematicalValue<f32>>,
1138 reference::reference_cosine_f32xf16_mathematical,
1139 &mut rng,
1140 distribution,
1141 expected_cosine_errors(),
1142 );
1143 }
1144
1145 #[test]
1146 fn test_cosine_f32_similarity() {
1147 let mut rng = rand::rngs::StdRng::seed_from_u64(0x4a09ad987a6204f3);
1148 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1149 run_float_test(
1150 as_function_pointer::<Cosine, f32, f32, SimilarityScore<f32>>,
1151 reference::reference_cosine_f32_similarity,
1152 &mut rng,
1153 distribution,
1154 expected_cosine_errors(),
1155 );
1156 }
1157
1158 #[test]
1159 fn test_cosine_f16_similarity() {
1160 let mut rng = rand::rngs::StdRng::seed_from_u64(0x77a48d1914f850f2);
1161 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1162 run_float_test(
1163 as_function_pointer::<Cosine, Half, Half, SimilarityScore<f32>>,
1164 reference::reference_cosine_f16_similarity,
1165 &mut rng,
1166 distribution,
1167 expected_cosine_errors(),
1168 );
1169 }
1170
1171 #[test]
1172 fn test_cosine_f32xf16_similarity() {
1173 let mut rng = rand::rngs::StdRng::seed_from_u64(0xbd7471b815655ca1);
1174 let distribution = rand_distr::Normal::new(0.0, 1.0).unwrap();
1175 run_float_test(
1176 as_function_pointer::<Cosine, f32, Half, SimilarityScore<f32>>,
1177 reference::reference_cosine_f32xf16_similarity,
1178 &mut rng,
1179 distribution,
1180 expected_cosine_errors(),
1181 );
1182 }
1183
1184 fn expected_cosine_normalized_errors() -> EpsilonAndRelative {
1189 EpsilonAndRelative {
1190 epsilon: 3e-7,
1191 max_relative: 5e-6,
1192 }
1193 }
1194
1195 #[test]
1196 fn test_cosine_normalized_f32_mathematical() {
1197 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1fda98112747f8dd);
1198 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1199 run_float_test(
1200 as_function_pointer::<CosineNormalized, f32, f32, MathematicalValue<f32>>,
1201 reference::reference_cosine_normalized_f32_mathematical,
1202 &mut rng,
1203 test_util::Normalized(distribution),
1204 expected_cosine_normalized_errors(),
1205 );
1206 }
1207
1208 #[test]
1209 fn test_cosine_normalized_f16_mathematical() {
1210 let mut rng = rand::rngs::StdRng::seed_from_u64(0x5e8c5d5e19cdd840);
1211 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1212 run_float_test(
1213 as_function_pointer::<CosineNormalized, Half, Half, MathematicalValue<f32>>,
1214 reference::reference_cosine_normalized_f16_mathematical,
1215 &mut rng,
1216 test_util::Normalized(distribution),
1217 expected_cosine_normalized_errors(),
1218 );
1219 }
1220
1221 #[test]
1222 fn test_cosine_normalized_f32xf16_mathematical() {
1223 let mut rng = rand::rngs::StdRng::seed_from_u64(0x3fd01e1c11c9bc45);
1224 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1225 run_float_test(
1226 as_function_pointer::<CosineNormalized, f32, Half, MathematicalValue<f32>>,
1227 reference::reference_cosine_normalized_f32xf16_mathematical,
1228 &mut rng,
1229 test_util::Normalized(distribution),
1230 expected_cosine_normalized_errors(),
1231 );
1232 }
1233
1234 #[test]
1235 fn test_cosine_normalized_f32_similarity() {
1236 let mut rng = rand::rngs::StdRng::seed_from_u64(0x9446d057870e5605);
1237 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1238 run_float_test(
1239 as_function_pointer::<CosineNormalized, f32, f32, SimilarityScore<f32>>,
1240 reference::reference_cosine_normalized_f32_similarity,
1241 &mut rng,
1242 test_util::Normalized(distribution),
1243 expected_cosine_normalized_errors(),
1244 );
1245 }
1246
1247 #[test]
1248 fn test_cosine_normalized_f16_similarity() {
1249 let mut rng = rand::rngs::StdRng::seed_from_u64(0x885c371801f18174);
1250 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1251 run_float_test(
1252 as_function_pointer::<CosineNormalized, Half, Half, SimilarityScore<f32>>,
1253 reference::reference_cosine_normalized_f16_similarity,
1254 &mut rng,
1255 test_util::Normalized(distribution),
1256 expected_cosine_normalized_errors(),
1257 );
1258 }
1259
1260 #[test]
1261 fn test_cosine_normalized_f32xf16_similarity() {
1262 let mut rng = rand::rngs::StdRng::seed_from_u64(0x1c356c92d0522c0f);
1263 let distribution = rand_distr::Normal::new(-1.0, 1.0).unwrap();
1264 run_float_test(
1265 as_function_pointer::<CosineNormalized, f32, Half, SimilarityScore<f32>>,
1266 reference::reference_cosine_normalized_f32xf16_similarity,
1267 &mut rng,
1268 test_util::Normalized(distribution),
1269 expected_cosine_normalized_errors(),
1270 );
1271 }
1272}