1use std::num::NonZeroUsize;
7
8use diskann_utils::{ReborrowMut, views::MatrixView};
9use diskann_vector::{
10 MathematicalValue, Norm, PureDistanceFunction, distance::InnerProduct, norm::FastL2Norm,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18 CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19 FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21use crate::{
22 AsFunctor, CompressIntoWith,
23 algorithms::{
24 heap::SliceHeap,
25 transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
26 },
27 alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
28 bits::{PermutationStrategy, Representation, Unsigned},
29 num::Positive,
30 utils::{CannotBeEmpty, compute_means_and_average_norm, compute_normalized_means},
31};
32#[cfg(feature = "flatbuffers")]
33use crate::{
34 algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
35};
36
37#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45 A: Allocator,
46{
47 shift: Poly<[f32], A>,
49
50 transform: Transform<A>,
58
59 metric: SupportedMetric,
61
62 mean_norm: Positive<f32>,
71
72 pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82 A: Allocator,
83{
84 fn try_clone(&self) -> Result<Self, AllocatorError> {
85 Ok(Self {
86 shift: self.shift.try_clone()?,
87 transform: self.transform.try_clone()?,
88 metric: self.metric,
89 mean_norm: self.mean_norm,
90 pre_scale: self.pre_scale,
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98 #[error("data dim cannot be zero")]
99 DimCannotBeZero,
100 #[error("data cannot be empty")]
101 DataCannotBeEmpty,
102 #[error("pre-scale must be positive")]
103 PrescaleNotPositive,
104 #[error("norm must be positive")]
105 NormNotPositive,
106 #[error("computed norm contains infinity or NaN")]
107 NormNotFinite,
108 #[error("reciprocal norm contains infinity or NaN")]
109 ReciprocalNormNotFinite,
110 #[error(transparent)]
111 AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116 A: Allocator,
117{
118 pub fn input_dim(&self) -> usize {
120 self.shift.len()
121 }
122
123 pub fn output_dim(&self) -> usize {
128 self.transform.output_dim()
129 }
130
131 pub fn shift(&self) -> &[f32] {
138 &self.shift
139 }
140
141 pub fn mean_norm(&self) -> Positive<f32> {
143 self.mean_norm
144 }
145
146 pub fn pre_scale(&self) -> Positive<f32> {
151 self.pre_scale
152 }
153
154 pub fn allocator(&self) -> &A {
156 self.shift.allocator()
157 }
158
159 pub fn generate(
161 mut centroid: Poly<[f32], A>,
162 mean_norm: f32,
163 transform: TransformKind,
164 metric: SupportedMetric,
165 pre_scale: Option<f32>,
166 rng: &mut dyn RngCore,
167 allocator: A,
168 ) -> Result<Self, TrainError> {
169 let pre_scale = match pre_scale {
170 Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171 None => crate::num::POSITIVE_ONE_F32,
172 };
173
174 let dim = match NonZeroUsize::new(centroid.len()) {
175 Some(dim) => dim,
176 None => {
177 return Err(TrainError::DimCannotBeZero);
178 }
179 };
180
181 let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183 let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185 Ok(v) => v,
186 Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187 Err(NewTransformError::AllocatorError(err)) => {
188 return Err(TrainError::AllocatorError(err));
189 }
190 };
191
192 centroid
194 .iter_mut()
195 .for_each(|v| *v *= pre_scale.into_inner());
196
197 Ok(SphericalQuantizer {
198 shift: centroid,
199 transform,
200 metric,
201 mean_norm,
202 pre_scale,
203 })
204 }
205
206 pub fn metric(&self) -> SupportedMetric {
208 self.metric
209 }
210
211 pub fn train<T, R>(
226 data: MatrixView<T>,
227 transform: TransformKind,
228 metric: SupportedMetric,
229 pre_scale: PreScale,
230 rng: &mut R,
231 allocator: A,
232 ) -> Result<Self, TrainError>
233 where
234 T: Copy + Into<f64> + Into<f32>,
235 R: Rng,
236 {
237 #[inline(never)]
240 fn train<T, A>(
241 data: MatrixView<T>,
242 transform: TransformKind,
243 metric: SupportedMetric,
244 pre_scale: PreScale,
245 rng: &mut dyn RngCore,
246 allocator: A,
247 ) -> Result<SphericalQuantizer<A>, TrainError>
248 where
249 T: Copy + Into<f64> + Into<f32>,
250 A: Allocator,
251 {
252 if data.ncols() == 0 {
255 return Err(TrainError::DimCannotBeZero);
256 }
257
258 let (centroid, mean_norm) = match metric {
259 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260 compute_means_and_average_norm(data)
261 }
262 SupportedMetric::Cosine => (
263 compute_normalized_means(data)
264 .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265 1.0,
266 ),
267 };
268
269 let mean_norm = mean_norm as f32;
270
271 if mean_norm <= 0.0 {
272 return Err(TrainError::NormNotPositive);
273 }
274
275 if !mean_norm.is_finite() {
276 return Err(TrainError::NormNotFinite);
277 }
278
279 let pre_scale: Positive<f32> = match pre_scale {
281 PreScale::None => crate::num::POSITIVE_ONE_F32,
282 PreScale::Some(v) => v,
283 PreScale::ReciprocalMeanNorm => {
284 let pre_scale = Positive::new(1.0 / mean_norm)
293 .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295 if !pre_scale.into_inner().is_finite() {
296 return Err(TrainError::ReciprocalNormNotFinite);
297 }
298
299 pre_scale
300 }
301 };
302
303 let centroid =
305 Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307 SphericalQuantizer::generate(
308 centroid,
309 mean_norm,
310 transform,
311 metric,
312 Some(pre_scale.into_inner()),
313 rng,
314 allocator,
315 )
316 }
317
318 train(data, transform, metric, pre_scale, rng, allocator)
319 }
320
321 pub fn rescale(&self, v: &mut [f32]) {
323 let norm = FastL2Norm.evaluate(&*v);
324 let m = self.mean_norm.into_inner() / norm;
325 v.iter_mut().for_each(|i| *i *= m);
326 }
327
328 fn preprocess<'a>(
334 &self,
335 data: &[f32],
336 allocator: ScopedAllocator<'a>,
337 ) -> Result<Preprocessed<'a>, CompressionError> {
338 assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340 let scale = self.pre_scale.into_inner();
345 let mul: f32 = match self.metric {
346 SupportedMetric::Cosine => {
347 let norm: f32 = (FastL2Norm).evaluate(data);
348 if norm == 0.0 { 1.0 } else { 1.0 / norm }
349 }
350 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
351 };
352
353 let shifted = Poly::from_iter(
355 std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
356 allocator,
357 )?;
358
359 let shifted_norm = FastL2Norm.evaluate(&*shifted);
360 if !shifted_norm.is_finite() {
361 return Err(CompressionError::InputContainsNaN);
362 }
363 let inner_product_with_centroid = match self.metric {
364 SupportedMetric::SquaredL2 => None,
365 SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
366 let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
367 Some(ip.into_inner())
368 }
369 };
370
371 Ok(Preprocessed {
372 shifted,
373 shifted_norm,
374 inner_product_with_centroid,
375 })
376 }
377}
378
379#[derive(Debug, Clone, Copy)]
383pub enum PreScale {
384 None,
386 Some(Positive<f32>),
388 ReciprocalMeanNorm,
391}
392
393#[cfg(feature = "flatbuffers")]
394#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
395#[derive(Debug, Clone, Error, PartialEq)]
396#[non_exhaustive]
397pub enum DeserializationError {
398 #[error(transparent)]
399 TransformError(#[from] TransformError),
400 #[error("unrecognized flatbuffer identifier")]
401 UnrecognizedIdentifier,
402 #[error("transform length not equal to centroid")]
403 DimMismatch,
404 #[error("norm is missing or is not positive")]
405 MissingNorm,
406 #[error("pre-scale is missing or is not positive")]
407 PreScaleNotPositive,
408
409 #[error(transparent)]
410 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
411
412 #[error(transparent)]
413 InvalidMetric(#[from] InvalidMetric),
414
415 #[error(transparent)]
416 AllocatorError(#[from] AllocatorError),
417}
418
419#[cfg(feature = "flatbuffers")]
420#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
421impl<A> SphericalQuantizer<A>
422where
423 A: Allocator + Clone,
424{
425 pub(crate) fn pack<'a, FA>(
428 &self,
429 buf: &mut FlatBufferBuilder<'a, FA>,
430 ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
431 where
432 FA: flatbuffers::Allocator + 'a,
433 {
434 let centroid = buf.create_vector(&self.shift);
436
437 let transform = self.transform.pack(buf);
439
440 spherical::SphericalQuantizer::create(
442 buf,
443 &spherical::SphericalQuantizerArgs {
444 centroid: Some(centroid),
445 transform: Some(transform),
446 metric: self.metric.into(),
447 mean_norm: self.mean_norm.into_inner(),
448 pre_scale: self.pre_scale.into_inner(),
449 },
450 )
451 }
452
453 pub(crate) fn try_unpack(
456 alloc: A,
457 proto: spherical::SphericalQuantizer<'_>,
458 ) -> Result<Self, DeserializationError> {
459 let metric: SupportedMetric = proto.metric().try_into()?;
460
461 let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
463
464 let transform = Transform::try_unpack(alloc, proto.transform())?;
466
467 if shift.len() != transform.input_dim() {
469 return Err(DeserializationError::DimMismatch);
470 }
471
472 let mean_norm =
474 Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
475
476 let pre_scale = Positive::new(proto.pre_scale())
477 .map_err(|_| DeserializationError::PreScaleNotPositive)?;
478
479 Ok(Self {
480 shift,
481 transform,
482 metric,
483 mean_norm,
484 pre_scale,
485 })
486 }
487}
488
489struct Preprocessed<'a> {
490 shifted: Poly<[f32], ScopedAllocator<'a>>,
491 shifted_norm: f32,
492 inner_product_with_centroid: Option<f32>,
493}
494
495impl Preprocessed<'_> {
496 fn metric_specific(&self) -> f32 {
501 match self.inner_product_with_centroid {
502 Some(ip) => ip,
503 None => self.shifted_norm * self.shifted_norm,
504 }
505 }
506}
507
508impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
513where
514 A: Allocator,
515{
516 fn as_functor(&self) -> CompensatedSquaredL2 {
517 CompensatedSquaredL2::new(self.output_dim())
518 }
519}
520
521impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
522where
523 A: Allocator,
524{
525 fn as_functor(&self) -> CompensatedIP {
526 CompensatedIP::new(&self.shift, self.output_dim())
527 }
528}
529
530impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
531where
532 A: Allocator,
533{
534 fn as_functor(&self) -> CompensatedCosine {
535 CompensatedCosine::new(self.as_functor())
536 }
537}
538
539#[derive(Debug, Error, Clone, Copy, PartialEq)]
544#[non_exhaustive]
545pub enum CompressionError {
546 #[error("input contains NaN")]
547 InputContainsNaN,
548
549 #[error("expected source vector to have length {expected}")]
550 SourceDimensionMismatch { expected: usize },
551
552 #[error("expected destination vector to have length {expected}")]
553 DestinationDimensionMismatch { expected: usize },
554
555 #[error(
556 "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
557 )]
558 EncodingError(#[from] DataMetaError),
559
560 #[error(transparent)]
561 AllocatorError(#[from] AllocatorError),
562}
563
564fn check_dims(
565 input: usize,
566 output: usize,
567 from: usize,
568 into: usize,
569) -> Result<(), CompressionError> {
570 if from != input {
571 return Err(CompressionError::SourceDimensionMismatch { expected: input });
572 }
573 if into != output {
574 return Err(CompressionError::DestinationDimensionMismatch { expected: output });
575 }
576 Ok(())
577}
578
579trait FinishCompressing {
582 fn finish_compressing(
583 &mut self,
584 preprocessed: &Preprocessed<'_>,
585 transformed: &[f32],
586 transformed_norm: f32,
587 allocator: ScopedAllocator<'_>,
588 ) -> Result<(), CompressionError>;
589}
590
591impl FinishCompressing for DataMut<'_, 1> {
592 fn finish_compressing(
593 &mut self,
594 preprocessed: &Preprocessed<'_>,
595 transformed: &[f32],
596 transformed_norm: f32,
597 _: ScopedAllocator<'_>,
598 ) -> Result<(), CompressionError> {
599 let mut quant_raw_inner_product = 0.0f32;
602 let mut bit_sum = 0u32;
603 transformed.iter().enumerate().for_each(|(i, &r)| {
604 let bit: u8 = if r > 0.0 { 1 } else { 0 };
605
606 quant_raw_inner_product += r.abs();
607 bit_sum += <u8 as Into<u32>>::into(bit);
608
609 unsafe { self.vector_mut().set_unchecked(i, bit) };
611 });
612
613 let inner_product_correction =
634 2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
635 self.set_meta(DataMeta::new(
636 inner_product_correction,
637 preprocessed.metric_specific(),
638 bit_sum,
639 )?);
640 Ok(())
641 }
642}
643
644impl FinishCompressing for DataMut<'_, 2> {
645 fn finish_compressing(
646 &mut self,
647 preprocessed: &Preprocessed<'_>,
648 transformed: &[f32],
649 transformed_norm: f32,
650 allocator: ScopedAllocator<'_>,
651 ) -> Result<(), CompressionError> {
652 compress_via_maximum_cosine(
653 self.reborrow_mut(),
654 preprocessed,
655 transformed,
656 transformed_norm,
657 allocator,
658 )
659 }
660}
661
662impl FinishCompressing for DataMut<'_, 4> {
663 fn finish_compressing(
664 &mut self,
665 preprocessed: &Preprocessed<'_>,
666 transformed: &[f32],
667 transformed_norm: f32,
668 allocator: ScopedAllocator<'_>,
669 ) -> Result<(), CompressionError> {
670 compress_via_maximum_cosine(
671 self.reborrow_mut(),
672 preprocessed,
673 transformed,
674 transformed_norm,
675 allocator,
676 )
677 }
678}
679
680impl FinishCompressing for DataMut<'_, 8> {
681 fn finish_compressing(
682 &mut self,
683 preprocessed: &Preprocessed<'_>,
684 transformed: &[f32],
685 transformed_norm: f32,
686 allocator: ScopedAllocator<'_>,
687 ) -> Result<(), CompressionError> {
688 compress_via_maximum_cosine(
689 self.reborrow_mut(),
690 preprocessed,
691 transformed,
692 transformed_norm,
693 allocator,
694 )
695 }
696}
697
698impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
703where
704 A: Allocator,
705{
706 type Error = CompressionError;
707
708 fn compress_into_with(
721 &self,
722 from: &[f32],
723 mut into: FullQueryMut<'_>,
724 allocator: ScopedAllocator<'_>,
725 ) -> Result<(), Self::Error> {
726 let input_dim = self.shift.len();
727 let output_dim = self.output_dim();
728 check_dims(input_dim, output_dim, from.len(), into.len())?;
729
730 let mut preprocessed = self.preprocess(from, allocator)?;
731
732 if preprocessed.shifted_norm == 0.0 {
735 into.vector_mut().fill(0.0);
736 *into.meta_mut() = Default::default();
737 return Ok(());
738 }
739
740 preprocessed
741 .shifted
742 .iter_mut()
743 .for_each(|v| *v /= preprocessed.shifted_norm);
744
745 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
750 match self
751 .transform
752 .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
753 {
754 Ok(()) => {}
755 Err(TransformFailed::AllocatorError(err)) => {
756 return Err(CompressionError::AllocatorError(err));
757 }
758 Err(TransformFailed::SourceMismatch { .. })
759 | Err(TransformFailed::DestinationMismatch { .. }) => {
760 panic!(
761 "The sizes of these arrays should already be checked - this is a logic error"
762 );
763 }
764 #[cfg(feature = "linalg")]
765 Err(TransformFailed::SgemmError(_)) => {
766 panic!("SGEMM should not fail with valid dimensions - this is a logic error");
767 }
768 }
769
770 *into.meta_mut() = FullQueryMeta {
771 sum: into.vector().iter().sum::<f32>(),
772 shifted_norm: preprocessed.shifted_norm,
773 metric_specific: preprocessed.metric_specific(),
774 };
775 Ok(())
776 }
777}
778
779impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
780 for SphericalQuantizer<A>
781where
782 A: Allocator,
783 Unsigned: Representation<NBITS>,
784 for<'a> DataMut<'a, NBITS>: FinishCompressing,
785{
786 type Error = CompressionError;
787
788 fn compress_into_with(
801 &self,
802 from: &[f32],
803 mut into: DataMut<'_, NBITS>,
804 allocator: ScopedAllocator<'_>,
805 ) -> Result<(), Self::Error> {
806 let input_dim = self.shift.len();
807 let output_dim = self.output_dim();
808 check_dims(input_dim, output_dim, from.len(), into.len())?;
809
810 let mut preprocessed = self.preprocess(from, allocator)?;
811
812 if preprocessed.shifted_norm == 0.0 {
813 into.set_meta(DataMeta::default());
814 return Ok(());
815 }
816
817 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
818 preprocessed
819 .shifted
820 .iter_mut()
821 .for_each(|v| *v /= preprocessed.shifted_norm);
822
823 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
828 match self
829 .transform
830 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
831 {
832 Ok(()) => {}
833 Err(TransformFailed::AllocatorError(err)) => {
834 return Err(CompressionError::AllocatorError(err));
835 }
836 Err(TransformFailed::SourceMismatch { .. })
837 | Err(TransformFailed::DestinationMismatch { .. }) => {
838 panic!(
839 "The sizes of these arrays should already be checked - this is a logic error"
840 );
841 }
842 #[cfg(feature = "linalg")]
843 Err(TransformFailed::SgemmError(_)) => {
844 panic!("SGEMM should not fail with valid dimensions - this is a logic error");
845 }
846 }
847
848 let transformed_norm = if self.transform.preserves_norms() {
849 1.0
850 } else {
851 (FastL2Norm).evaluate(&*transformed)
852 };
853
854 into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
855 Ok(())
856 }
857}
858
859struct AsNonZero<const NBITS: usize>;
860impl<const NBITS: usize> AsNonZero<NBITS> {
861 #[allow(clippy::unwrap_used)]
863 const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
864}
865
866fn compress_via_maximum_cosine<const NBITS: usize>(
867 mut data: DataMut<'_, NBITS>,
868 preprocessed: &Preprocessed<'_>,
869 transformed: &[f32],
870 transformed_norm: f32,
871 allocator: ScopedAllocator<'_>,
872) -> Result<(), CompressionError>
873where
874 Unsigned: Representation<NBITS>,
875{
876 assert_eq!(data.len(), transformed.len());
877
878 let optimal_scale =
881 maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
882
883 let domain = Unsigned::domain_const::<NBITS>();
884 let min = *domain.start() as f32;
885 let max = *domain.end() as f32;
886 let offset = max / 2.0;
887
888 let mut self_inner_product = 0.0f32;
889 let mut bit_sum = 0u32;
890 for (i, t) in transformed.iter().enumerate() {
891 let v = (*t * optimal_scale + offset).clamp(min, max).round();
892 let dv = v - offset;
893 self_inner_product = dv.mul_add(*t, self_inner_product);
894
895 let v = v as u8;
896 bit_sum += <u8 as Into<u32>>::into(v);
897
898 unsafe { data.vector_mut().set_unchecked(i, v) };
903 }
904
905 let shifted_norm = preprocessed.shifted_norm;
906 let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
907 data.set_meta(DataMeta::new(
908 inner_product_correction,
909 preprocessed.metric_specific(),
910 bit_sum,
911 )?);
912 Ok(())
913}
914
915#[derive(Debug, Clone, Copy)]
929struct Pair {
930 value: f32,
931 position: u32,
932}
933
934impl PartialEq for Pair {
935 fn eq(&self, other: &Self) -> bool {
936 self.value.eq(&other.value)
937 }
938}
939
940impl Eq for Pair {}
941impl PartialOrd for Pair {
942 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
943 Some(self.cmp(other))
944 }
945}
946impl Ord for Pair {
947 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
948 other
949 .value
950 .partial_cmp(&self.value)
951 .unwrap_or(std::cmp::Ordering::Equal)
952 }
953}
954
955fn maximize_cosine_similarity(
987 v: &[f32],
988 num_bits: NonZeroUsize,
989 allocator: ScopedAllocator<'_>,
990) -> Result<f32, AllocatorError> {
991 let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
996 let mut current_square_norm = 0.25 * (v.len() as f64);
997
998 let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
1002
1003 let eps = 0.0001f32;
1009 let one_and_change = 1.0 + eps;
1010 let mut base = Poly::from_iter(
1011 v.iter().enumerate().map(|(position, value)| {
1012 let value = one_and_change / value.abs();
1013 Pair {
1014 value,
1015 position: position as u32,
1016 }
1017 }),
1018 allocator,
1019 )?;
1020
1021 #[allow(clippy::expect_used)]
1024 let mut critical_values =
1025 SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1026
1027 let mut max_similarity = f64::NEG_INFINITY;
1028 let mut optimal_scale = f32::default();
1029 let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1030
1031 loop {
1032 let mut should_break = false;
1033 critical_values.update_root(|pair| {
1034 let Pair { value, position } = *pair;
1035 if value == f32::MAX {
1036 should_break = true;
1037 return;
1038 }
1039
1040 let r = &mut rounded[position as usize];
1041 let vp = &v[position as usize];
1042
1043 let old_r = *r;
1044 *r += 1;
1047
1048 current_ip += vp.abs() as f64;
1053
1054 current_square_norm += (2 * old_r) as f64;
1068
1069 let similarity = current_ip / current_square_norm.sqrt();
1071 if similarity > max_similarity {
1072 max_similarity = similarity;
1073 optimal_scale = value;
1074 }
1075
1076 if *r < stop {
1078 *pair = Pair {
1079 value: (*r as f32 + eps) / vp.abs(),
1080 position,
1081 };
1082 } else {
1083 *pair = Pair {
1084 value: f32::MAX,
1085 position,
1086 };
1087 }
1088 });
1089 if should_break {
1090 break;
1091 }
1092 }
1093
1094 Ok(optimal_scale)
1095}
1096
1097impl<const NBITS: usize, Perm, A>
1102 CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1103 for SphericalQuantizer<A>
1104where
1105 Unsigned: Representation<NBITS>,
1106 Perm: PermutationStrategy<NBITS>,
1107 A: Allocator,
1108{
1109 type Error = CompressionError;
1110
1111 fn compress_into_with(
1124 &self,
1125 from: &[f32],
1126 mut into: QueryMut<'_, NBITS, Perm>,
1127 allocator: ScopedAllocator<'_>,
1128 ) -> Result<(), Self::Error> {
1129 let input_dim = self.shift.len();
1130 let output_dim = self.output_dim();
1131 check_dims(input_dim, output_dim, from.len(), into.len())?;
1132
1133 let mut preprocessed = self.preprocess(from, allocator)?;
1134
1135 if preprocessed.shifted_norm == 0.0 {
1136 into.set_meta(QueryMeta::default());
1137 return Ok(());
1138 }
1139
1140 preprocessed
1141 .shifted
1142 .iter_mut()
1143 .for_each(|v| *v /= preprocessed.shifted_norm);
1144
1145 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1146
1147 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1152 match self
1153 .transform
1154 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1155 {
1156 Ok(()) => {}
1157 Err(TransformFailed::AllocatorError(err)) => {
1158 return Err(CompressionError::AllocatorError(err));
1159 }
1160 Err(TransformFailed::SourceMismatch { .. })
1161 | Err(TransformFailed::DestinationMismatch { .. }) => {
1162 panic!(
1163 "The sizes of these arrays should already be checked - this is a logic error"
1164 );
1165 }
1166 #[cfg(feature = "linalg")]
1167 Err(TransformFailed::SgemmError(_)) => {
1168 panic!("SGEMM should not fail with valid dimensions - this is a logic error");
1169 }
1170 }
1171
1172 let (min, max) = transformed
1174 .iter()
1175 .fold((f32::MAX, f32::MIN), |(min, max), i| {
1176 (i.min(min), i.max(max))
1177 });
1178
1179 let domain = Unsigned::domain_const::<NBITS>();
1180 let lo = (*domain.start()) as f32;
1181 let hi = (*domain.end()) as f32;
1182
1183 let scale = (max - min) / hi;
1184 let mut bit_sum: f32 = 0.0;
1185 transformed.iter().enumerate().for_each(|(i, v)| {
1186 let c = ((v - min) / scale).round().clamp(lo, hi);
1187 bit_sum += c;
1188
1189 #[allow(clippy::unwrap_used)]
1195 into.vector_mut().set(i, c as i64).unwrap();
1196 });
1197
1198 into.set_meta(QueryMeta {
1200 inner_product_correction: preprocessed.shifted_norm * scale,
1201 bit_sum,
1202 offset: min / scale,
1203 metric_specific: preprocessed.metric_specific(),
1204 });
1205
1206 Ok(())
1207 }
1208}
1209
1210#[cfg(not(miri))]
1215#[cfg(test)]
1216mod tests {
1217 use super::*;
1218
1219 use std::fmt::Display;
1220
1221 use diskann_utils::{
1222 ReborrowMut, lazy_format,
1223 views::{self, Matrix},
1224 };
1225 use diskann_vector::{PureDistanceFunction, norm::FastL2NormSquared};
1226 use diskann_wide::ARCH;
1227 use rand::{
1228 SeedableRng,
1229 distr::{Distribution, Uniform},
1230 rngs::StdRng,
1231 };
1232 use rand_distr::StandardNormal;
1233
1234 use crate::{
1235 algorithms::transforms::TargetDim,
1236 alloc::GlobalAllocator,
1237 bits::{BitTranspose, Dense},
1238 spherical::{Data, DataMetaF32, FullQuery, Query},
1239 test_util,
1240 };
1241
1242 #[test]
1244 fn test_cosine_similarity_maximizer() {
1245 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1246 let num_trials = 10000;
1247 let num_bits = NonZeroUsize::new(3).unwrap();
1248
1249 let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1250
1251 let run_test = |target: [f32; 4]| {
1252 let scale =
1253 maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1254
1255 let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1256 let mut best_similarity: f32 = f32::NEG_INFINITY;
1257
1258 let min = -3.5;
1261 for i0 in 0..8 {
1262 for i1 in 0..8 {
1263 for i2 in 0..8 {
1264 for i3 in 0..8 {
1265 let p: [f32; 4] = [
1266 min + (i0 as f32),
1267 min + (i1 as f32),
1268 min + (i2 as f32),
1269 min + (i3 as f32),
1270 ];
1271
1272 let sim: MathematicalValue<f32> =
1273 diskann_vector::distance::Cosine::evaluate(&p, &target);
1274 let sim = sim.into_inner();
1275 if sim > best_similarity {
1276 best_similarity = sim;
1277 best = p.map(|i| i - min);
1279 }
1280 }
1281 }
1282 }
1283 }
1284
1285 let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1288 let clamped_cosine: MathematicalValue<f32> =
1289 diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1290
1291 let passed = if best == clamped {
1294 true
1295 } else {
1296 let ratio: Vec<f32> = std::iter::zip(best, clamped)
1297 .map(|(b, c)| {
1298 let ratio = (b + min) / (c + min);
1299 assert_ne!(
1300 ratio, 0.0,
1301 "ratio should never be zero because `b` is an integer and \
1302 `min` is not"
1303 );
1304 ratio
1305 })
1306 .collect();
1307
1308 ratio.iter().all(|i| *i == ratio[0])
1309 };
1310
1311 if !passed {
1312 panic!(
1313 "failed for input {:?}.\
1314 Best = {:?}, Found = {:?}\
1315 Best similarity = {}, similarity with clamped = {}",
1316 target,
1317 best,
1318 clamped,
1319 best_similarity,
1320 clamped_cosine.into_inner()
1321 );
1322 }
1323 };
1324
1325 let min = -3.5;
1327 for i0 in (0..8).step_by(2) {
1328 for i1 in (1..9).step_by(2) {
1329 for i2 in (0..8).step_by(2) {
1330 for i3 in (1..9).step_by(2) {
1331 let p: [f32; 4] = [
1332 min + (i0 as f32),
1333 min + (i1 as f32),
1334 min + (i2 as f32),
1335 min + (i3 as f32),
1336 ];
1337 run_test(p)
1338 }
1339 }
1340 }
1341 }
1342
1343 for _ in 0..num_trials {
1344 let this_scale: f32 = scale_distribution.sample(&mut rng);
1345 let v: [f32; 4] = [(); 4].map(|_| {
1346 let v: f32 = StandardNormal {}.sample(&mut rng);
1347 this_scale * v
1348 });
1349 run_test(v);
1350 }
1351 }
1352
1353 #[test]
1354 #[should_panic(expected = "calling code should not allow the slice to be empty")]
1355 fn empty_slice_panics() {
1356 maximize_cosine_similarity(
1357 &[],
1358 NonZeroUsize::new(4).unwrap(),
1359 ScopedAllocator::global(),
1360 )
1361 .unwrap();
1362 }
1363
1364 struct Setup {
1365 transform: TransformKind,
1366 nrows: usize,
1367 ncols: usize,
1368 num_trials: usize,
1369 }
1370
1371 fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1372 match scale {
1373 PreScale::None => 1.0,
1374 PreScale::Some(v) => v.into_inner(),
1375 PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1376 }
1377 }
1378
1379 fn test_l2<const Q: usize, const D: usize, Perm>(
1380 setup: &Setup,
1381 problem: &test_util::TestProblem,
1382 computed_means: &[f32],
1383 pre_scale: PreScale,
1384 rng: &mut StdRng,
1385 ) where
1386 Unsigned: Representation<Q>,
1387 Unsigned: Representation<D>,
1388 Perm: PermutationStrategy<Q>,
1389 for<'a> SphericalQuantizer:
1390 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1391 for<'a> SphericalQuantizer:
1392 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1393 {
1394 assert_eq!(setup.nrows, problem.data.nrows());
1395 assert_eq!(setup.ncols, problem.data.ncols());
1396
1397 let scoped_global = ScopedAllocator::global();
1398 let distribution = Uniform::new(0, setup.nrows).unwrap();
1399 let quantizer = SphericalQuantizer::train(
1400 problem.data.as_view(),
1401 setup.transform,
1402 SupportedMetric::SquaredL2,
1403 pre_scale,
1404 rng,
1405 GlobalAllocator,
1406 )
1407 .unwrap();
1408
1409 let scale = get_scale(pre_scale, &quantizer);
1410
1411 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1412 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1413 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1414
1415 assert_eq!(
1416 quantizer.mean_norm.into_inner(),
1417 problem.mean_norm as f32,
1418 "computed mean norm should not apply scale"
1419 );
1420 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1421 assert_eq!(&*scaled_means, quantizer.shift());
1422
1423 let l2: CompensatedSquaredL2 = quantizer.as_functor();
1424 assert_eq!(l2.dim, quantizer.output_dim() as f32);
1425
1426 for _ in 0..setup.num_trials {
1427 let i = distribution.sample(rng);
1428 let v = problem.data.row(i);
1429
1430 quantizer
1431 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1432 .unwrap();
1433 quantizer
1434 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1435 .unwrap();
1436 quantizer
1437 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1438 .unwrap();
1439
1440 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1441 .map(|(a, b)| scale * a - b)
1442 .collect();
1443
1444 {
1446 let DataMetaF32 {
1447 inner_product_correction,
1448 bit_sum,
1449 metric_specific,
1450 } = b.meta().to_full(ARCH);
1451
1452 let shifted_square_norm = metric_specific;
1453
1454 let bv = b.vector();
1456 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1457 assert_eq!(s, bit_sum as usize);
1458
1459 {
1461 let expected = FastL2NormSquared.evaluate(&*shifted);
1462 let err = (shifted_square_norm - expected).abs() / expected.abs();
1463 assert!(
1464 err < 5.0e-4, "failed diff check, got {}, expected {} - relative error = {}",
1466 shifted_square_norm,
1467 expected,
1468 err
1469 );
1470 }
1471
1472 if const { D == 1 } {
1475 let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1476 / (inner_product_correction * (bv.len() as f32).sqrt());
1477 assert!(
1478 (self_inner_product - 0.8).abs() < 0.13,
1479 "self inner-product should be close to 0.8. Instead, it's {}",
1480 self_inner_product
1481 );
1482 }
1483 }
1484
1485 {
1486 let QueryMeta {
1487 inner_product_correction,
1488 bit_sum,
1489 offset,
1490 metric_specific,
1491 } = q.meta();
1492
1493 let shifted_square_norm = metric_specific;
1494 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1495 preprocessed
1496 .shifted
1497 .iter_mut()
1498 .for_each(|i| *i /= preprocessed.shifted_norm);
1499
1500 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1501 quantizer
1502 .transform
1503 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1504 .unwrap();
1505
1506 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1507 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1508
1509 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1510
1511 {
1513 let expected = FastL2NormSquared.evaluate(&*shifted);
1514 let err = (shifted_square_norm - expected).abs() / expected.abs();
1515 assert!(
1516 err < 2e-7,
1517 "failed diff check, got {}, expected {} - relative error = {}",
1518 shifted_square_norm,
1519 expected,
1520 err
1521 );
1522 }
1523
1524 {
1526 let expected = shifted_square_norm.sqrt() * scale;
1527 let got = inner_product_correction;
1528
1529 let err = (expected - got).abs();
1530 assert!(
1531 err < 1.0e-7,
1532 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1533 expected,
1534 got,
1535 err
1536 );
1537 }
1538
1539 {
1541 let expected = min / scale;
1542 let got = offset;
1543
1544 let err = (expected - got).abs();
1545 assert!(
1546 err < 1.0e-7,
1547 "\"sum_scale\": expected {}, got {}, error = {}",
1548 expected,
1549 got,
1550 err
1551 );
1552 }
1553
1554 {
1556 let expected = (0..q.len())
1557 .map(|i| q.vector().get(i).unwrap())
1558 .sum::<i64>() as f32;
1559
1560 let got = bit_sum;
1561
1562 let err = (expected - got).abs();
1563 assert!(
1564 err < 1.0e-7,
1565 "\"offset\": expected {}, got {}, error = {}",
1566 expected,
1567 got,
1568 err
1569 );
1570 }
1571 }
1572
1573 {
1575 let s: f32 = f.data.iter().sum::<f32>();
1577 assert_eq!(s, f.meta.sum);
1578
1579 {
1581 let expected = FastL2Norm.evaluate(&*shifted);
1582 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1583 assert!(
1584 err < 2e-7,
1585 "failed diff check, got {}, expected {} - relative error = {}",
1586 f.meta.shifted_norm,
1587 expected,
1588 err
1589 );
1590 }
1591
1592 assert_eq!(
1593 f.meta.metric_specific,
1594 f.meta.shifted_norm * f.meta.shifted_norm,
1595 "metric specific data for squared l2 is the square shifted norm",
1596 );
1597 }
1598 }
1599
1600 quantizer
1603 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1604 .unwrap();
1605 assert_eq!(b.meta(), DataMeta::default());
1606
1607 quantizer
1608 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1609 .unwrap();
1610 assert_eq!(q.meta(), QueryMeta::default());
1611
1612 f.data.fill(f32::INFINITY);
1613 quantizer
1614 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1615 .unwrap();
1616 assert!(f.data.iter().all(|&i| i == 0.0));
1617 assert_eq!(f.meta.sum, 0.0);
1618 assert_eq!(f.meta.metric_specific, 0.0);
1619 }
1620
1621 fn test_ip<const Q: usize, const D: usize, Perm>(
1622 setup: &Setup,
1623 problem: &test_util::TestProblem,
1624 computed_means: &[f32],
1625 pre_scale: PreScale,
1626 rng: &mut StdRng,
1627 ctx: &dyn Display,
1628 ) where
1629 Unsigned: Representation<Q>,
1630 Unsigned: Representation<D>,
1631 Perm: PermutationStrategy<Q>,
1632 for<'a> SphericalQuantizer:
1633 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1634 for<'a> SphericalQuantizer:
1635 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1636 {
1637 assert_eq!(setup.nrows, problem.data.nrows());
1638 assert_eq!(setup.ncols, problem.data.ncols());
1639
1640 let scoped_global = ScopedAllocator::global();
1641 let distribution = Uniform::new(0, setup.nrows).unwrap();
1642 let quantizer = SphericalQuantizer::train(
1643 problem.data.as_view(),
1644 setup.transform,
1645 SupportedMetric::InnerProduct,
1646 pre_scale,
1647 rng,
1648 GlobalAllocator,
1649 )
1650 .unwrap();
1651
1652 let scale = get_scale(pre_scale, &quantizer);
1653
1654 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1655 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1656 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1657
1658 assert_eq!(
1659 quantizer.mean_norm.into_inner(),
1660 problem.mean_norm as f32,
1661 "computed mean norm should not apply scale"
1662 );
1663 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1664 assert_eq!(&*scaled_means, quantizer.shift());
1665
1666 let ip: CompensatedIP = quantizer.as_functor();
1667
1668 assert_eq!(ip.dim, quantizer.output_dim() as f32);
1669 assert_eq!(
1670 ip.squared_shift_norm,
1671 FastL2NormSquared.evaluate(quantizer.shift())
1672 );
1673
1674 for _ in 0..setup.num_trials {
1675 let i = distribution.sample(rng);
1676 let v = problem.data.row(i);
1677
1678 quantizer
1679 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1680 .unwrap();
1681 quantizer
1682 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1683 .unwrap();
1684 quantizer
1685 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1686 .unwrap();
1687
1688 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1689 .map(|(a, b)| scale * a - b)
1690 .collect();
1691
1692 {
1694 let DataMetaF32 {
1695 inner_product_correction,
1696 bit_sum,
1697 metric_specific,
1698 } = b.meta().to_full(ARCH);
1699
1700 let inner_product_with_centroid = metric_specific;
1701
1702 let bv = b.vector();
1704 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1705 assert_eq!(s, bit_sum as usize);
1706
1707 let inner_product: MathematicalValue<f32> =
1709 InnerProduct::evaluate(&*shifted, quantizer.shift());
1710
1711 let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1712 assert!(
1713 diff < 1.53e-5,
1714 "got a diff of {}. Expected = {}, got = {} -- context: {}",
1715 diff,
1716 inner_product.into_inner(),
1717 inner_product_with_centroid,
1718 ctx,
1719 );
1720
1721 if const { D == 1 } {
1724 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1725 / (inner_product_correction * (bv.len() as f32).sqrt());
1726 assert!(
1727 (self_inner_product - 0.8).abs() < 0.12,
1728 "self inner-product should be close to 0.8. Instead, it's {}",
1729 self_inner_product
1730 );
1731 }
1732 }
1733
1734 {
1735 let QueryMeta {
1736 inner_product_correction,
1737 bit_sum,
1738 offset,
1739 metric_specific,
1740 } = q.meta();
1741
1742 let inner_product_with_centroid = metric_specific;
1743 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1744 preprocessed
1745 .shifted
1746 .iter_mut()
1747 .for_each(|i| *i /= preprocessed.shifted_norm);
1748
1749 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1750 quantizer
1751 .transform
1752 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1753 .unwrap();
1754
1755 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1756 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1757
1758 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1759
1760 {
1762 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1763 let got = inner_product_correction;
1764
1765 let err = (expected - got).abs();
1766 assert!(
1767 err < 1.0e-7,
1768 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1769 expected,
1770 got,
1771 err
1772 );
1773 }
1774
1775 {
1777 let expected = min / scale;
1778 let got = offset;
1779
1780 let err = (expected - got).abs();
1781 assert!(
1782 err < 1.0e-7,
1783 "\"sum_scale\": expected {}, got {}, error = {}",
1784 expected,
1785 got,
1786 err
1787 );
1788 }
1789
1790 {
1792 let expected = (0..q.len())
1793 .map(|i| q.vector().get(i).unwrap())
1794 .sum::<i64>() as f32;
1795
1796 let got = bit_sum;
1797
1798 let err = (expected - got).abs();
1799 assert!(
1800 err < 1.0e-7,
1801 "\"offset\": expected {}, got {}, error = {}",
1802 expected,
1803 got,
1804 err
1805 );
1806 }
1807
1808 {
1810 let inner_product: MathematicalValue<f32> =
1812 InnerProduct::evaluate(&*shifted, quantizer.shift());
1813 assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1814 }
1815 }
1816
1817 {
1819 let s: f32 = f.data.iter().sum::<f32>();
1821 assert_eq!(s, f.meta.sum);
1822
1823 {
1825 let expected = FastL2Norm.evaluate(&*shifted);
1826 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1827 assert!(
1828 err < 2e-7,
1829 "failed diff check, got {}, expected {} - relative error = {}",
1830 f.meta.shifted_norm,
1831 expected,
1832 err
1833 );
1834 }
1835
1836 let inner_product: MathematicalValue<f32> =
1838 InnerProduct::evaluate(&*shifted, quantizer.shift());
1839 assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1840 }
1841 }
1842
1843 quantizer
1846 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1847 .unwrap();
1848 assert_eq!(b.meta(), DataMeta::default());
1849
1850 quantizer
1851 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1852 .unwrap();
1853 assert_eq!(q.meta(), QueryMeta::default());
1854
1855 f.data.fill(f32::INFINITY);
1856 quantizer
1857 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1858 .unwrap();
1859 assert!(f.data.iter().all(|&i| i == 0.0));
1860 assert_eq!(f.meta.sum, 0.0);
1861 assert_eq!(f.meta.metric_specific, 0.0);
1862 }
1863
1864 fn test_cosine<const Q: usize, const D: usize, Perm>(
1865 setup: &Setup,
1866 problem: &test_util::TestProblem,
1867 pre_scale: PreScale,
1868 rng: &mut StdRng,
1869 ) where
1870 Unsigned: Representation<Q>,
1871 Unsigned: Representation<D>,
1872 Perm: PermutationStrategy<Q>,
1873 for<'a> SphericalQuantizer:
1874 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1875 for<'a> SphericalQuantizer:
1876 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1877 {
1878 assert_eq!(setup.nrows, problem.data.nrows());
1879 assert_eq!(setup.ncols, problem.data.ncols());
1880
1881 let scoped_global = ScopedAllocator::global();
1882 let distribution = Uniform::new(0, setup.nrows).unwrap();
1883 let quantizer = SphericalQuantizer::train(
1884 problem.data.as_view(),
1885 setup.transform,
1886 SupportedMetric::Cosine,
1887 pre_scale,
1888 rng,
1889 GlobalAllocator,
1890 )
1891 .unwrap();
1892
1893 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1894 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1895 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1896
1897 let cosine: CompensatedCosine = quantizer.as_functor();
1898
1899 assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1900 assert_eq!(
1901 cosine.inner.squared_shift_norm,
1902 FastL2NormSquared.evaluate(quantizer.shift())
1903 );
1904
1905 const IP_BOUND: f32 = 2.6e-3;
1906
1907 let mut test_row = |v: &[f32]| {
1908 let vnorm = (FastL2Norm).evaluate(v);
1909 let v_normalized: Vec<f32> = v
1910 .iter()
1911 .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1912 .collect();
1913
1914 quantizer
1915 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1916 .unwrap();
1917
1918 quantizer
1919 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1920 .unwrap();
1921
1922 quantizer
1923 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1924 .unwrap();
1925
1926 let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1927 .map(|(a, b)| a - b)
1928 .collect();
1929
1930 {
1932 let DataMetaF32 {
1933 inner_product_correction,
1934 bit_sum,
1935 metric_specific,
1936 } = b.meta().to_full(ARCH);
1937
1938 let inner_product_with_centroid = metric_specific;
1939
1940 let bv = b.vector();
1942 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1943 assert_eq!(s, bit_sum as usize);
1944
1945 let inner_product: MathematicalValue<f32> =
1948 InnerProduct::evaluate(&*shifted, quantizer.shift());
1949
1950 let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1951 let relative = abs / inner_product.into_inner().abs();
1952
1953 assert!(
1954 abs < 1e-7 || relative < IP_BOUND,
1955 "got an abs/rel of {}/{} with a bound of {}/{}",
1956 abs,
1957 relative,
1958 1e-7,
1959 IP_BOUND
1960 );
1961
1962 if const { D == 1 } {
1965 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1966 / (inner_product_correction * (bv.len() as f32).sqrt());
1967 assert!(
1968 (self_inner_product - 0.8).abs() < 0.11,
1969 "self inner-product should be close to 0.8. Instead, it's {}",
1970 self_inner_product
1971 );
1972 }
1973 }
1974
1975 {
1976 let QueryMeta {
1977 inner_product_correction,
1978 bit_sum,
1979 offset,
1980 metric_specific,
1981 } = q.meta();
1982
1983 let inner_product_with_centroid = metric_specific;
1984 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1985 preprocessed
1986 .shifted
1987 .iter_mut()
1988 .for_each(|i| *i /= preprocessed.shifted_norm);
1989
1990 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1991 quantizer
1992 .transform
1993 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1994 .unwrap();
1995
1996 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1997 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1998
1999 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
2000
2001 {
2003 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
2004 let got = inner_product_correction;
2005
2006 let err = (expected - got).abs();
2007 assert!(
2008 err < 1.0e-7,
2009 "\"innerproduct_scale\": expected {}, got {}, error = {}",
2010 expected,
2011 got,
2012 err
2013 );
2014 }
2015
2016 {
2018 let expected = min / scale;
2019 let got = offset;
2020
2021 let err = (expected - got).abs();
2022 assert!(
2023 err < 1.0e-7,
2024 "\"sum_scale\": expected {}, got {}, error = {}",
2025 expected,
2026 got,
2027 err
2028 );
2029 }
2030
2031 {
2033 let expected = (0..q.len())
2034 .map(|i| q.vector().get(i).unwrap())
2035 .sum::<i64>() as f32;
2036
2037 let got = bit_sum;
2038
2039 let err = (expected - got).abs();
2040 assert!(
2041 err < 1.0e-7,
2042 "\"offset\": expected {}, got {}, error = {}",
2043 expected,
2044 got,
2045 err
2046 );
2047 }
2048
2049 {
2051 let inner_product: MathematicalValue<f32> =
2053 InnerProduct::evaluate(&*shifted, quantizer.shift());
2054
2055 let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2056 / inner_product.into_inner().abs();
2057 assert!(
2058 err < IP_BOUND,
2059 "\"offset\": expected {}, got {}, error = {}",
2060 inner_product.into_inner(),
2061 inner_product_with_centroid,
2062 err
2063 );
2064 }
2065 }
2066
2067 {
2069 let s: f32 = f.data.iter().sum::<f32>();
2071 assert_eq!(s, f.meta.sum);
2072
2073 {
2075 let expected = FastL2Norm.evaluate(&*shifted);
2076 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2077 assert!(
2078 err < 2e-7,
2079 "failed diff check, got {}, expected {} - relative error = {}",
2080 f.meta.shifted_norm,
2081 expected,
2082 err
2083 );
2084 }
2085
2086 let inner_product: MathematicalValue<f32> =
2088 InnerProduct::evaluate(&*shifted, quantizer.shift());
2089 let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2090 / inner_product.into_inner().abs();
2091 assert!(
2092 err < IP_BOUND,
2093 "\"offset\": expected {}, got {}, error = {}",
2094 inner_product.into_inner(),
2095 f.meta.metric_specific,
2096 err
2097 );
2098 }
2099 };
2100
2101 for _ in 0..setup.num_trials {
2102 let i = distribution.sample(rng);
2103 let v = problem.data.row(i);
2104 test_row(v);
2105 }
2106
2107 let zero = vec![0.0f32; quantizer.input_dim()];
2109 test_row(&zero);
2110 }
2111
2112 fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2113 where
2114 for<'a> T: ReborrowMut<'a>,
2115 for<'a> SphericalQuantizer: CompressIntoWith<
2116 &'a [f32],
2117 <T as ReborrowMut<'a>>::Target,
2118 ScopedAllocator<'a>,
2119 Error = CompressionError,
2120 >,
2121 {
2122 let mut succeeded = false;
2123 let mut failed = false;
2124 for max_allocations in 0..10 {
2125 match quantizer.compress_into_with(
2126 data,
2127 dst.reborrow_mut(),
2128 ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2129 ) {
2130 Ok(()) => {
2131 succeeded = true;
2132 }
2133 Err(CompressionError::AllocatorError(_)) => {
2134 failed = true;
2135 }
2136 Err(other) => {
2137 panic!("received an unexpected error: {:?}", other);
2138 }
2139 }
2140 }
2141 assert!(succeeded);
2142 assert!(failed);
2143 }
2144
2145 fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2146 setup: &Setup,
2147 problem: &test_util::TestProblem,
2148 pre_scale: PreScale,
2149 rng: &mut StdRng,
2150 ) where
2151 Unsigned: Representation<Q>,
2152 Unsigned: Representation<D>,
2153 Perm: PermutationStrategy<Q>,
2154 for<'a> SphericalQuantizer: CompressIntoWith<
2155 &'a [f32],
2156 DataMut<'a, D>,
2157 ScopedAllocator<'a>,
2158 Error = CompressionError,
2159 >,
2160 for<'a> SphericalQuantizer: CompressIntoWith<
2161 &'a [f32],
2162 QueryMut<'a, Q, Perm>,
2163 ScopedAllocator<'a>,
2164 Error = CompressionError,
2165 >,
2166 {
2167 assert_eq!(setup.nrows, problem.data.nrows());
2168 assert_eq!(setup.ncols, problem.data.ncols());
2169
2170 let quantizer = SphericalQuantizer::train(
2171 problem.data.as_view(),
2172 setup.transform,
2173 SupportedMetric::SquaredL2,
2174 pre_scale,
2175 rng,
2176 GlobalAllocator,
2177 )
2178 .unwrap();
2179
2180 let data = problem.data.row(0);
2182 _test_oom_resiliance::<Data<D, _>>(
2183 &quantizer,
2184 data,
2185 &mut Data::new_boxed(quantizer.output_dim()),
2186 );
2187 _test_oom_resiliance::<Query<Q, Perm, _>>(
2188 &quantizer,
2189 data,
2190 &mut Query::new_boxed(quantizer.output_dim()),
2191 );
2192 _test_oom_resiliance::<FullQuery<_>>(
2193 &quantizer,
2194 data,
2195 &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2196 );
2197 }
2198
2199 fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2200 where
2201 Unsigned: Representation<Q>,
2202 Unsigned: Representation<D>,
2203 Perm: PermutationStrategy<Q>,
2204 for<'a> SphericalQuantizer: CompressIntoWith<
2205 &'a [f32],
2206 DataMut<'a, D>,
2207 ScopedAllocator<'a>,
2208 Error = CompressionError,
2209 >,
2210 for<'a> SphericalQuantizer: CompressIntoWith<
2211 &'a [f32],
2212 QueryMut<'a, Q, Perm>,
2213 ScopedAllocator<'a>,
2214 Error = CompressionError,
2215 >,
2216 {
2217 let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2218 let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2219
2220 let scales = [
2221 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2222 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2223 PreScale::ReciprocalMeanNorm,
2224 ];
2225
2226 for scale in scales {
2227 let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2228
2229 test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2230 test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2231 test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2232 }
2233
2234 test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2235 }
2236
2237 #[test]
2238 fn test_spherical_quantizer() {
2239 let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2240 for dim in [56, 72, 128, 255] {
2241 let setup = Setup {
2242 transform: TransformKind::PaddingHadamard {
2243 target_dim: TargetDim::Same,
2244 },
2245 nrows: 64,
2246 ncols: dim,
2247 num_trials: 10,
2248 };
2249
2250 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2251 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2252 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2253 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2254
2255 let setup = Setup {
2256 transform: TransformKind::DoubleHadamard {
2257 target_dim: TargetDim::Same,
2258 },
2259 nrows: 64,
2260 ncols: dim,
2261 num_trials: 10,
2262 };
2263 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2264 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2265 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2266 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2267 }
2268 }
2269
2270 #[test]
2275 fn err_dim_cannot_be_zero() {
2276 let data = Matrix::new(0.0f32, 10, 0);
2277 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2278 let err = SphericalQuantizer::train(
2279 data.as_view(),
2280 TransformKind::DoubleHadamard {
2281 target_dim: TargetDim::Same,
2282 },
2283 SupportedMetric::SquaredL2,
2284 PreScale::None,
2285 &mut rng,
2286 GlobalAllocator,
2287 )
2288 .unwrap_err();
2289 assert_eq!(err.to_string(), "data dim cannot be zero");
2290 }
2291
2292 #[test]
2293 fn err_norm_must_be_positive() {
2294 let data = Matrix::new(0.0f32, 10, 10);
2295 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2296 let err = SphericalQuantizer::train(
2297 data.as_view(),
2298 TransformKind::DoubleHadamard {
2299 target_dim: TargetDim::Same,
2300 },
2301 SupportedMetric::SquaredL2,
2302 PreScale::None,
2303 &mut rng,
2304 GlobalAllocator,
2305 )
2306 .unwrap_err();
2307 assert_eq!(err.to_string(), "norm must be positive");
2308 }
2309
2310 #[test]
2311 fn err_norm_cannot_be_infinity() {
2312 let mut data = Matrix::new(0.0f32, 10, 10);
2313 data[(2, 5)] = f32::INFINITY;
2314
2315 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2316 let err = SphericalQuantizer::train(
2317 data.as_view(),
2318 TransformKind::DoubleHadamard {
2319 target_dim: TargetDim::Same,
2320 },
2321 SupportedMetric::SquaredL2,
2322 PreScale::None,
2323 &mut rng,
2324 GlobalAllocator,
2325 )
2326 .unwrap_err();
2327 assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2328 }
2329
2330 #[test]
2331 fn err_reciprocal_norm_cannot_be_infinity() {
2332 let mut data = Matrix::new(0.0f32, 10, 10);
2333 data[(2, 5)] = 2.93863e-39;
2334
2335 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2336 let err = SphericalQuantizer::train(
2337 data.as_view(),
2338 TransformKind::DoubleHadamard {
2339 target_dim: TargetDim::Same,
2340 },
2341 SupportedMetric::SquaredL2,
2342 PreScale::ReciprocalMeanNorm,
2343 &mut rng,
2344 GlobalAllocator,
2345 )
2346 .unwrap_err();
2347 assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2348 }
2349
2350 #[test]
2351 fn err_mean_norm_cannot_be_zero_generate() {
2352 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2353 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2354 let err = SphericalQuantizer::generate(
2355 centroid,
2356 0.0,
2357 TransformKind::DoubleHadamard {
2358 target_dim: TargetDim::Same,
2359 },
2360 SupportedMetric::SquaredL2,
2361 None,
2362 &mut rng,
2363 GlobalAllocator,
2364 )
2365 .unwrap_err();
2366 assert_eq!(err.to_string(), "norm must be positive");
2367 }
2368
2369 #[test]
2370 fn err_scale_cannot_be_zero_generate() {
2371 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2372 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2373 let err = SphericalQuantizer::generate(
2374 centroid,
2375 1.0,
2376 TransformKind::DoubleHadamard {
2377 target_dim: TargetDim::Same,
2378 },
2379 SupportedMetric::SquaredL2,
2380 Some(0.0),
2381 &mut rng,
2382 GlobalAllocator,
2383 )
2384 .unwrap_err();
2385 assert_eq!(err.to_string(), "pre-scale must be positive");
2386 }
2387
2388 #[test]
2389 fn compression_errors_data() {
2390 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2391 let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2392
2393 let quantizer = SphericalQuantizer::train(
2394 data.as_view(),
2395 TransformKind::PaddingHadamard {
2396 target_dim: TargetDim::Same,
2397 },
2398 SupportedMetric::SquaredL2,
2399 PreScale::None,
2400 &mut rng,
2401 GlobalAllocator,
2402 )
2403 .unwrap();
2404
2405 let scoped_global = ScopedAllocator::global();
2406
2407 {
2409 let mut query: Vec<f32> = quantizer.shift().to_vec();
2410 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2411 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2412
2413 for i in 0..query.len() {
2414 let last = query[i];
2415 for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2416 query[i] = v;
2417
2418 let err = quantizer
2419 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2420 .unwrap_err();
2421
2422 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2423
2424 let err = quantizer
2425 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2426 .unwrap_err();
2427
2428 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2429 }
2430 query[i] = last;
2431 }
2432 }
2433
2434 {
2436 let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2437 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2438
2439 let err = quantizer
2440 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2441 .unwrap_err();
2442
2443 let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2444
2445 assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2446 }
2447
2448 for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2450 let query = vec![0.0f32; len];
2451 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2452 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2453
2454 let err = quantizer
2455 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2456 .unwrap_err();
2457 assert_eq!(
2458 err,
2459 CompressionError::SourceDimensionMismatch {
2460 expected: quantizer.input_dim(),
2461 }
2462 );
2463
2464 let err = quantizer
2465 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2466 .unwrap_err();
2467 assert_eq!(
2468 err,
2469 CompressionError::SourceDimensionMismatch {
2470 expected: quantizer.input_dim(),
2471 }
2472 );
2473 }
2474
2475 for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2476 let query = vec![0.0f32; quantizer.input_dim()];
2477 let mut d = Data::<1, _>::new_boxed(len);
2478 let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2479
2480 let err = quantizer
2481 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2482 .unwrap_err();
2483 assert_eq!(
2484 err,
2485 CompressionError::DestinationDimensionMismatch {
2486 expected: quantizer.output_dim(),
2487 }
2488 );
2489
2490 let err = quantizer
2491 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2492 .unwrap_err();
2493 assert_eq!(
2494 err,
2495 CompressionError::DestinationDimensionMismatch {
2496 expected: quantizer.output_dim(),
2497 }
2498 );
2499 }
2500 }
2501
2502 #[test]
2503 fn centroid_scaling_happens_in_generate() {
2504 let centroid = Poly::from_iter(
2505 [1088.6732f32, 1393.32, 1547.877].into_iter(),
2506 GlobalAllocator,
2507 )
2508 .unwrap();
2509 let mean_norm = 2359.27;
2510 let pre_scale = 1.0 / mean_norm;
2511
2512 let quantizer = SphericalQuantizer::generate(
2513 centroid,
2514 mean_norm,
2515 TransformKind::Null,
2516 SupportedMetric::InnerProduct,
2517 Some(pre_scale),
2518 &mut StdRng::seed_from_u64(10),
2519 GlobalAllocator,
2520 )
2521 .unwrap();
2522
2523 let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2524 let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2525 assert!(
2526 quantizer
2527 .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2528 .is_ok(),
2529 "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2530 );
2531 }
2532}
2533
2534#[cfg(feature = "flatbuffers")]
2535#[cfg(test)]
2536mod test_serialization {
2537 use rand::{SeedableRng, rngs::StdRng};
2538
2539 use super::*;
2540 use crate::{
2541 algorithms::transforms::TargetDim,
2542 flatbuffers::{self as fb, to_flatbuffer},
2543 poly, test_util,
2544 };
2545
2546 #[test]
2547 fn test_serialization_happy_path() {
2548 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2549 let problem = test_util::create_test_problem(10, 128, &mut rng);
2550
2551 let low = NonZeroUsize::new(100).unwrap();
2552 let high = NonZeroUsize::new(150).unwrap();
2553
2554 let kinds = [
2555 TransformKind::Null,
2557 TransformKind::DoubleHadamard {
2559 target_dim: TargetDim::Same,
2560 },
2561 TransformKind::DoubleHadamard {
2562 target_dim: TargetDim::Natural,
2563 },
2564 TransformKind::DoubleHadamard {
2565 target_dim: TargetDim::Override(low),
2566 },
2567 TransformKind::DoubleHadamard {
2568 target_dim: TargetDim::Override(high),
2569 },
2570 TransformKind::PaddingHadamard {
2572 target_dim: TargetDim::Same,
2573 },
2574 TransformKind::PaddingHadamard {
2575 target_dim: TargetDim::Natural,
2576 },
2577 TransformKind::PaddingHadamard {
2578 target_dim: TargetDim::Override(low),
2579 },
2580 TransformKind::PaddingHadamard {
2581 target_dim: TargetDim::Override(high),
2582 },
2583 #[cfg(all(not(miri), feature = "linalg"))]
2585 TransformKind::RandomRotation {
2586 target_dim: TargetDim::Same,
2587 },
2588 #[cfg(all(not(miri), feature = "linalg"))]
2589 TransformKind::RandomRotation {
2590 target_dim: TargetDim::Natural,
2591 },
2592 #[cfg(all(not(miri), feature = "linalg"))]
2593 TransformKind::RandomRotation {
2594 target_dim: TargetDim::Override(low),
2595 },
2596 #[cfg(all(not(miri), feature = "linalg"))]
2597 TransformKind::RandomRotation {
2598 target_dim: TargetDim::Override(high),
2599 },
2600 ];
2601
2602 let pre_scales = [
2603 PreScale::None,
2604 PreScale::Some(Positive::new(0.5).unwrap()),
2605 PreScale::Some(Positive::new(1.0).unwrap()),
2606 PreScale::Some(Positive::new(1.5).unwrap()),
2607 PreScale::ReciprocalMeanNorm,
2608 ];
2609
2610 let alloc = GlobalAllocator;
2611 for kind in kinds.into_iter() {
2612 for metric in SupportedMetric::all() {
2613 for pre_scale in pre_scales {
2614 let quantizer = SphericalQuantizer::train(
2615 problem.data.as_view(),
2616 kind,
2617 metric,
2618 pre_scale,
2619 &mut rng,
2620 alloc,
2621 )
2622 .unwrap();
2623
2624 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2625 let proto =
2626 flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2627 let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2628 assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2629 }
2630 }
2631 }
2632 }
2633
2634 #[test]
2635 fn test_error_checking() {
2636 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2637 let problem = test_util::create_test_problem(10, 128, &mut rng);
2638
2639 let transform = TransformKind::DoubleHadamard {
2640 target_dim: TargetDim::Same,
2641 };
2642
2643 let alloc = GlobalAllocator;
2644 let mut make_quantizer = || {
2645 SphericalQuantizer::train(
2646 problem.data.as_view(),
2647 transform,
2648 SupportedMetric::SquaredL2,
2649 PreScale::None,
2650 &mut rng,
2651 alloc,
2652 )
2653 .unwrap()
2654 };
2655
2656 type E = DeserializationError;
2657
2658 {
2660 let mut quantizer = make_quantizer();
2661 quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2665
2666 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2667 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2668 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2669 assert_eq!(err, E::MissingNorm);
2670 }
2671
2672 {
2674 let mut quantizer = make_quantizer();
2675
2676 quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2680
2681 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2682 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2683 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2684 assert_eq!(err, E::MissingNorm);
2685 }
2686
2687 {
2689 let mut quantizer = make_quantizer();
2690
2691 quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2694
2695 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2696 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2697 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2698 assert_eq!(err, E::PreScaleNotPositive);
2699 }
2700
2701 {
2703 let mut quantizer = make_quantizer();
2704 quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2705
2706 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2707 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2708 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2709 assert_eq!(err, E::DimMismatch);
2710 }
2711 }
2712}