1use diskann_utils::{Reborrow, ReborrowMut};
115use diskann_vector::{norm::FastL2NormSquared, Norm};
116use diskann_wide::{arch::Target2, Architecture};
117use half::f16;
118use thiserror::Error;
119
120#[cfg(feature = "flatbuffers")]
121use crate::flatbuffers as fb;
122use crate::{
123 alloc::{AllocatorCore, AllocatorError, Poly},
124 bits::{BitSlice, Dense, PermutationStrategy, Representation, Unsigned},
125 distances::{self, InnerProduct, MV},
126 meta,
127};
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum SupportedMetric {
136 SquaredL2,
137 InnerProduct,
138 Cosine,
139}
140
141#[cfg(test)]
142impl SupportedMetric {
143 fn pick(self, shifted_norm: f32, inner_product_with_centroid: f32) -> f32 {
144 match self {
145 Self::SquaredL2 => shifted_norm * shifted_norm,
146 Self::InnerProduct | Self::Cosine => inner_product_with_centroid,
147 }
148 }
149
150 #[cfg(feature = "flatbuffers")]
151 pub(super) fn all() -> [Self; 3] {
152 [Self::SquaredL2, Self::InnerProduct, Self::Cosine]
153 }
154}
155
156impl TryFrom<diskann_vector::distance::Metric> for SupportedMetric {
157 type Error = UnsupportedMetric;
158 fn try_from(metric: diskann_vector::distance::Metric) -> Result<Self, Self::Error> {
159 use diskann_vector::distance::Metric;
160 match metric {
161 Metric::L2 => Ok(Self::SquaredL2),
162 Metric::InnerProduct => Ok(Self::InnerProduct),
163 Metric::Cosine => Ok(Self::Cosine),
164 unsupported => Err(UnsupportedMetric(unsupported)),
165 }
166 }
167}
168
169impl PartialEq<diskann_vector::distance::Metric> for SupportedMetric {
170 fn eq(&self, metric: &diskann_vector::distance::Metric) -> bool {
171 match Self::try_from(*metric) {
172 Ok(m) => *self == m,
173 Err(_) => false,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Copy, Error)]
179#[error("metric {0:?} is not supported for spherical quantization")]
180pub struct UnsupportedMetric(pub(crate) diskann_vector::distance::Metric);
181
182#[cfg(feature = "flatbuffers")]
183#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
184#[derive(Debug, Clone, Copy, PartialEq, Error)]
185#[error("the value {0} is not recognized as a supported metric")]
186pub struct InvalidMetric(i8);
187
188#[cfg(feature = "flatbuffers")]
189#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
190impl TryFrom<fb::spherical::SupportedMetric> for SupportedMetric {
191 type Error = InvalidMetric;
192 fn try_from(value: fb::spherical::SupportedMetric) -> Result<Self, Self::Error> {
193 match value {
194 fb::spherical::SupportedMetric::SquaredL2 => Ok(Self::SquaredL2),
195 fb::spherical::SupportedMetric::InnerProduct => Ok(Self::InnerProduct),
196 fb::spherical::SupportedMetric::Cosine => Ok(Self::Cosine),
197 unsupported => Err(InvalidMetric(unsupported.0)),
198 }
199 }
200}
201
202#[cfg(feature = "flatbuffers")]
203#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
204impl From<SupportedMetric> for fb::spherical::SupportedMetric {
205 fn from(value: SupportedMetric) -> Self {
206 match value {
207 SupportedMetric::SquaredL2 => fb::spherical::SupportedMetric::SquaredL2,
208 SupportedMetric::InnerProduct => fb::spherical::SupportedMetric::InnerProduct,
209 SupportedMetric::Cosine => fb::spherical::SupportedMetric::Cosine,
210 }
211 }
212}
213
214#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
220#[repr(C)]
221pub struct DataMeta {
222 pub inner_product_correction: f16,
234
235 pub metric_specific: f16,
241
242 pub bit_sum: u16,
247}
248
249#[derive(Debug, Error, Clone, Copy, PartialEq)]
250pub enum DataMetaError {
251 #[error("inner product correction {value} cannot fit in a 16-bit floating point number")]
252 InnerProductCorrection { value: f32 },
253
254 #[error("metric specific correction {value} cannot fit in a 16-bit floating point number")]
255 MetricSpecific { value: f32 },
256
257 #[error("bit sum {value} cannot fit in a 16-bit unsigned integer")]
258 BitSum { value: u32 },
259}
260
261impl DataMeta {
262 pub fn new(
266 inner_product_correction: f32,
267 metric_specific: f32,
268 bit_sum: u32,
269 ) -> Result<Self, DataMetaError> {
270 let inner_product_correction_f16 = diskann_wide::cast_f32_to_f16(inner_product_correction);
271 if !inner_product_correction_f16.is_finite() {
272 return Err(DataMetaError::InnerProductCorrection {
273 value: inner_product_correction,
274 });
275 }
276
277 let metric_specific_f16 = diskann_wide::cast_f32_to_f16(metric_specific);
278 if !metric_specific_f16.is_finite() {
279 return Err(DataMetaError::MetricSpecific {
280 value: metric_specific,
281 });
282 }
283
284 let bit_sum_u16: u16 = bit_sum
285 .try_into()
286 .map_err(|_| DataMetaError::BitSum { value: bit_sum })?;
287
288 Ok(Self {
289 inner_product_correction: inner_product_correction_f16,
290 metric_specific: metric_specific_f16,
291 bit_sum: bit_sum_u16,
292 })
293 }
294
295 const fn offset_term<const NBITS: usize>() -> f32 {
313 ((2usize).pow(NBITS as u32) as f32 - 1.0) / 2.0
314 }
315
316 #[inline(always)]
318 pub fn to_full<A>(self, arch: A) -> DataMetaF32
319 where
320 A: Architecture,
321 {
322 use diskann_wide::SIMDVector;
323
324 let pre = [
330 self.metric_specific,
331 self.inner_product_correction,
332 half::f16::default(),
333 half::f16::default(),
334 half::f16::default(),
335 half::f16::default(),
336 half::f16::default(),
337 half::f16::default(),
338 ];
339
340 let post: <A as Architecture>::f32x8 =
341 <A as Architecture>::f16x8::from_array(arch, pre).into();
342 let post = post.to_array();
343
344 DataMetaF32 {
345 metric_specific: post[0],
346 inner_product_correction: post[1],
347 bit_sum: self.bit_sum.into(),
348 }
349 }
350}
351
352#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
353#[repr(C)]
354pub struct DataMetaF32 {
355 pub inner_product_correction: f32,
356 pub metric_specific: f32,
357 pub bit_sum: f32,
358}
359
360pub type DataRef<'a, const NBITS: usize> = meta::VectorRef<'a, NBITS, Unsigned, DataMeta>;
362
363pub type DataMut<'a, const NBITS: usize> = meta::VectorMut<'a, NBITS, Unsigned, DataMeta>;
365
366pub type Data<const NBITS: usize, A> = meta::PolyVector<NBITS, Unsigned, DataMeta, Dense, A>;
368
369#[derive(Copy, Clone, Default, Debug, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
382#[repr(C)]
383pub struct QueryMeta {
384 pub inner_product_correction: f32,
387
388 pub bit_sum: f32,
392
393 pub offset: f32,
396
397 pub metric_specific: f32,
399}
400
401pub type Query<const NBITS: usize, Perm, A> = meta::PolyVector<NBITS, Unsigned, QueryMeta, Perm, A>;
403
404pub type QueryRef<'a, const NBITS: usize, Perm> =
406 meta::VectorRef<'a, NBITS, Unsigned, QueryMeta, Perm>;
407
408pub type QueryMut<'a, const NBITS: usize, Perm> =
410 meta::VectorMut<'a, NBITS, Unsigned, QueryMeta, Perm>;
411
412#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
417#[repr(C)]
418pub struct FullQueryMeta {
419 pub sum: f32,
421 pub shifted_norm: f32,
423 pub metric_specific: f32,
425}
426
427#[derive(Debug)]
429pub struct FullQuery<A>
430where
431 A: AllocatorCore,
432{
433 pub data: Poly<[f32], A>,
435 pub meta: FullQueryMeta,
436}
437
438impl<A> FullQuery<A>
439where
440 A: AllocatorCore,
441{
442 pub fn empty(dim: usize, allocator: A) -> Result<Self, AllocatorError> {
444 Ok(Self {
445 data: Poly::broadcast(0.0f32, dim, allocator)?,
446 meta: Default::default(),
447 })
448 }
449}
450
451pub type FullQueryRef<'a> = meta::slice::SliceRef<'a, f32, FullQueryMeta>;
452
453pub type FullQueryMut<'a> = meta::slice::SliceMut<'a, f32, FullQueryMeta>;
454
455impl<'short, A> Reborrow<'short> for FullQuery<A>
456where
457 A: AllocatorCore,
458{
459 type Target = FullQueryRef<'short>;
460 fn reborrow(&'short self) -> Self::Target {
461 FullQueryRef::new(&self.data, &self.meta)
462 }
463}
464
465impl<'short, A> ReborrowMut<'short> for FullQuery<A>
466where
467 A: AllocatorCore,
468{
469 type Target = FullQueryMut<'short>;
470 fn reborrow_mut(&'short mut self) -> Self::Target {
471 FullQueryMut::new(&mut self.data, &mut self.meta)
472 }
473}
474
475struct ConstOffset<const NBITS: usize>;
482
483impl<const NBITS: usize> ConstOffset<NBITS> {
484 const OFFSET: f32 = DataMeta::offset_term::<NBITS>();
485 const OFFSET_SQUARED: f32 = DataMeta::offset_term::<NBITS>() * DataMeta::offset_term::<NBITS>();
486}
487
488#[inline(always)]
494fn kernel<A, const NBITS: usize>(
495 arch: A,
496 x: DataRef<'_, NBITS>,
497 y: DataRef<'_, NBITS>,
498 dim: f32,
499) -> distances::Result<f32>
500where
501 A: Architecture,
502 Unsigned: Representation<NBITS>,
503 InnerProduct: for<'a> Target2<
504 A,
505 distances::MathematicalResult<u32>,
506 BitSlice<'a, NBITS, Unsigned>,
507 BitSlice<'a, NBITS, Unsigned>,
508 >,
509{
510 let ip: distances::MathematicalResult<u32> =
515 <_ as Target2<_, _, _, _>>::run(InnerProduct, arch, x.vector(), y.vector());
516
517 let ip = ip?.into_inner() as f32;
518
519 let offset = ConstOffset::<NBITS>::OFFSET;
520 let offset_squared = ConstOffset::<NBITS>::OFFSET_SQUARED;
521
522 let xc = x.meta().to_full(arch);
523 let yc = y.meta().to_full(arch);
524
525 Ok(xc.inner_product_correction
526 * yc.inner_product_correction
527 * (ip - offset * (xc.bit_sum + yc.bit_sum) + offset_squared * dim))
528}
529
530#[derive(Debug, Clone, Copy)]
537pub struct CompensatedSquaredL2 {
538 pub(super) dim: f32,
539}
540
541impl CompensatedSquaredL2 {
542 pub fn new(dim: usize) -> Self {
544 Self { dim: dim as f32 }
545 }
546}
547
548impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedSquaredL2
551where
552 A: Architecture,
553 Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
554{
555 #[inline(always)]
556 fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
557 self.run(arch, x, y).map(|r| r.into_inner())
558 }
559}
560
561impl<A, const NBITS: usize>
571 Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
572 for CompensatedSquaredL2
573where
574 A: Architecture,
575 Unsigned: Representation<NBITS>,
576 InnerProduct: for<'a> Target2<
577 A,
578 distances::MathematicalResult<u32>,
579 BitSlice<'a, NBITS, Unsigned>,
580 BitSlice<'a, NBITS, Unsigned>,
581 >,
582{
583 #[inline(always)]
584 fn run(
585 self,
586 arch: A,
587 x: DataRef<'_, NBITS>,
588 y: DataRef<'_, NBITS>,
589 ) -> distances::MathematicalResult<f32> {
590 let xc = x.meta().to_full(arch);
591 let yc = y.meta().to_full(arch);
592 let result = xc.metric_specific + yc.metric_specific - 2.0 * kernel(arch, x, y, self.dim)?;
593 Ok(MV::new(result))
594 }
595}
596
597impl<A, const Q: usize, const D: usize, Perm>
598 Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
599 for CompensatedSquaredL2
600where
601 A: Architecture,
602 Unsigned: Representation<Q>,
603 Unsigned: Representation<D>,
604 Perm: PermutationStrategy<Q>,
605 for<'a> InnerProduct: Target2<
606 A,
607 distances::MathematicalResult<u32>,
608 BitSlice<'a, Q, Unsigned, Perm>,
609 BitSlice<'a, D, Unsigned>,
610 >,
611{
612 #[inline(always)]
613 fn run(
614 self,
615 arch: A,
616 x: QueryRef<'_, Q, Perm>,
617 y: DataRef<'_, D>,
618 ) -> distances::MathematicalResult<f32> {
619 let ip: distances::MathematicalResult<u32> =
620 arch.run2_inline(InnerProduct, x.vector(), y.vector());
621 let ip = ip?.into_inner() as f32;
622
623 let yc = y.meta().to_full(arch);
624 let xc = x.meta();
625
626 let y_offset: f32 = DataMeta::offset_term::<D>();
627
628 let corrected_ip = yc.inner_product_correction
629 * xc.inner_product_correction
630 * (ip - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
631 - y_offset * xc.offset * self.dim);
632
633 Ok(MV::new(
634 yc.metric_specific + xc.metric_specific - 2.0 * corrected_ip,
635 ))
636 }
637}
638
639impl<A, const NBITS: usize>
644 Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
645 for CompensatedSquaredL2
646where
647 A: Architecture,
648 Unsigned: Representation<NBITS>,
649 InnerProduct: for<'a> Target2<
650 A,
651 distances::MathematicalResult<f32>,
652 &'a [f32],
653 BitSlice<'a, NBITS, Unsigned>,
654 >,
655{
656 #[inline(always)]
657 fn run(
658 self,
659 arch: A,
660 x: FullQueryRef<'_>,
661 y: DataRef<'_, NBITS>,
662 ) -> distances::MathematicalResult<f32> {
663 let s = arch
664 .run2(InnerProduct, x.vector(), y.vector())?
665 .into_inner();
666
667 let xc = x.meta();
668 let yc = y.meta().to_full(arch);
669
670 let offset = ConstOffset::<NBITS>::OFFSET;
671 let ip = s - xc.sum * offset;
672
673 let r = xc.metric_specific + yc.metric_specific
676 - 2.0 * xc.shifted_norm * yc.inner_product_correction * ip;
677 Ok(MV::new(r))
678 }
679}
680
681#[derive(Debug, Clone, Copy)]
688pub struct CompensatedIP {
689 pub(super) squared_shift_norm: f32,
690 pub(super) dim: f32,
691}
692
693impl CompensatedIP {
694 pub fn new(shift: &[f32], dim: usize) -> Self {
696 Self {
697 squared_shift_norm: FastL2NormSquared.evaluate(shift),
698 dim: dim as f32,
699 }
700 }
701}
702
703impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedIP
709where
710 A: Architecture,
711 Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
712{
713 #[inline(always)]
714 fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
715 arch.run2(self, x, y).map(|r| -r.into_inner())
716 }
717}
718
719impl<A, const NBITS: usize>
731 Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
732 for CompensatedIP
733where
734 A: Architecture,
735 Unsigned: Representation<NBITS>,
736 InnerProduct: for<'a> Target2<
737 A,
738 distances::MathematicalResult<u32>,
739 BitSlice<'a, NBITS, Unsigned>,
740 BitSlice<'a, NBITS, Unsigned>,
741 >,
742{
743 #[inline(always)]
744 fn run(
745 self,
746 arch: A,
747 x: DataRef<'_, NBITS>,
748 y: DataRef<'_, NBITS>,
749 ) -> distances::MathematicalResult<f32> {
750 let xc = x.meta().to_full(arch);
751 let yc = y.meta().to_full(arch);
752
753 let result = xc.metric_specific
754 + yc.metric_specific
755 + kernel(arch, x, y, self.dim)?
756 + self.squared_shift_norm;
757 Ok(MV::new(result))
758 }
759}
760
761impl<A, const Q: usize, const D: usize, Perm>
762 Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
763 for CompensatedIP
764where
765 A: Architecture,
766 Unsigned: Representation<Q>,
767 Unsigned: Representation<D>,
768 Perm: PermutationStrategy<Q>,
769 for<'a> InnerProduct: Target2<
770 A,
771 distances::MathematicalResult<u32>,
772 BitSlice<'a, Q, Unsigned, Perm>,
773 BitSlice<'a, D, Unsigned>,
774 >,
775{
776 #[inline(always)]
777 fn run(
778 self,
779 arch: A,
780 x: QueryRef<'_, Q, Perm>,
781 y: DataRef<'_, D>,
782 ) -> distances::MathematicalResult<f32> {
783 let ip: MV<u32> = arch.run2_inline(InnerProduct, x.vector(), y.vector())?;
785
786 let yc = y.meta().to_full(arch);
787 let xc = x.meta();
788
789 let y_offset: f32 = DataMeta::offset_term::<D>();
791
792 let corrected_ip = xc.inner_product_correction
793 * yc.inner_product_correction
794 * (ip.into_inner() as f32 - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
795 - y_offset * xc.offset * self.dim);
796
797 Ok(MV::new(
799 corrected_ip + yc.metric_specific + xc.metric_specific + self.squared_shift_norm,
800 ))
801 }
802}
803
804impl<A, const NBITS: usize>
809 Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
810 for CompensatedIP
811where
812 A: Architecture,
813 Unsigned: Representation<NBITS>,
814 InnerProduct: for<'a> Target2<
815 A,
816 distances::MathematicalResult<f32>,
817 &'a [f32],
818 BitSlice<'a, NBITS, Unsigned>,
819 >,
820{
821 #[inline(always)]
822 fn run(
823 self,
824 arch: A,
825 x: FullQueryRef<'_>,
826 y: DataRef<'_, NBITS>,
827 ) -> distances::MathematicalResult<f32> {
828 let s = arch
829 .run2(InnerProduct, x.vector(), y.vector())?
830 .into_inner();
831
832 let yc = y.meta().to_full(arch);
833 let xc = x.meta();
834
835 let offset = ConstOffset::<NBITS>::OFFSET;
836 let ip = xc.shifted_norm * yc.inner_product_correction * (s - xc.sum * offset);
837
838 Ok(MV::new(
839 ip + xc.metric_specific + yc.metric_specific + self.squared_shift_norm,
840 ))
841 }
842}
843
844#[derive(Debug, Clone, Copy)]
856pub struct CompensatedCosine {
857 pub(super) inner: CompensatedIP,
858}
859
860impl CompensatedCosine {
861 pub fn new(inner: CompensatedIP) -> Self {
863 Self { inner }
864 }
865}
866
867impl<A, T, U> Target2<A, distances::MathematicalResult<f32>, T, U> for CompensatedCosine
868where
869 A: Architecture,
870 CompensatedIP: Target2<A, distances::MathematicalResult<f32>, T, U>,
871{
872 #[inline(always)]
873 fn run(self, arch: A, x: T, y: U) -> distances::MathematicalResult<f32> {
874 self.inner.run(arch, x, y)
875 }
876}
877
878impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedCosine
884where
885 A: Architecture,
886 Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
887{
888 #[inline(always)]
889 fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
890 let r: MV<f32> = self.run(arch, x, y)?;
891 Ok(1.0 - r.into_inner())
892 }
893}
894
895#[cfg(test)]
900mod tests {
901 use diskann_utils::{lazy_format, Reborrow};
902 use diskann_vector::{distance::Metric, norm::FastL2Norm, PureDistanceFunction};
903 use diskann_wide::ARCH;
904 use rand::{
905 distr::{Distribution, Uniform},
906 rngs::StdRng,
907 SeedableRng,
908 };
909 use rand_distr::StandardNormal;
910
911 use super::*;
912 use crate::{
913 alloc::GlobalAllocator,
914 bits::{BitTranspose, Dense},
915 };
916
917 #[derive(Debug, Clone, Copy, PartialEq)]
918 struct Approx {
919 absolute: f32,
920 relative: f32,
921 }
922
923 impl Approx {
924 const fn new(absolute: f32, relative: f32) -> Self {
925 assert!(absolute >= 0.0);
926 assert!(relative >= 0.0);
927 Self { absolute, relative }
928 }
929
930 fn check(&self, got: f32, expected: f32, ctx: Option<&dyn std::fmt::Display>) -> bool {
931 struct Ctx<'a>(Option<&'a dyn std::fmt::Display>);
932
933 impl std::fmt::Display for Ctx<'_> {
934 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
935 match self.0 {
936 None => write!(f, "none"),
937 Some(d) => write!(f, "{}", d),
938 }
939 }
940 }
941
942 let absolute = (got - expected).abs();
943 if absolute <= self.absolute {
944 true
945 } else {
946 let relative = absolute / expected.abs();
947 if relative <= self.relative {
948 true
949 } else {
950 panic!(
951 "got {}, expected {}. Abs/Rel = {}/{} with bounds {}/{}: Ctx: {}",
952 got,
953 expected,
954 absolute,
955 relative,
956 self.absolute,
957 self.relative,
958 Ctx(ctx)
959 );
960 }
961 }
962 }
963 }
964
965 #[test]
970 fn test_data_meta() {
971 let meta = DataMeta::new(1.0, 2.0, 10).unwrap();
973 let expected = DataMetaF32 {
974 inner_product_correction: 1.0,
975 metric_specific: 2.0,
976 bit_sum: 10.0,
977 };
978 assert_eq!(meta.to_full(ARCH), expected);
979
980 let err = DataMeta::new(65600.0, 2.0, 10).unwrap_err();
982 assert_eq!(
983 err.to_string(),
984 "inner product correction 65600 cannot fit in a 16-bit floating point number"
985 );
986
987 let err = DataMeta::new(2.0, 65600.0, 10).unwrap_err();
988 assert_eq!(
989 err.to_string(),
990 "metric specific correction 65600 cannot fit in a 16-bit floating point number"
991 );
992
993 let err = DataMeta::new(2.0, 2.0, 65536).unwrap_err();
994 assert_eq!(
995 err.to_string(),
996 "bit sum 65536 cannot fit in a 16-bit unsigned integer",
997 );
998 }
999
1000 #[test]
1005 fn supported_metric() {
1006 assert_eq!(
1007 SupportedMetric::try_from(Metric::L2).unwrap(),
1008 SupportedMetric::SquaredL2
1009 );
1010 assert_eq!(
1011 SupportedMetric::try_from(Metric::InnerProduct).unwrap(),
1012 SupportedMetric::InnerProduct
1013 );
1014 assert_eq!(
1015 SupportedMetric::try_from(Metric::Cosine).unwrap(),
1016 SupportedMetric::Cosine
1017 );
1018 assert!(matches!(
1019 SupportedMetric::try_from(Metric::CosineNormalized),
1020 Err(UnsupportedMetric(Metric::CosineNormalized))
1021 ));
1022
1023 assert_eq!(SupportedMetric::SquaredL2, Metric::L2);
1024 assert_ne!(SupportedMetric::SquaredL2, Metric::InnerProduct);
1025 assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
1026 assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
1027
1028 assert_ne!(SupportedMetric::InnerProduct, Metric::L2);
1029 assert_eq!(SupportedMetric::InnerProduct, Metric::InnerProduct);
1030 assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
1031 assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
1032 }
1033
1034 struct Reference<T> {
1039 compressed: T,
1040 reconstructed: Vec<f32>,
1041 norm: f32,
1042 center_ip: f32,
1043 self_ip: Option<f32>,
1044 }
1045
1046 trait GenerateReference: Sized {
1047 fn generate_reference(
1048 center: &[f32],
1049 metric: SupportedMetric,
1050 rng: &mut StdRng,
1051 ) -> Reference<Self>;
1052 }
1053
1054 impl<const NBITS: usize> GenerateReference for Data<NBITS, GlobalAllocator>
1055 where
1056 Unsigned: Representation<NBITS>,
1057 {
1058 fn generate_reference(
1059 center: &[f32],
1060 metric: SupportedMetric,
1061 rng: &mut StdRng,
1062 ) -> Reference<Self> {
1063 let dim = center.len();
1064
1065 let mut reconstructed = vec![0.0f32; dim];
1066 let mut compressed = Data::<NBITS, _>::new_boxed(dim);
1067
1068 let mut bit_sum = 0;
1069 let dist = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
1070 let offset = (2usize.pow(NBITS as u32) as f32 - 1.0) / 2.0;
1071 for (i, r) in reconstructed.iter_mut().enumerate() {
1072 let b: i64 = dist.sample(rng);
1073 bit_sum += b;
1074 compressed.vector_mut().set(i, b).unwrap();
1075 *r = (b as f32) - offset;
1076 }
1077
1078 let r_norm = FastL2Norm.evaluate(reconstructed.as_slice());
1079 reconstructed.iter_mut().for_each(|i| *i /= r_norm);
1080
1081 let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1082 let center_ip: f32 = Uniform::new(0.5, 2.5).unwrap().sample(rng);
1083 let self_ip: f32 = Uniform::new(0.5, 1.5).unwrap().sample(rng);
1084
1085 compressed.set_meta(
1086 DataMeta::new(
1087 norm / (self_ip * r_norm),
1088 metric.pick(norm, center_ip),
1089 bit_sum.try_into().unwrap(),
1090 )
1091 .unwrap(),
1092 );
1093
1094 Reference {
1095 compressed,
1096 reconstructed,
1097 norm,
1098 center_ip,
1099 self_ip: Some(self_ip),
1100 }
1101 }
1102 }
1103
1104 impl<const NBITS: usize, Perm> GenerateReference for Query<NBITS, Perm, GlobalAllocator>
1105 where
1106 Unsigned: Representation<NBITS>,
1107 Perm: PermutationStrategy<NBITS>,
1108 {
1109 fn generate_reference(
1110 center: &[f32],
1111 metric: SupportedMetric,
1112 rng: &mut StdRng,
1113 ) -> Reference<Self> {
1114 let dim = center.len();
1115
1116 let mut reconstructed = vec![0.0f32; dim];
1117 let mut compressed = Query::<NBITS, Perm, _>::new_boxed(dim);
1118
1119 let distribution = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
1120
1121 let base: f32 = StandardNormal {}.sample(rng);
1122 let scale: f32 = {
1123 let scale: f32 = StandardNormal {}.sample(rng);
1124 scale.abs()
1125 };
1126
1127 let mut bit_sum = 0;
1128 for (i, r) in reconstructed.iter_mut().enumerate() {
1129 let b = distribution.sample(rng);
1130 compressed.vector_mut().set(i, b).unwrap();
1131 *r = base + scale * (b as f32);
1132 bit_sum += b;
1133 }
1134
1135 let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1136 let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
1137
1138 compressed.set_meta(QueryMeta {
1139 inner_product_correction: norm * scale,
1140 bit_sum: bit_sum as f32,
1141 offset: base / scale,
1142 metric_specific: metric.pick(norm, center_ip),
1143 });
1144
1145 Reference {
1146 compressed,
1147 reconstructed,
1148 norm,
1149 center_ip,
1150 self_ip: None,
1151 }
1152 }
1153 }
1154
1155 impl GenerateReference for FullQuery<GlobalAllocator> {
1156 fn generate_reference(
1157 center: &[f32],
1158 metric: SupportedMetric,
1159 rng: &mut StdRng,
1160 ) -> Reference<Self> {
1161 let dim = center.len();
1162
1163 let mut query = FullQuery::empty(dim, GlobalAllocator).unwrap();
1164
1165 let mut sum = 0.0;
1166 let dist = StandardNormal {};
1167 for r in query.data.iter_mut() {
1168 let b: f32 = dist.sample(rng);
1169 sum += b;
1170 *r = b;
1171 }
1172
1173 let r_norm = FastL2Norm.evaluate(&*query.data);
1174 query.data.iter_mut().for_each(|i| *i /= r_norm);
1175
1176 let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
1177 let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
1178
1179 query.meta = FullQueryMeta {
1180 sum: sum / r_norm,
1181 shifted_norm: norm,
1182 metric_specific: metric.pick(norm, center_ip),
1183 };
1184
1185 let reconstructed = query.data.to_vec();
1186 Reference {
1187 compressed: query,
1188 reconstructed,
1189 norm,
1190 center_ip,
1191 self_ip: None,
1192 }
1193 }
1194 }
1195
1196 fn test_compensated_distance<const NBITS: usize>(
1206 dim: usize,
1207 ntrials: usize,
1208 err_l2: Approx,
1209 err_ip: Approx,
1210 rng: &mut StdRng,
1211 ) where
1212 Unsigned: Representation<NBITS>,
1213 for<'a> CompensatedIP: Target2<
1214 diskann_wide::arch::Current,
1215 distances::Result<f32>,
1216 DataRef<'a, NBITS>,
1217 DataRef<'a, NBITS>,
1218 > + Target2<
1219 diskann_wide::arch::Current,
1220 distances::MathematicalResult<f32>,
1221 DataRef<'a, NBITS>,
1222 DataRef<'a, NBITS>,
1223 >,
1224 for<'a> CompensatedSquaredL2: Target2<
1225 diskann_wide::arch::Current,
1226 distances::Result<f32>,
1227 DataRef<'a, NBITS>,
1228 DataRef<'a, NBITS>,
1229 > + Target2<
1230 diskann_wide::arch::Current,
1231 distances::MathematicalResult<f32>,
1232 DataRef<'a, NBITS>,
1233 DataRef<'a, NBITS>,
1234 >,
1235 {
1236 let mut center = vec![0.0f32; dim];
1237 for trial in 0..ntrials {
1238 center
1240 .iter_mut()
1241 .for_each(|c| *c = StandardNormal {}.sample(rng));
1242
1243 let c_square_norm = FastL2NormSquared.evaluate(&*center);
1244
1245 {
1247 let x = Data::<NBITS, _>::generate_reference(
1248 ¢er,
1249 SupportedMetric::InnerProduct,
1250 rng,
1251 );
1252 let y = Data::<NBITS, _>::generate_reference(
1253 ¢er,
1254 SupportedMetric::InnerProduct,
1255 rng,
1256 );
1257
1258 let kernel_result = {
1259 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1260 &*x.reconstructed,
1261 &*y.reconstructed,
1262 );
1263 x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
1264 };
1265
1266 let reference_ip = kernel_result + x.center_ip + y.center_ip + c_square_norm;
1267 let ip = CompensatedIP::new(¢er, center.len());
1268 let got_ip: distances::MathematicalResult<f32> =
1269 ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1270 let got_ip = got_ip.unwrap();
1271
1272 let ctx = &lazy_format!(
1273 "Inner Product, trial {} of {}, dim = {}",
1274 trial,
1275 ntrials,
1276 dim
1277 );
1278 assert!(err_ip.check(got_ip.into_inner(), reference_ip, Some(ctx)));
1279
1280 let got_ip_f32: distances::Result<f32> =
1281 ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1282
1283 let got_ip_f32 = got_ip_f32.unwrap();
1284
1285 assert_eq!(got_ip_f32, -got_ip.into_inner());
1286
1287 let cosine = CompensatedCosine::new(ip);
1289 let got_cosine: distances::MathematicalResult<f32> =
1290 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1291 let got_cosine = got_cosine.unwrap();
1292 assert_eq!(
1293 got_cosine.into_inner(),
1294 got_ip.into_inner(),
1295 "cosine and IP should be the same"
1296 );
1297
1298 let got_cosine_f32: distances::Result<f32> =
1299 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1300
1301 let got_cosine_f32 = got_cosine_f32.unwrap();
1302
1303 assert_eq!(
1304 got_cosine_f32,
1305 1.0 - got_cosine.into_inner(),
1306 "incorrect transform performed"
1307 );
1308 }
1309
1310 {
1312 let x =
1313 Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
1314 let y =
1315 Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
1316
1317 let kernel_result = {
1319 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1320 &*x.reconstructed,
1321 &*y.reconstructed,
1322 );
1323 x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
1324 };
1325
1326 let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * kernel_result;
1327 let l2 = CompensatedSquaredL2::new(dim);
1328 let got_l2: distances::MathematicalResult<f32> =
1329 ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1330 let got_l2 = got_l2.unwrap();
1331
1332 let ctx =
1333 &lazy_format!("Squared L2, trial {} of {}, dim = {}", trial, ntrials, dim);
1334 assert!(err_l2.check(got_l2.into_inner(), reference_l2, Some(ctx)));
1335
1336 let got_l2_f32: distances::Result<f32> =
1337 ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1338 let got_l2_f32 = got_l2_f32.unwrap();
1339
1340 assert_eq!(got_l2_f32, got_l2.into_inner());
1341 }
1342 }
1343 }
1344
1345 fn test_mixed_compensated_distance<const Q: usize, const D: usize, Perm>(
1348 dim: usize,
1349 ntrials: usize,
1350 err_l2: Approx,
1351 err_ip: Approx,
1352 rng: &mut StdRng,
1353 ) where
1354 Unsigned: Representation<Q>,
1355 Unsigned: Representation<D>,
1356 Perm: PermutationStrategy<Q>,
1357 for<'a> CompensatedIP: Target2<
1358 diskann_wide::arch::Current,
1359 distances::MathematicalResult<f32>,
1360 QueryRef<'a, Q, Perm>,
1361 DataRef<'a, D>,
1362 >,
1363 for<'a> CompensatedSquaredL2: Target2<
1364 diskann_wide::arch::Current,
1365 distances::MathematicalResult<f32>,
1366 QueryRef<'a, Q, Perm>,
1367 DataRef<'a, D>,
1368 >,
1369 for<'a> CompensatedCosine: Target2<
1370 diskann_wide::arch::Current,
1371 distances::MathematicalResult<f32>,
1372 QueryRef<'a, Q, Perm>,
1373 DataRef<'a, D>,
1374 >,
1375 for<'a> CompensatedIP: Target2<
1376 diskann_wide::arch::Current,
1377 distances::Result<f32>,
1378 QueryRef<'a, Q, Perm>,
1379 DataRef<'a, D>,
1380 >,
1381 for<'a> CompensatedSquaredL2: Target2<
1382 diskann_wide::arch::Current,
1383 distances::Result<f32>,
1384 QueryRef<'a, Q, Perm>,
1385 DataRef<'a, D>,
1386 >,
1387 for<'a> CompensatedCosine: Target2<
1388 diskann_wide::arch::Current,
1389 distances::Result<f32>,
1390 QueryRef<'a, Q, Perm>,
1391 DataRef<'a, D>,
1392 >,
1393 {
1394 let mut center = vec![0.0f32; dim];
1396 for trial in 0..ntrials {
1397 center
1399 .iter_mut()
1400 .for_each(|c| *c = StandardNormal {}.sample(rng));
1401
1402 let c_square_norm = FastL2NormSquared.evaluate(&*center);
1403
1404 {
1406 let x = Query::<Q, Perm, _>::generate_reference(
1407 ¢er,
1408 SupportedMetric::InnerProduct,
1409 rng,
1410 );
1411 let y =
1412 Data::<D, _>::generate_reference(¢er, SupportedMetric::InnerProduct, rng);
1413
1414 let xy = {
1416 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1417 &*x.reconstructed,
1418 &*y.reconstructed,
1419 );
1420 x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1421 };
1422
1423 let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
1424 let ip = CompensatedIP::new(¢er, center.len());
1425 let got_ip: distances::Result<f32> =
1426 ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1427 let got_ip = got_ip.unwrap();
1428
1429 let ctx = &lazy_format!(
1430 "Inner Product, trial = {} of {}, dim = {}",
1431 trial,
1432 ntrials,
1433 dim
1434 );
1435
1436 assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
1437
1438 let cosine = CompensatedCosine::new(ip);
1440 let got_cosine: distances::MathematicalResult<f32> =
1441 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1442
1443 let got_cosine = got_cosine.unwrap();
1444 assert_eq!(
1445 got_cosine.into_inner(),
1446 -got_ip,
1447 "cosine and IP should be the same"
1448 );
1449
1450 let got_cosine_f32: distances::Result<f32> =
1451 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1452
1453 let got_cosine_f32 = got_cosine_f32.unwrap();
1454 assert_eq!(
1455 got_cosine_f32,
1456 1.0 - got_cosine.into_inner(),
1457 "incorrect transform performed"
1458 );
1459 }
1460
1461 {
1463 let x = Query::<Q, Perm, _>::generate_reference(
1464 ¢er,
1465 SupportedMetric::SquaredL2,
1466 rng,
1467 );
1468 let y = Data::<D, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
1469
1470 let xy = {
1472 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1473 &*x.reconstructed,
1474 &*y.reconstructed,
1475 );
1476 x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1477 };
1478 let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
1479 let l2 = CompensatedSquaredL2::new(dim);
1480 let got_l2: distances::Result<f32> =
1481 ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1482 let got_l2 = got_l2.unwrap();
1483
1484 let ctx = &lazy_format!(
1485 "Squared L2, trial = {} of {}, dim = {}",
1486 trial,
1487 ntrials,
1488 dim
1489 );
1490
1491 assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
1492 }
1493 }
1494 }
1495
1496 fn test_full_distances<const NBITS: usize>(
1497 dim: usize,
1498 ntrials: usize,
1499 err_l2: Approx,
1500 err_ip: Approx,
1501 rng: &mut StdRng,
1502 ) where
1503 Unsigned: Representation<NBITS>,
1504 for<'a> CompensatedIP: Target2<
1505 diskann_wide::arch::Current,
1506 distances::MathematicalResult<f32>,
1507 FullQueryRef<'a>,
1508 DataRef<'a, NBITS>,
1509 >,
1510 for<'a> CompensatedSquaredL2: Target2<
1511 diskann_wide::arch::Current,
1512 distances::MathematicalResult<f32>,
1513 FullQueryRef<'a>,
1514 DataRef<'a, NBITS>,
1515 >,
1516 for<'a> CompensatedCosine: Target2<
1517 diskann_wide::arch::Current,
1518 distances::MathematicalResult<f32>,
1519 FullQueryRef<'a>,
1520 DataRef<'a, NBITS>,
1521 >,
1522 for<'a> CompensatedIP: Target2<
1523 diskann_wide::arch::Current,
1524 distances::Result<f32>,
1525 FullQueryRef<'a>,
1526 DataRef<'a, NBITS>,
1527 >,
1528 for<'a> CompensatedSquaredL2: Target2<
1529 diskann_wide::arch::Current,
1530 distances::Result<f32>,
1531 FullQueryRef<'a>,
1532 DataRef<'a, NBITS>,
1533 >,
1534 for<'a> CompensatedCosine: Target2<
1535 diskann_wide::arch::Current,
1536 distances::Result<f32>,
1537 FullQueryRef<'a>,
1538 DataRef<'a, NBITS>,
1539 >,
1540 {
1541 let mut center = vec![0.0f32; dim];
1543 for trial in 0..ntrials {
1544 center
1546 .iter_mut()
1547 .for_each(|c| *c = StandardNormal {}.sample(rng));
1548
1549 let c_square_norm = FastL2NormSquared.evaluate(&*center);
1550
1551 {
1553 let x = FullQuery::generate_reference(¢er, SupportedMetric::InnerProduct, rng);
1554 let y = Data::<NBITS, _>::generate_reference(
1555 ¢er,
1556 SupportedMetric::InnerProduct,
1557 rng,
1558 );
1559
1560 let xy = {
1562 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1563 &*x.reconstructed,
1564 &*y.reconstructed,
1565 );
1566 x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1567 };
1568
1569 let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
1570 let ip = CompensatedIP::new(¢er, center.len());
1571 let got_ip: distances::Result<f32> =
1572 ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
1573 let got_ip = got_ip.unwrap();
1574
1575 let ctx = &lazy_format!(
1576 "Inner Product, trial = {} of {}, dim = {}",
1577 trial,
1578 ntrials,
1579 dim
1580 );
1581
1582 assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
1583
1584 let cosine = CompensatedCosine::new(ip);
1586 let got_cosine: distances::MathematicalResult<f32> =
1587 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1588 let got_cosine = got_cosine.unwrap();
1589 assert_eq!(
1590 got_cosine.into_inner(),
1591 -got_ip,
1592 "cosine and IP should be the same"
1593 );
1594
1595 let got_cosine_f32: distances::Result<f32> =
1596 ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
1597
1598 let got_cosine_f32 = got_cosine_f32.unwrap();
1599 assert_eq!(
1600 got_cosine_f32,
1601 1.0 - got_cosine.into_inner(),
1602 "incorrect transform performed"
1603 );
1604 }
1605
1606 {
1608 let x = FullQuery::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
1609 let y =
1610 Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
1611
1612 let xy = {
1614 let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
1615 &*x.reconstructed,
1616 &*y.reconstructed,
1617 );
1618 x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
1619 };
1620
1621 let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
1622 let l2 = CompensatedSquaredL2::new(dim);
1623 let got_l2: distances::Result<f32> =
1624 ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
1625 let got_l2 = got_l2.unwrap();
1626
1627 let ctx = &lazy_format!(
1628 "Squared L2, trial = {} of {}, dim = {}",
1629 trial,
1630 ntrials,
1631 dim
1632 );
1633 assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
1634 }
1635 }
1636 }
1637
1638 cfg_if::cfg_if! {
1639 if #[cfg(miri)] {
1640 const MAX_DIM: usize = 37;
1644 const TRIALS_PER_DIM: usize = 1;
1645 } else {
1646 const MAX_DIM: usize = 256;
1647 const TRIALS_PER_DIM: usize = 20;
1648 }
1649 }
1650
1651 #[test]
1652 fn test_symmetric_distances_1bit() {
1653 let mut rng = StdRng::seed_from_u64(0x2a5f79a2469218f6);
1654 for dim in 1..MAX_DIM {
1655 test_compensated_distance::<1>(
1656 dim,
1657 TRIALS_PER_DIM,
1658 Approx::new(4.0e-3, 3.0e-3),
1659 Approx::new(1.0e-3, 5.0e-4),
1660 &mut rng,
1661 );
1662 }
1663 }
1664
1665 #[test]
1666 fn test_symmetric_distances_2bit() {
1667 let mut rng = StdRng::seed_from_u64(0x68f8f52057f94399);
1668 for dim in 1..MAX_DIM {
1669 test_compensated_distance::<2>(
1670 dim,
1671 TRIALS_PER_DIM,
1672 Approx::new(3.5e-3, 2.0e-3),
1673 Approx::new(2.0e-3, 5.0e-4),
1674 &mut rng,
1675 );
1676 }
1677 }
1678
1679 #[test]
1680 fn test_symmetric_distances_4bit() {
1681 let mut rng = StdRng::seed_from_u64(0xb88d76ac4c58e923);
1682 for dim in 1..MAX_DIM {
1683 test_compensated_distance::<4>(
1684 dim,
1685 TRIALS_PER_DIM,
1686 Approx::new(2.0e-3, 2.0e-3),
1687 Approx::new(2.0e-3, 5.0e-4),
1688 &mut rng,
1689 );
1690 }
1691 }
1692
1693 #[test]
1694 fn test_symmetric_distances_8bit() {
1695 let mut rng = StdRng::seed_from_u64(0x1c2b79873ee32626);
1696 for dim in 1..MAX_DIM {
1697 test_compensated_distance::<8>(
1698 dim,
1699 TRIALS_PER_DIM,
1700 Approx::new(2.0e-3, 2.0e-3),
1701 Approx::new(2.0e-3, 4.0e-4),
1702 &mut rng,
1703 );
1704 }
1705 }
1706
1707 #[test]
1708 fn test_mixed_distances_4x1() {
1709 let mut rng = StdRng::seed_from_u64(0x1efb4d87ed0a8ada);
1710 for dim in 1..MAX_DIM {
1711 test_mixed_compensated_distance::<4, 1, BitTranspose>(
1712 dim,
1713 TRIALS_PER_DIM,
1714 Approx::new(4.0e-3, 3.0e-3),
1715 Approx::new(1.3e-2, 8.3e-3),
1716 &mut rng,
1717 );
1718 }
1719 }
1720
1721 #[test]
1722 fn test_mixed_distances_4x4() {
1723 let mut rng = StdRng::seed_from_u64(0x508554264eb7a51b);
1724 for dim in 1..MAX_DIM {
1725 test_mixed_compensated_distance::<4, 4, Dense>(
1726 dim,
1727 TRIALS_PER_DIM,
1728 Approx::new(4.0e-3, 3.0e-3),
1729 Approx::new(3.0e-4, 8.3e-2),
1730 &mut rng,
1731 );
1732 }
1733 }
1734
1735 #[test]
1736 fn test_mixed_distances_8x8() {
1737 let mut rng = StdRng::seed_from_u64(0x8acd8e4224c76c43);
1738 for dim in 1..MAX_DIM {
1739 test_mixed_compensated_distance::<8, 8, Dense>(
1740 dim,
1741 TRIALS_PER_DIM,
1742 Approx::new(2.0e-3, 6.0e-3),
1743 Approx::new(1.0e-2, 3.0e-2),
1744 &mut rng,
1745 );
1746 }
1747 }
1748
1749 #[test]
1751 fn test_full_distances_1bit() {
1752 let mut rng = StdRng::seed_from_u64(0x7f93530559f42d66);
1753 for dim in 1..MAX_DIM {
1754 test_full_distances::<1>(
1755 dim,
1756 TRIALS_PER_DIM,
1757 Approx::new(1.0e-3, 2.0e-3),
1758 Approx::new(0.0, 5.0e-3),
1759 &mut rng,
1760 );
1761 }
1762 }
1763
1764 #[test]
1765 fn test_full_distances_2bit() {
1766 let mut rng = StdRng::seed_from_u64(0xa3ad61d3d03a0c5a);
1767 for dim in 1..MAX_DIM {
1768 test_full_distances::<2>(
1769 dim,
1770 TRIALS_PER_DIM,
1771 Approx::new(2.0e-3, 1.1e-3),
1772 Approx::new(7.0e-4, 1.0e-3),
1773 &mut rng,
1774 );
1775 }
1776 }
1777
1778 #[test]
1779 fn test_full_distances_4bit() {
1780 let mut rng = StdRng::seed_from_u64(0x3e2f50ed7c64f0c2);
1781 for dim in 1..MAX_DIM {
1782 test_full_distances::<4>(
1783 dim,
1784 TRIALS_PER_DIM,
1785 Approx::new(2.0e-3, 1.0e-2),
1786 Approx::new(1.0e-3, 5.0e-4),
1787 &mut rng,
1788 );
1789 }
1790 }
1791
1792 #[test]
1793 fn test_full_distances_8bit() {
1794 let mut rng = StdRng::seed_from_u64(0x95705070e415c6d3);
1795 for dim in 1..MAX_DIM {
1796 test_full_distances::<8>(
1797 dim,
1798 TRIALS_PER_DIM,
1799 Approx::new(1.0e-3, 1.0e-3),
1800 Approx::new(2.0e-3, 1.0e-4),
1801 &mut rng,
1802 );
1803 }
1804 }
1805}