1use diskann_vector::{MathematicalValue, PureDistanceFunction};
7use thiserror::Error;
8
9use crate::{
10 alloc::GlobalAllocator,
11 bits::{BitSlice, Dense, Representation, Unsigned},
12 distances,
13 distances::{InnerProduct, MV},
14 meta::{self, slice},
15};
16
17#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
44#[repr(C)]
45pub struct MinMaxCompensation {
46 pub dim: u32, pub b: f32, pub n: f32, pub a: f32, pub norm_squared: f32, }
52
53const META_BYTES: usize = std::mem::size_of::<MinMaxCompensation>(); #[derive(Debug, Error, Clone, PartialEq, Eq)]
58pub enum MetaParseError {
59 #[error("Invalid size: {0}, must contain at least {META_BYTES} bytes")]
60 NotCanonical(usize),
61}
62
63impl MinMaxCompensation {
64 #[inline(always)]
80 pub fn read_dimension(bytes: &[u8]) -> Result<usize, MetaParseError> {
81 if bytes.len() < META_BYTES {
82 return Err(MetaParseError::NotCanonical(bytes.len()));
83 }
84
85 let dim_bytes: [u8; 4] = bytes.get(..4).map_or_else(
87 || Err(MetaParseError::NotCanonical(bytes.len())),
88 |slice| {
89 slice
90 .try_into()
91 .map_err(|_| MetaParseError::NotCanonical(bytes.len()))
92 },
93 )?;
94
95 let dim = u32::from_le_bytes(dim_bytes) as usize;
96
97 Ok(dim)
98 }
99}
100
101pub type Data<const NBITS: usize> = meta::Vector<NBITS, Unsigned, MinMaxCompensation, Dense>;
105
106pub type DataRef<'a, const NBITS: usize> =
110 meta::VectorRef<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
111
112#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
113pub enum DecompressError {
114 #[error("expected src and dst length to be identical, instead src is {0}, and dst is {1}")]
115 LengthMismatch(usize, usize),
116}
117impl<const NBITS: usize> DataRef<'_, NBITS>
118where
119 Unsigned: Representation<NBITS>,
120{
121 pub fn decompress_into(&self, dst: &mut [f32]) -> Result<(), DecompressError> {
137 if dst.len() != self.len() {
138 return Err(DecompressError::LengthMismatch(self.len(), dst.len()));
139 }
140 let meta = self.meta();
141
142 dst.iter_mut().enumerate().for_each(|(i, d)| unsafe {
145 *d = self.vector().get_unchecked(i) as f32 * meta.a + meta.b
146 });
147 Ok(())
148 }
149}
150
151pub type DataMutRef<'a, const NBITS: usize> =
155 meta::VectorMut<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
156
157#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
179#[repr(C)]
180pub struct FullQueryMeta {
181 pub sum: f32,
183 pub norm_squared: f32,
185}
186
187pub type FullQuery<A = GlobalAllocator> = slice::PolySlice<f32, FullQueryMeta, A>;
191
192pub type FullQueryRef<'a> = slice::SliceRef<'a, f32, FullQueryMeta>;
196
197pub type FullQueryMut<'a> = slice::SliceMut<'a, f32, FullQueryMeta>;
201
202#[inline(always)]
206fn kernel<const NBITS: usize, const MBITS: usize, F>(
207 x: DataRef<'_, NBITS>,
208 y: DataRef<'_, MBITS>,
209 f: F,
210) -> distances::MathematicalResult<f32>
211where
212 Unsigned: Representation<NBITS> + Representation<MBITS>,
213 InnerProduct: for<'a, 'b> PureDistanceFunction<
214 BitSlice<'a, NBITS, Unsigned>,
215 BitSlice<'b, MBITS, Unsigned>,
216 distances::MathematicalResult<u32>,
217 >,
218 F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
219{
220 let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?;
221 let (xm, ym) = (x.meta(), y.meta());
222 let term0 = xm.a * ym.a * raw_product.into_inner() as f32;
223 let term1_x = xm.n * ym.b;
224 let term1_y = ym.n * xm.b;
225 let term2 = xm.b * ym.b * (x.len() as f32);
226
227 let v = term0 + term1_x + term1_y + term2;
228 Ok(MV::new(f(v, &xm, &ym)))
229}
230
231pub struct MinMaxIP;
232
233impl<const NBITS: usize, const MBITS: usize>
234 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
235 for MinMaxIP
236where
237 Unsigned: Representation<NBITS> + Representation<MBITS>,
238 InnerProduct: for<'a, 'b> PureDistanceFunction<
239 BitSlice<'a, NBITS, Unsigned>,
240 BitSlice<'b, MBITS, Unsigned>,
241 distances::MathematicalResult<u32>,
242 >,
243{
244 #[inline(always)]
245 fn evaluate(
246 x: DataRef<'_, NBITS>,
247 y: DataRef<'_, MBITS>,
248 ) -> distances::MathematicalResult<f32> {
249 kernel(x, y, |v, _, _| v)
250 }
251}
252
253impl<const NBITS: usize, const MBITS: usize>
254 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
255 for MinMaxIP
256where
257 Unsigned: Representation<NBITS> + Representation<MBITS>,
258 InnerProduct: for<'a, 'b> PureDistanceFunction<
259 BitSlice<'a, NBITS, Unsigned>,
260 BitSlice<'b, MBITS, Unsigned>,
261 distances::MathematicalResult<u32>,
262 >,
263{
264 #[inline(always)]
265 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
266 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
267 Ok(-v?.into_inner())
268 }
269}
270
271impl<const NBITS: usize>
272 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
273 for MinMaxIP
274where
275 Unsigned: Representation<NBITS>,
276 InnerProduct: for<'a, 'b> PureDistanceFunction<
277 &'a [f32],
278 BitSlice<'b, NBITS, Unsigned>,
279 distances::MathematicalResult<f32>,
280 >,
281{
282 #[inline(always)]
283 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
284 let raw_product: f32 = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
285 Ok(MathematicalValue::new(
286 raw_product * y.meta().a + x.meta().sum * y.meta().b,
287 ))
288 }
289}
290
291impl<const NBITS: usize>
292 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxIP
293where
294 Unsigned: Representation<NBITS>,
295 InnerProduct: for<'a, 'b> PureDistanceFunction<
296 &'a [f32],
297 BitSlice<'b, NBITS, Unsigned>,
298 distances::MathematicalResult<f32>,
299 >,
300{
301 #[inline(always)]
302 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
303 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
304 Ok(-v?.into_inner())
305 }
306}
307
308pub struct MinMaxL2Squared;
309
310impl<const NBITS: usize, const MBITS: usize>
311 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
312 for MinMaxL2Squared
313where
314 Unsigned: Representation<NBITS> + Representation<MBITS>,
315 InnerProduct: for<'a, 'b> PureDistanceFunction<
316 BitSlice<'a, NBITS, Unsigned>,
317 BitSlice<'b, MBITS, Unsigned>,
318 distances::MathematicalResult<u32>,
319 >,
320{
321 #[inline(always)]
322 fn evaluate(
323 x: DataRef<'_, NBITS>,
324 y: DataRef<'_, MBITS>,
325 ) -> distances::MathematicalResult<f32> {
326 kernel(x, y, |v, xm, ym| {
327 -2.0 * v + xm.norm_squared + ym.norm_squared
328 })
329 }
330}
331
332impl<const NBITS: usize, const MBITS: usize>
333 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
334 for MinMaxL2Squared
335where
336 Unsigned: Representation<NBITS> + Representation<MBITS>,
337 InnerProduct: for<'a, 'b> PureDistanceFunction<
338 BitSlice<'a, NBITS, Unsigned>,
339 BitSlice<'b, MBITS, Unsigned>,
340 distances::MathematicalResult<u32>,
341 >,
342{
343 #[inline(always)]
344 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
345 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
346 Ok(v?.into_inner())
347 }
348}
349
350impl<const NBITS: usize>
351 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
352 for MinMaxL2Squared
353where
354 Unsigned: Representation<NBITS>,
355 InnerProduct: for<'a, 'b> PureDistanceFunction<
356 &'a [f32],
357 BitSlice<'b, NBITS, Unsigned>,
358 distances::MathematicalResult<f32>,
359 >,
360{
361 #[inline(always)]
362 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
363 let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
364
365 let ym = y.meta();
366 let compensated_ip = raw_product * ym.a + x.meta().sum * ym.b;
367 Ok(MV::new(
368 x.meta().norm_squared + ym.norm_squared - 2.0 * compensated_ip,
369 ))
370 }
371}
372
373impl<const NBITS: usize>
374 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
375 for MinMaxL2Squared
376where
377 Unsigned: Representation<NBITS>,
378 InnerProduct: for<'a, 'b> PureDistanceFunction<
379 &'a [f32],
380 BitSlice<'b, NBITS, Unsigned>,
381 distances::MathematicalResult<f32>,
382 >,
383{
384 #[inline(always)]
385 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
386 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
387 Ok(v?.into_inner())
388 }
389}
390
391pub struct MinMaxCosine;
396
397impl<const NBITS: usize, const MBITS: usize>
398 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
399 for MinMaxCosine
400where
401 Unsigned: Representation<NBITS> + Representation<MBITS>,
402 MinMaxIP: for<'a, 'b> PureDistanceFunction<
403 DataRef<'a, NBITS>,
404 DataRef<'b, MBITS>,
405 distances::MathematicalResult<f32>,
406 >,
407{
408 #[inline(always)]
410 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
411 let ip: MV<f32> = MinMaxIP::evaluate(x, y)?;
412 let (xm, ym) = (x.meta(), y.meta());
413 Ok(1.0 - ip.into_inner() / (xm.norm_squared.sqrt() * ym.norm_squared.sqrt()))
414 }
415}
416
417impl<const NBITS: usize>
418 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
419 for MinMaxCosine
420where
421 Unsigned: Representation<NBITS>,
422 MinMaxIP: for<'a, 'b> PureDistanceFunction<
423 FullQueryRef<'a>,
424 DataRef<'b, NBITS>,
425 distances::MathematicalResult<f32>,
426 >,
427{
428 #[inline(always)]
429 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
430 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
431 let (xm, ym) = (x.meta().norm_squared, y.meta());
432 Ok(1.0 - ip.into_inner() / (xm.sqrt() * ym.norm_squared.sqrt()))
433 }
435}
436
437pub struct MinMaxCosineNormalized;
438
439impl<const NBITS: usize, const MBITS: usize>
440 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
441 for MinMaxCosineNormalized
442where
443 Unsigned: Representation<NBITS> + Representation<MBITS>,
444 MinMaxIP: for<'a, 'b> PureDistanceFunction<
445 DataRef<'a, NBITS>,
446 DataRef<'b, MBITS>,
447 distances::MathematicalResult<f32>,
448 >,
449{
450 #[inline(always)]
451 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
452 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
453 Ok(1.0 - ip.into_inner()) }
455}
456
457impl<const NBITS: usize>
458 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
459 for MinMaxCosineNormalized
460where
461 Unsigned: Representation<NBITS>,
462 MinMaxIP: for<'a, 'b> PureDistanceFunction<
463 FullQueryRef<'a>,
464 DataRef<'b, NBITS>,
465 distances::MathematicalResult<f32>,
466 >,
467{
468 #[inline(always)]
469 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
470 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
471 Ok(1.0 - ip.into_inner()) }
473}
474
475#[cfg(test)]
480#[cfg(not(miri))]
481mod minmax_vector_tests {
482 use diskann_utils::Reborrow;
483 use rand::{
484 Rng, SeedableRng,
485 distr::{Distribution, Uniform},
486 rngs::StdRng,
487 };
488
489 use super::*;
490 use crate::{alloc::GlobalAllocator, scalar::bit_scale};
491
492 fn random_minmax_vector<const NBITS: usize>(
497 dim: usize,
498 rng: &mut impl Rng,
499 ) -> (Data<NBITS>, Vec<f32>)
500 where
501 Unsigned: Representation<NBITS>,
502 {
503 let mut v = Data::<NBITS>::new_boxed(dim);
504
505 let domain = Unsigned::domain_const::<NBITS>();
506 let code_dist = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
507
508 {
509 let mut bs = v.vector_mut();
510 for i in 0..dim {
511 bs.set(i, code_dist.sample(rng)).unwrap();
512 }
513 }
514
515 let a: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
516 let b: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
517
518 let original: Vec<f32> = (0..dim)
519 .map(|i| a * v.vector().get(i).unwrap() as f32 + b)
520 .collect();
521
522 let code_sum: f32 = (0..dim).map(|i| v.vector().get(i).unwrap() as f32).sum();
523 let norm_squared: f32 = original.iter().map(|x| x * x).sum();
524
525 v.set_meta(MinMaxCompensation {
526 a,
527 b,
528 n: a * code_sum,
529 norm_squared,
530 dim: dim as u32,
531 });
532
533 (v, original)
534 }
535
536 fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
537 where
538 Unsigned: Representation<NBITS>,
539 InnerProduct: for<'a, 'b> PureDistanceFunction<
540 BitSlice<'a, NBITS, Unsigned>,
541 BitSlice<'b, NBITS, Unsigned>,
542 distances::MathematicalResult<u32>,
543 >,
544 InnerProduct: for<'a, 'b> PureDistanceFunction<
545 &'a [f32],
546 BitSlice<'b, NBITS, Unsigned>,
547 distances::MathematicalResult<f32>,
548 >,
549 R: Rng,
550 {
551 assert!(dim <= bit_scale::<NBITS>() as usize);
552
553 let (v1, original1) = random_minmax_vector::<NBITS>(dim, rng);
554 let (v2, original2) = random_minmax_vector::<NBITS>(dim, rng);
555
556 let norm1_squared = v1.meta().norm_squared;
557 let norm2_squared = v2.meta().norm_squared;
558
559 let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
561
562 let computed_ip_f32: distances::Result<f32> =
564 MinMaxIP::evaluate(v1.reborrow(), v2.reborrow());
565 let computed_ip_f32 = computed_ip_f32.unwrap();
566 assert!(
567 (expected_ip - (-computed_ip_f32)).abs() / expected_ip.abs() < 1e-3,
568 "Inner product (f32) failed: expected {}, got {} on dim : {}",
569 -expected_ip,
570 computed_ip_f32,
571 dim
572 );
573
574 let expected_l2 = (0..dim)
576 .map(|i| original1[i] - original2[i])
577 .map(|x| x.powf(2.0))
578 .sum::<f32>();
579
580 let computed_l2_f32: distances::Result<f32> =
582 MinMaxL2Squared::evaluate(v1.reborrow(), v2.reborrow());
583 let computed_l2_f32 = computed_l2_f32.unwrap();
584 assert!(
585 ((computed_l2_f32 - expected_l2).abs() / expected_l2) < 1e-3,
586 "L2 distance (f32) failed: expected {}, got {} on dim : {}",
587 expected_l2,
588 computed_l2_f32,
589 dim
590 );
591
592 let expected_cosine = 1.0 - expected_ip / (norm1_squared.sqrt() * norm2_squared.sqrt());
593
594 let computed_cosine: distances::Result<f32> =
595 MinMaxCosine::evaluate(v1.reborrow(), v2.reborrow());
596 let computed_cosine = computed_cosine.unwrap();
597
598 {
599 let passed = (computed_cosine - expected_cosine).abs() < 1e-6
600 || ((computed_cosine - expected_cosine).abs() / expected_cosine) < 1e-3;
601
602 assert!(
603 passed,
604 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
605 expected_cosine, computed_cosine, dim
606 );
607 }
608
609 let cosine_normalized: distances::Result<f32> =
610 MinMaxCosineNormalized::evaluate(v1.reborrow(), v2.reborrow());
611 let cosine_normalized = cosine_normalized.unwrap();
612 let expected_cos_normalized = 1.0 - expected_ip;
613 assert!(
614 ((expected_cos_normalized - cosine_normalized).abs() / expected_cos_normalized.abs())
615 < 1e-6,
616 "CosineNormalized distance (f32) failed: expected {}, got {} on dim : {}",
617 expected_cos_normalized,
618 cosine_normalized,
619 dim
620 );
621
622 let mut fp_query = FullQuery::new_in(dim, GlobalAllocator).unwrap();
624 fp_query.vector_mut().copy_from_slice(&original1);
625 *fp_query.meta_mut() = FullQueryMeta {
626 norm_squared: norm1_squared,
627 sum: original1.iter().sum::<f32>(),
628 };
629
630 let fp_ip: distances::Result<f32> = MinMaxIP::evaluate(fp_query.reborrow(), v2.reborrow());
631 let fp_ip = fp_ip.unwrap();
632 assert!(
633 (expected_ip - (-fp_ip)).abs() / expected_ip.abs() < 1e-3,
634 "Inner product (f32) failed: expected {}, got {} on dim : {}",
635 -expected_ip,
636 fp_ip,
637 dim
638 );
639
640 let fp_l2: distances::Result<f32> =
641 MinMaxL2Squared::evaluate(fp_query.reborrow(), v2.reborrow());
642 let fp_l2 = fp_l2.unwrap();
643 assert!(
644 ((fp_l2 - expected_l2).abs() / expected_l2) < 1e-3,
645 "L2 distance (f32) failed: expected {}, got {} on dim : {}",
646 expected_l2,
647 computed_l2_f32,
648 dim
649 );
650
651 let fp_cosine: distances::Result<f32> =
652 MinMaxCosine::evaluate(fp_query.reborrow(), v2.reborrow());
653 let fp_cosine = fp_cosine.unwrap();
654 let diff = (fp_cosine - expected_cosine).abs();
655 assert!(
656 (diff / expected_cosine) < 1e-3 || diff <= 1e-6,
657 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
658 expected_cosine,
659 fp_cosine,
660 dim
661 );
662
663 let fp_cos_norm: distances::Result<f32> =
664 MinMaxCosineNormalized::evaluate(fp_query.reborrow(), v2.reborrow());
665 let fp_cos_norm = fp_cos_norm.unwrap();
666 assert!(
667 (((1.0 - expected_ip) - fp_cos_norm).abs() / (1.0 - expected_ip)) < 1e-3,
668 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
669 (1.0 - expected_ip),
670 fp_cos_norm,
671 dim
672 );
673
674 let meta = v1.meta();
676 let v1_ref = DataRef::new(v1.vector(), &meta);
677 let dim = v1_ref.len();
678 let mut boxed = vec![0f32; dim + 1];
679
680 let pre = v1_ref.decompress_into(&mut boxed);
681 assert_eq!(
682 pre.unwrap_err(),
683 DecompressError::LengthMismatch(dim, dim + 1)
684 );
685 let pre = v1_ref.decompress_into(&mut boxed[..dim - 1]);
686 assert_eq!(
687 pre.unwrap_err(),
688 DecompressError::LengthMismatch(dim, dim - 1)
689 );
690 let pre = v1_ref.decompress_into(&mut boxed[..dim]);
691 assert!(pre.is_ok());
692
693 boxed
694 .iter()
695 .zip(original1.iter())
696 .for_each(|(x, y)| assert!((*x - *y).abs() <= 1e-6));
697
698 let mut bytes = vec![0u8; Data::canonical_bytes(dim)];
700 let mut data = DataMutRef::from_canonical_front_mut(bytes.as_mut_slice(), dim).unwrap();
701 data.set_meta(meta);
702
703 let pre = MinMaxCompensation::read_dimension(&bytes);
704 assert!(pre.is_ok());
705 let read_dim = pre.unwrap();
706 assert_eq!(read_dim, dim);
707
708 let pre = MinMaxCompensation::read_dimension(&[0_u8; 2]);
709 assert_eq!(pre.unwrap_err(), MetaParseError::NotCanonical(2));
710 }
711
712 cfg_if::cfg_if! {
713 if #[cfg(miri)] {
714 const TRIALS: usize = 2;
718 } else {
719 const TRIALS: usize = 10;
720 }
721 }
722
723 macro_rules! test_minmax_compensated {
724 ($name:ident, $nbits:literal, $seed:literal) => {
725 #[test]
726 fn $name() {
727 let mut rng = StdRng::seed_from_u64($seed);
728 const MAX_DIM: usize = (bit_scale::<$nbits>() as usize);
729 for dim in 1..=MAX_DIM {
730 for _ in 0..TRIALS {
731 test_minmax_compensated_vectors::<$nbits, _>(dim, &mut rng);
732 }
733 }
734 }
735 };
736 }
737 test_minmax_compensated!(unsigned_minmax_compensated_test_u1, 1, 0xa33d5658097a1c35);
738 test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
739 test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
740 test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
741
742 fn test_minmax_heterogeneous_kernel<const NBITS: usize, const MBITS: usize, R>(
748 dim: usize,
749 rng: &mut R,
750 ) where
751 Unsigned: Representation<NBITS> + Representation<MBITS>,
752 InnerProduct: for<'a, 'b> PureDistanceFunction<
753 BitSlice<'a, NBITS, Unsigned>,
754 BitSlice<'b, MBITS, Unsigned>,
755 distances::MathematicalResult<u32>,
756 >,
757 R: Rng,
758 {
759 let (v_query, original1) = random_minmax_vector::<NBITS>(dim, rng);
760 let (v_data, original2) = random_minmax_vector::<MBITS>(dim, rng);
761
762 let expected_ip: f32 = original1.iter().zip(&original2).map(|(x, y)| x * y).sum();
764 let computed_ip = kernel(v_query.reborrow(), v_data.reborrow(), |v, _, _| v)
765 .unwrap()
766 .into_inner();
767 assert!(
768 (expected_ip - computed_ip).abs() / expected_ip.abs().max(1e-10) < 1e-6,
769 "Heterogeneous IP ({},{}) failed: expected {}, got {} on dim: {}",
770 NBITS,
771 MBITS,
772 expected_ip,
773 computed_ip,
774 dim,
775 );
776 }
777
778 macro_rules! test_minmax_heterogeneous {
779 ($name:ident, $N:literal, $M:literal, $seed:literal) => {
780 #[test]
781 fn $name() {
782 let mut rng = StdRng::seed_from_u64($seed);
783 const MAX_DIM: usize = bit_scale::<$M>() as usize;
785 for dim in 1..=MAX_DIM {
786 for _ in 0..TRIALS {
787 test_minmax_heterogeneous_kernel::<$N, $M, _>(dim, &mut rng);
788 }
789 }
790 }
791 };
792 }
793
794 test_minmax_heterogeneous!(minmax_heterogeneous_8x4, 8, 4, 0xb7c3d9e5f1a20864);
795 test_minmax_heterogeneous!(minmax_heterogeneous_8x2, 8, 2, 0x4e8f2c6a1d3b5079);
796 test_minmax_heterogeneous!(minmax_heterogeneous_8x1, 8, 1, 0x1b0f2c614d2a7141);
797}