1use diskann_vector::{MathematicalValue, PureDistanceFunction};
7use thiserror::Error;
8
9use crate::{
10 alloc::GlobalAllocator,
11 bits::{BitSlice, Dense, Representation, Unsigned},
12 distances,
13 distances::{InnerProduct, MV},
14 meta::{self, slice},
15};
16
17#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
44#[repr(C)]
45pub struct MinMaxCompensation {
46 pub dim: u32, pub b: f32, pub n: f32, pub a: f32, pub norm_squared: f32, }
52
53const META_BYTES: usize = std::mem::size_of::<MinMaxCompensation>(); #[derive(Debug, Error, Clone, PartialEq, Eq)]
58pub enum MetaParseError {
59 #[error("Invalid size: {0}, must contain at least {META_BYTES} bytes")]
60 NotCanonical(usize),
61}
62
63impl MinMaxCompensation {
64 #[inline(always)]
80 pub fn read_dimension(bytes: &[u8]) -> Result<usize, MetaParseError> {
81 if bytes.len() < META_BYTES {
82 return Err(MetaParseError::NotCanonical(bytes.len()));
83 }
84
85 let dim_bytes: [u8; 4] = bytes.get(..4).map_or_else(
87 || Err(MetaParseError::NotCanonical(bytes.len())),
88 |slice| {
89 slice
90 .try_into()
91 .map_err(|_| MetaParseError::NotCanonical(bytes.len()))
92 },
93 )?;
94
95 let dim = u32::from_le_bytes(dim_bytes) as usize;
96
97 Ok(dim)
98 }
99}
100
101pub type Data<const NBITS: usize> = meta::Vector<NBITS, Unsigned, MinMaxCompensation, Dense>;
105
106pub type DataRef<'a, const NBITS: usize> =
110 meta::VectorRef<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
111
112#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
113pub enum DecompressError {
114 #[error("expected src and dst length to be identical, instead src is {0}, and dst is {1}")]
115 LengthMismatch(usize, usize),
116}
117impl<const NBITS: usize> DataRef<'_, NBITS>
118where
119 Unsigned: Representation<NBITS>,
120{
121 pub fn decompress_into(&self, dst: &mut [f32]) -> Result<(), DecompressError> {
137 if dst.len() != self.len() {
138 return Err(DecompressError::LengthMismatch(self.len(), dst.len()));
139 }
140 let meta = self.meta();
141
142 dst.iter_mut().enumerate().for_each(|(i, d)| unsafe {
145 *d = self.vector().get_unchecked(i) as f32 * meta.a + meta.b
146 });
147 Ok(())
148 }
149}
150
151pub type DataMutRef<'a, const NBITS: usize> =
155 meta::VectorMut<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
156
157#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
179#[repr(C)]
180pub struct FullQueryMeta {
181 pub sum: f32,
183 pub norm_squared: f32,
185}
186
187pub type FullQuery<A = GlobalAllocator> = slice::PolySlice<f32, FullQueryMeta, A>;
191
192pub type FullQueryRef<'a> = slice::SliceRef<'a, f32, FullQueryMeta>;
196
197pub type FullQueryMut<'a> = slice::SliceMut<'a, f32, FullQueryMeta>;
201
202#[inline(always)]
206fn kernel<const NBITS: usize, const MBITS: usize, F>(
207 x: DataRef<'_, NBITS>,
208 y: DataRef<'_, MBITS>,
209 f: F,
210) -> distances::MathematicalResult<f32>
211where
212 Unsigned: Representation<NBITS> + Representation<MBITS>,
213 InnerProduct: for<'a, 'b> PureDistanceFunction<
214 BitSlice<'a, NBITS, Unsigned>,
215 BitSlice<'b, MBITS, Unsigned>,
216 distances::MathematicalResult<u32>,
217 >,
218 F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
219{
220 let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?;
221 let (xm, ym) = (x.meta(), y.meta());
222 let term0 = xm.a * ym.a * raw_product.into_inner() as f32;
223 let term1_x = xm.n * ym.b;
224 let term1_y = ym.n * xm.b;
225 let term2 = xm.b * ym.b * (x.len() as f32);
226
227 let v = term0 + term1_x + term1_y + term2;
228 Ok(MV::new(f(v, &xm, &ym)))
229}
230
231pub struct MinMaxIP;
232
233impl<const NBITS: usize, const MBITS: usize>
234 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
235 for MinMaxIP
236where
237 Unsigned: Representation<NBITS> + Representation<MBITS>,
238 InnerProduct: for<'a, 'b> PureDistanceFunction<
239 BitSlice<'a, NBITS, Unsigned>,
240 BitSlice<'b, MBITS, Unsigned>,
241 distances::MathematicalResult<u32>,
242 >,
243{
244 fn evaluate(
245 x: DataRef<'_, NBITS>,
246 y: DataRef<'_, MBITS>,
247 ) -> distances::MathematicalResult<f32> {
248 kernel(x, y, |v, _, _| v)
249 }
250}
251
252impl<const NBITS: usize, const MBITS: usize>
253 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
254 for MinMaxIP
255where
256 Unsigned: Representation<NBITS> + Representation<MBITS>,
257 InnerProduct: for<'a, 'b> PureDistanceFunction<
258 BitSlice<'a, NBITS, Unsigned>,
259 BitSlice<'b, MBITS, Unsigned>,
260 distances::MathematicalResult<u32>,
261 >,
262{
263 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
264 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
265 Ok(-v?.into_inner())
266 }
267}
268
269impl<const NBITS: usize>
270 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
271 for MinMaxIP
272where
273 Unsigned: Representation<NBITS>,
274 InnerProduct: for<'a, 'b> PureDistanceFunction<
275 &'a [f32],
276 BitSlice<'b, NBITS, Unsigned>,
277 distances::MathematicalResult<f32>,
278 >,
279{
280 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
281 let raw_product: f32 = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
282 Ok(MathematicalValue::new(
283 raw_product * y.meta().a + x.meta().sum * y.meta().b,
284 ))
285 }
286}
287
288impl<const NBITS: usize>
289 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxIP
290where
291 Unsigned: Representation<NBITS>,
292 InnerProduct: for<'a, 'b> PureDistanceFunction<
293 &'a [f32],
294 BitSlice<'b, NBITS, Unsigned>,
295 distances::MathematicalResult<f32>,
296 >,
297{
298 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
299 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
300 Ok(-v?.into_inner())
301 }
302}
303
304pub struct MinMaxL2Squared;
305
306impl<const NBITS: usize, const MBITS: usize>
307 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
308 for MinMaxL2Squared
309where
310 Unsigned: Representation<NBITS> + Representation<MBITS>,
311 InnerProduct: for<'a, 'b> PureDistanceFunction<
312 BitSlice<'a, NBITS, Unsigned>,
313 BitSlice<'b, MBITS, Unsigned>,
314 distances::MathematicalResult<u32>,
315 >,
316{
317 fn evaluate(
318 x: DataRef<'_, NBITS>,
319 y: DataRef<'_, MBITS>,
320 ) -> distances::MathematicalResult<f32> {
321 kernel(x, y, |v, xm, ym| {
322 -2.0 * v + xm.norm_squared + ym.norm_squared
323 })
324 }
325}
326
327impl<const NBITS: usize, const MBITS: usize>
328 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
329 for MinMaxL2Squared
330where
331 Unsigned: Representation<NBITS> + Representation<MBITS>,
332 InnerProduct: for<'a, 'b> PureDistanceFunction<
333 BitSlice<'a, NBITS, Unsigned>,
334 BitSlice<'b, MBITS, Unsigned>,
335 distances::MathematicalResult<u32>,
336 >,
337{
338 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
339 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
340 Ok(v?.into_inner())
341 }
342}
343
344impl<const NBITS: usize>
345 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
346 for MinMaxL2Squared
347where
348 Unsigned: Representation<NBITS>,
349 InnerProduct: for<'a, 'b> PureDistanceFunction<
350 &'a [f32],
351 BitSlice<'b, NBITS, Unsigned>,
352 distances::MathematicalResult<f32>,
353 >,
354{
355 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
356 let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
357
358 let ym = y.meta();
359 let compensated_ip = raw_product * ym.a + x.meta().sum * ym.b;
360 Ok(MV::new(
361 x.meta().norm_squared + ym.norm_squared - 2.0 * compensated_ip,
362 ))
363 }
364}
365
366impl<const NBITS: usize>
367 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
368 for MinMaxL2Squared
369where
370 Unsigned: Representation<NBITS>,
371 InnerProduct: for<'a, 'b> PureDistanceFunction<
372 &'a [f32],
373 BitSlice<'b, NBITS, Unsigned>,
374 distances::MathematicalResult<f32>,
375 >,
376{
377 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
378 let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
379 Ok(v?.into_inner())
380 }
381}
382
383pub struct MinMaxCosine;
388
389impl<const NBITS: usize, const MBITS: usize>
390 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
391 for MinMaxCosine
392where
393 Unsigned: Representation<NBITS> + Representation<MBITS>,
394 MinMaxIP: for<'a, 'b> PureDistanceFunction<
395 DataRef<'a, NBITS>,
396 DataRef<'b, MBITS>,
397 distances::MathematicalResult<f32>,
398 >,
399{
400 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
402 let ip: MV<f32> = MinMaxIP::evaluate(x, y)?;
403 let (xm, ym) = (x.meta(), y.meta());
404 Ok(1.0 - ip.into_inner() / (xm.norm_squared.sqrt() * ym.norm_squared.sqrt()))
405 }
406}
407
408impl<const NBITS: usize>
409 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
410 for MinMaxCosine
411where
412 Unsigned: Representation<NBITS>,
413 MinMaxIP: for<'a, 'b> PureDistanceFunction<
414 FullQueryRef<'a>,
415 DataRef<'b, NBITS>,
416 distances::MathematicalResult<f32>,
417 >,
418{
419 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
420 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
421 let (xm, ym) = (x.meta().norm_squared, y.meta());
422 Ok(1.0 - ip.into_inner() / (xm.sqrt() * ym.norm_squared.sqrt()))
423 }
425}
426
427pub struct MinMaxCosineNormalized;
428
429impl<const NBITS: usize, const MBITS: usize>
430 PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
431 for MinMaxCosineNormalized
432where
433 Unsigned: Representation<NBITS> + Representation<MBITS>,
434 MinMaxIP: for<'a, 'b> PureDistanceFunction<
435 DataRef<'a, NBITS>,
436 DataRef<'b, MBITS>,
437 distances::MathematicalResult<f32>,
438 >,
439{
440 fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
441 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
442 Ok(1.0 - ip.into_inner()) }
444}
445
446impl<const NBITS: usize>
447 PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
448 for MinMaxCosineNormalized
449where
450 Unsigned: Representation<NBITS>,
451 MinMaxIP: for<'a, 'b> PureDistanceFunction<
452 FullQueryRef<'a>,
453 DataRef<'b, NBITS>,
454 distances::MathematicalResult<f32>,
455 >,
456{
457 fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
458 let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
459 Ok(1.0 - ip.into_inner()) }
461}
462
463#[cfg(test)]
468#[cfg(not(miri))]
469mod minmax_vector_tests {
470 use diskann_utils::Reborrow;
471 use rand::{
472 Rng, SeedableRng,
473 distr::{Distribution, Uniform},
474 rngs::StdRng,
475 };
476
477 use super::*;
478 use crate::{alloc::GlobalAllocator, scalar::bit_scale};
479
480 fn random_minmax_vector<const NBITS: usize>(
485 dim: usize,
486 rng: &mut impl Rng,
487 ) -> (Data<NBITS>, Vec<f32>)
488 where
489 Unsigned: Representation<NBITS>,
490 {
491 let mut v = Data::<NBITS>::new_boxed(dim);
492
493 let domain = Unsigned::domain_const::<NBITS>();
494 let code_dist = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
495
496 {
497 let mut bs = v.vector_mut();
498 for i in 0..dim {
499 bs.set(i, code_dist.sample(rng)).unwrap();
500 }
501 }
502
503 let a: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
504 let b: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
505
506 let original: Vec<f32> = (0..dim)
507 .map(|i| a * v.vector().get(i).unwrap() as f32 + b)
508 .collect();
509
510 let code_sum: f32 = (0..dim).map(|i| v.vector().get(i).unwrap() as f32).sum();
511 let norm_squared: f32 = original.iter().map(|x| x * x).sum();
512
513 v.set_meta(MinMaxCompensation {
514 a,
515 b,
516 n: a * code_sum,
517 norm_squared,
518 dim: dim as u32,
519 });
520
521 (v, original)
522 }
523
524 fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
525 where
526 Unsigned: Representation<NBITS>,
527 InnerProduct: for<'a, 'b> PureDistanceFunction<
528 BitSlice<'a, NBITS, Unsigned>,
529 BitSlice<'b, NBITS, Unsigned>,
530 distances::MathematicalResult<u32>,
531 >,
532 InnerProduct: for<'a, 'b> PureDistanceFunction<
533 &'a [f32],
534 BitSlice<'b, NBITS, Unsigned>,
535 distances::MathematicalResult<f32>,
536 >,
537 R: Rng,
538 {
539 assert!(dim <= bit_scale::<NBITS>() as usize);
540
541 let (v1, original1) = random_minmax_vector::<NBITS>(dim, rng);
542 let (v2, original2) = random_minmax_vector::<NBITS>(dim, rng);
543
544 let norm1_squared = v1.meta().norm_squared;
545 let norm2_squared = v2.meta().norm_squared;
546
547 let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
549
550 let computed_ip_f32: distances::Result<f32> =
552 MinMaxIP::evaluate(v1.reborrow(), v2.reborrow());
553 let computed_ip_f32 = computed_ip_f32.unwrap();
554 assert!(
555 (expected_ip - (-computed_ip_f32)).abs() / expected_ip.abs() < 1e-3,
556 "Inner product (f32) failed: expected {}, got {} on dim : {}",
557 -expected_ip,
558 computed_ip_f32,
559 dim
560 );
561
562 let expected_l2 = (0..dim)
564 .map(|i| original1[i] - original2[i])
565 .map(|x| x.powf(2.0))
566 .sum::<f32>();
567
568 let computed_l2_f32: distances::Result<f32> =
570 MinMaxL2Squared::evaluate(v1.reborrow(), v2.reborrow());
571 let computed_l2_f32 = computed_l2_f32.unwrap();
572 assert!(
573 ((computed_l2_f32 - expected_l2).abs() / expected_l2) < 1e-3,
574 "L2 distance (f32) failed: expected {}, got {} on dim : {}",
575 expected_l2,
576 computed_l2_f32,
577 dim
578 );
579
580 let expected_cosine = 1.0 - expected_ip / (norm1_squared.sqrt() * norm2_squared.sqrt());
581
582 let computed_cosine: distances::Result<f32> =
583 MinMaxCosine::evaluate(v1.reborrow(), v2.reborrow());
584 let computed_cosine = computed_cosine.unwrap();
585
586 {
587 let passed = (computed_cosine - expected_cosine).abs() < 1e-6
588 || ((computed_cosine - expected_cosine).abs() / expected_cosine) < 1e-3;
589
590 assert!(
591 passed,
592 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
593 expected_cosine, computed_cosine, dim
594 );
595 }
596
597 let cosine_normalized: distances::Result<f32> =
598 MinMaxCosineNormalized::evaluate(v1.reborrow(), v2.reborrow());
599 let cosine_normalized = cosine_normalized.unwrap();
600 let expected_cos_normalized = 1.0 - expected_ip;
601 assert!(
602 ((expected_cos_normalized - cosine_normalized).abs() / expected_cos_normalized.abs())
603 < 1e-6,
604 "CosineNormalized distance (f32) failed: expected {}, got {} on dim : {}",
605 expected_cos_normalized,
606 cosine_normalized,
607 dim
608 );
609
610 let mut fp_query = FullQuery::new_in(dim, GlobalAllocator).unwrap();
612 fp_query.vector_mut().copy_from_slice(&original1);
613 *fp_query.meta_mut() = FullQueryMeta {
614 norm_squared: norm1_squared,
615 sum: original1.iter().sum::<f32>(),
616 };
617
618 let fp_ip: distances::Result<f32> = MinMaxIP::evaluate(fp_query.reborrow(), v2.reborrow());
619 let fp_ip = fp_ip.unwrap();
620 assert!(
621 (expected_ip - (-fp_ip)).abs() / expected_ip.abs() < 1e-3,
622 "Inner product (f32) failed: expected {}, got {} on dim : {}",
623 -expected_ip,
624 fp_ip,
625 dim
626 );
627
628 let fp_l2: distances::Result<f32> =
629 MinMaxL2Squared::evaluate(fp_query.reborrow(), v2.reborrow());
630 let fp_l2 = fp_l2.unwrap();
631 assert!(
632 ((fp_l2 - expected_l2).abs() / expected_l2) < 1e-3,
633 "L2 distance (f32) failed: expected {}, got {} on dim : {}",
634 expected_l2,
635 computed_l2_f32,
636 dim
637 );
638
639 let fp_cosine: distances::Result<f32> =
640 MinMaxCosine::evaluate(fp_query.reborrow(), v2.reborrow());
641 let fp_cosine = fp_cosine.unwrap();
642 let diff = (fp_cosine - expected_cosine).abs();
643 assert!(
644 (diff / expected_cosine) < 1e-3 || diff <= 1e-6,
645 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
646 expected_cosine,
647 fp_cosine,
648 dim
649 );
650
651 let fp_cos_norm: distances::Result<f32> =
652 MinMaxCosineNormalized::evaluate(fp_query.reborrow(), v2.reborrow());
653 let fp_cos_norm = fp_cos_norm.unwrap();
654 assert!(
655 (((1.0 - expected_ip) - fp_cos_norm).abs() / (1.0 - expected_ip)) < 1e-3,
656 "Cosine distance (f32) failed: expected {}, got {} on dim : {}",
657 (1.0 - expected_ip),
658 fp_cos_norm,
659 dim
660 );
661
662 let meta = v1.meta();
664 let v1_ref = DataRef::new(v1.vector(), &meta);
665 let dim = v1_ref.len();
666 let mut boxed = vec![0f32; dim + 1];
667
668 let pre = v1_ref.decompress_into(&mut boxed);
669 assert_eq!(
670 pre.unwrap_err(),
671 DecompressError::LengthMismatch(dim, dim + 1)
672 );
673 let pre = v1_ref.decompress_into(&mut boxed[..dim - 1]);
674 assert_eq!(
675 pre.unwrap_err(),
676 DecompressError::LengthMismatch(dim, dim - 1)
677 );
678 let pre = v1_ref.decompress_into(&mut boxed[..dim]);
679 assert!(pre.is_ok());
680
681 boxed
682 .iter()
683 .zip(original1.iter())
684 .for_each(|(x, y)| assert!((*x - *y).abs() <= 1e-6));
685
686 let mut bytes = vec![0u8; Data::canonical_bytes(dim)];
688 let mut data = DataMutRef::from_canonical_front_mut(bytes.as_mut_slice(), dim).unwrap();
689 data.set_meta(meta);
690
691 let pre = MinMaxCompensation::read_dimension(&bytes);
692 assert!(pre.is_ok());
693 let read_dim = pre.unwrap();
694 assert_eq!(read_dim, dim);
695
696 let pre = MinMaxCompensation::read_dimension(&[0_u8; 2]);
697 assert_eq!(pre.unwrap_err(), MetaParseError::NotCanonical(2));
698 }
699
700 cfg_if::cfg_if! {
701 if #[cfg(miri)] {
702 const TRIALS: usize = 2;
706 } else {
707 const TRIALS: usize = 10;
708 }
709 }
710
711 macro_rules! test_minmax_compensated {
712 ($name:ident, $nbits:literal, $seed:literal) => {
713 #[test]
714 fn $name() {
715 let mut rng = StdRng::seed_from_u64($seed);
716 const MAX_DIM: usize = (bit_scale::<$nbits>() as usize);
717 for dim in 1..=MAX_DIM {
718 for _ in 0..TRIALS {
719 test_minmax_compensated_vectors::<$nbits, _>(dim, &mut rng);
720 }
721 }
722 }
723 };
724 }
725 test_minmax_compensated!(unsigned_minmax_compensated_test_u1, 1, 0xa33d5658097a1c35);
726 test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
727 test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
728 test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
729
730 fn test_minmax_heterogeneous_kernel<const NBITS: usize, const MBITS: usize, R>(
736 dim: usize,
737 rng: &mut R,
738 ) where
739 Unsigned: Representation<NBITS> + Representation<MBITS>,
740 InnerProduct: for<'a, 'b> PureDistanceFunction<
741 BitSlice<'a, NBITS, Unsigned>,
742 BitSlice<'b, MBITS, Unsigned>,
743 distances::MathematicalResult<u32>,
744 >,
745 R: Rng,
746 {
747 let (v_query, original1) = random_minmax_vector::<NBITS>(dim, rng);
748 let (v_data, original2) = random_minmax_vector::<MBITS>(dim, rng);
749
750 let expected_ip: f32 = original1.iter().zip(&original2).map(|(x, y)| x * y).sum();
752 let computed_ip = kernel(v_query.reborrow(), v_data.reborrow(), |v, _, _| v)
753 .unwrap()
754 .into_inner();
755 assert!(
756 (expected_ip - computed_ip).abs() / expected_ip.abs().max(1e-10) < 1e-6,
757 "Heterogeneous IP ({},{}) failed: expected {}, got {} on dim: {}",
758 NBITS,
759 MBITS,
760 expected_ip,
761 computed_ip,
762 dim,
763 );
764 }
765
766 macro_rules! test_minmax_heterogeneous {
767 ($name:ident, $N:literal, $M:literal, $seed:literal) => {
768 #[test]
769 fn $name() {
770 let mut rng = StdRng::seed_from_u64($seed);
771 const MAX_DIM: usize = bit_scale::<$M>() as usize;
773 for dim in 1..=MAX_DIM {
774 for _ in 0..TRIALS {
775 test_minmax_heterogeneous_kernel::<$N, $M, _>(dim, &mut rng);
776 }
777 }
778 }
779 };
780 }
781
782 test_minmax_heterogeneous!(minmax_heterogeneous_8x4, 8, 4, 0xb7c3d9e5f1a20864);
783 test_minmax_heterogeneous!(minmax_heterogeneous_8x2, 8, 2, 0x4e8f2c6a1d3b5079);
784 test_minmax_heterogeneous!(minmax_heterogeneous_8x1, 8, 1, 0x1b0f2c614d2a7141);
785}