1use std::num::NonZeroUsize;
7
8use diskann_utils::{ReborrowMut, views::MatrixView};
9use diskann_vector::{
10 MathematicalValue, Norm, PureDistanceFunction, distance::InnerProduct, norm::FastL2Norm,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18 CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19 FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21use crate::{
22 AsFunctor, CompressIntoWith,
23 algorithms::{
24 heap::SliceHeap,
25 transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
26 },
27 alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
28 bits::{PermutationStrategy, Representation, Unsigned},
29 num::Positive,
30 utils::{CannotBeEmpty, compute_means_and_average_norm, compute_normalized_means},
31};
32#[cfg(feature = "flatbuffers")]
33use crate::{
34 algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
35};
36
37#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45 A: Allocator,
46{
47 shift: Poly<[f32], A>,
49
50 transform: Transform<A>,
58
59 metric: SupportedMetric,
61
62 mean_norm: Positive<f32>,
71
72 pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82 A: Allocator,
83{
84 fn try_clone(&self) -> Result<Self, AllocatorError> {
85 Ok(Self {
86 shift: self.shift.try_clone()?,
87 transform: self.transform.try_clone()?,
88 metric: self.metric,
89 mean_norm: self.mean_norm,
90 pre_scale: self.pre_scale,
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98 #[error("data dim cannot be zero")]
99 DimCannotBeZero,
100 #[error("data cannot be empty")]
101 DataCannotBeEmpty,
102 #[error("pre-scale must be positive")]
103 PrescaleNotPositive,
104 #[error("norm must be positive")]
105 NormNotPositive,
106 #[error("computed norm contains infinity or NaN")]
107 NormNotFinite,
108 #[error("reciprocal norm contains infinity or NaN")]
109 ReciprocalNormNotFinite,
110 #[error(transparent)]
111 AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116 A: Allocator,
117{
118 pub fn input_dim(&self) -> usize {
120 self.shift.len()
121 }
122
123 pub fn output_dim(&self) -> usize {
128 self.transform.output_dim()
129 }
130
131 pub fn shift(&self) -> &[f32] {
138 &self.shift
139 }
140
141 pub fn mean_norm(&self) -> Positive<f32> {
143 self.mean_norm
144 }
145
146 pub fn pre_scale(&self) -> Positive<f32> {
151 self.pre_scale
152 }
153
154 pub fn allocator(&self) -> &A {
156 self.shift.allocator()
157 }
158
159 pub fn generate(
161 mut centroid: Poly<[f32], A>,
162 mean_norm: f32,
163 transform: TransformKind,
164 metric: SupportedMetric,
165 pre_scale: Option<f32>,
166 rng: &mut dyn RngCore,
167 allocator: A,
168 ) -> Result<Self, TrainError> {
169 let pre_scale = match pre_scale {
170 Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171 None => crate::num::POSITIVE_ONE_F32,
172 };
173
174 let dim = match NonZeroUsize::new(centroid.len()) {
175 Some(dim) => dim,
176 None => {
177 return Err(TrainError::DimCannotBeZero);
178 }
179 };
180
181 let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183 let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185 Ok(v) => v,
186 Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187 Err(NewTransformError::AllocatorError(err)) => {
188 return Err(TrainError::AllocatorError(err));
189 }
190 };
191
192 centroid
194 .iter_mut()
195 .for_each(|v| *v *= pre_scale.into_inner());
196
197 Ok(SphericalQuantizer {
198 shift: centroid,
199 transform,
200 metric,
201 mean_norm,
202 pre_scale,
203 })
204 }
205
206 pub fn metric(&self) -> SupportedMetric {
208 self.metric
209 }
210
211 pub fn train<T, R>(
226 data: MatrixView<T>,
227 transform: TransformKind,
228 metric: SupportedMetric,
229 pre_scale: PreScale,
230 rng: &mut R,
231 allocator: A,
232 ) -> Result<Self, TrainError>
233 where
234 T: Copy + Into<f64> + Into<f32>,
235 R: Rng,
236 {
237 #[inline(never)]
240 fn train<T, A>(
241 data: MatrixView<T>,
242 transform: TransformKind,
243 metric: SupportedMetric,
244 pre_scale: PreScale,
245 rng: &mut dyn RngCore,
246 allocator: A,
247 ) -> Result<SphericalQuantizer<A>, TrainError>
248 where
249 T: Copy + Into<f64> + Into<f32>,
250 A: Allocator,
251 {
252 if data.ncols() == 0 {
255 return Err(TrainError::DimCannotBeZero);
256 }
257
258 let (centroid, mean_norm) = match metric {
259 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260 compute_means_and_average_norm(data)
261 }
262 SupportedMetric::Cosine => (
263 compute_normalized_means(data)
264 .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265 1.0,
266 ),
267 };
268
269 let mean_norm = mean_norm as f32;
270
271 if mean_norm <= 0.0 {
272 return Err(TrainError::NormNotPositive);
273 }
274
275 if !mean_norm.is_finite() {
276 return Err(TrainError::NormNotFinite);
277 }
278
279 let pre_scale: Positive<f32> = match pre_scale {
281 PreScale::None => crate::num::POSITIVE_ONE_F32,
282 PreScale::Some(v) => v,
283 PreScale::ReciprocalMeanNorm => {
284 let pre_scale = Positive::new(1.0 / mean_norm)
293 .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295 if !pre_scale.into_inner().is_finite() {
296 return Err(TrainError::ReciprocalNormNotFinite);
297 }
298
299 pre_scale
300 }
301 };
302
303 let centroid =
305 Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307 SphericalQuantizer::generate(
308 centroid,
309 mean_norm,
310 transform,
311 metric,
312 Some(pre_scale.into_inner()),
313 rng,
314 allocator,
315 )
316 }
317
318 train(data, transform, metric, pre_scale, rng, allocator)
319 }
320
321 pub fn rescale(&self, v: &mut [f32]) {
323 let norm = FastL2Norm.evaluate(&*v);
324 let m = self.mean_norm.into_inner() / norm;
325 v.iter_mut().for_each(|i| *i *= m);
326 }
327
328 fn preprocess<'a>(
334 &self,
335 data: &[f32],
336 allocator: ScopedAllocator<'a>,
337 ) -> Result<Preprocessed<'a>, CompressionError> {
338 assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340 let scale = self.pre_scale.into_inner();
345 let mul: f32 = match self.metric {
346 SupportedMetric::Cosine => {
347 let norm: f32 = (FastL2Norm).evaluate(data);
348 if norm == 0.0 { 1.0 } else { 1.0 / norm }
349 }
350 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
351 };
352
353 let shifted = Poly::from_iter(
355 std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
356 allocator,
357 )?;
358
359 let shifted_norm = FastL2Norm.evaluate(&*shifted);
360 if !shifted_norm.is_finite() {
361 return Err(CompressionError::InputContainsNaN);
362 }
363 let inner_product_with_centroid = match self.metric {
364 SupportedMetric::SquaredL2 => None,
365 SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
366 let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
367 Some(ip.into_inner())
368 }
369 };
370
371 Ok(Preprocessed {
372 shifted,
373 shifted_norm,
374 inner_product_with_centroid,
375 })
376 }
377}
378
379#[derive(Debug, Clone, Copy)]
383pub enum PreScale {
384 None,
386 Some(Positive<f32>),
388 ReciprocalMeanNorm,
391}
392
393#[cfg(feature = "flatbuffers")]
394#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
395#[derive(Debug, Clone, Error, PartialEq)]
396#[non_exhaustive]
397pub enum DeserializationError {
398 #[error(transparent)]
399 TransformError(#[from] TransformError),
400 #[error("unrecognized flatbuffer identifier")]
401 UnrecognizedIdentifier,
402 #[error("transform length not equal to centroid")]
403 DimMismatch,
404 #[error("norm is missing or is not positive")]
405 MissingNorm,
406 #[error("pre-scale is missing or is not positive")]
407 PreScaleNotPositive,
408
409 #[error(transparent)]
410 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
411
412 #[error(transparent)]
413 InvalidMetric(#[from] InvalidMetric),
414
415 #[error(transparent)]
416 AllocatorError(#[from] AllocatorError),
417}
418
419#[cfg(feature = "flatbuffers")]
420#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
421impl<A> SphericalQuantizer<A>
422where
423 A: Allocator + Clone,
424{
425 pub(crate) fn pack<'a, FA>(
428 &self,
429 buf: &mut FlatBufferBuilder<'a, FA>,
430 ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
431 where
432 FA: flatbuffers::Allocator + 'a,
433 {
434 let centroid = buf.create_vector(&self.shift);
436
437 let transform = self.transform.pack(buf);
439
440 spherical::SphericalQuantizer::create(
442 buf,
443 &spherical::SphericalQuantizerArgs {
444 centroid: Some(centroid),
445 transform: Some(transform),
446 metric: self.metric.into(),
447 mean_norm: self.mean_norm.into_inner(),
448 pre_scale: self.pre_scale.into_inner(),
449 },
450 )
451 }
452
453 pub(crate) fn try_unpack(
456 alloc: A,
457 proto: spherical::SphericalQuantizer<'_>,
458 ) -> Result<Self, DeserializationError> {
459 let metric: SupportedMetric = proto.metric().try_into()?;
460
461 let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
463
464 let transform = Transform::try_unpack(alloc, proto.transform())?;
466
467 if shift.len() != transform.input_dim() {
469 return Err(DeserializationError::DimMismatch);
470 }
471
472 let mean_norm =
474 Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
475
476 let pre_scale = Positive::new(proto.pre_scale())
477 .map_err(|_| DeserializationError::PreScaleNotPositive)?;
478
479 Ok(Self {
480 shift,
481 transform,
482 metric,
483 mean_norm,
484 pre_scale,
485 })
486 }
487}
488
489struct Preprocessed<'a> {
490 shifted: Poly<[f32], ScopedAllocator<'a>>,
491 shifted_norm: f32,
492 inner_product_with_centroid: Option<f32>,
493}
494
495impl Preprocessed<'_> {
496 fn metric_specific(&self) -> f32 {
501 match self.inner_product_with_centroid {
502 Some(ip) => ip,
503 None => self.shifted_norm * self.shifted_norm,
504 }
505 }
506}
507
508impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
513where
514 A: Allocator,
515{
516 fn as_functor(&self) -> CompensatedSquaredL2 {
517 CompensatedSquaredL2::new(self.output_dim())
518 }
519}
520
521impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
522where
523 A: Allocator,
524{
525 fn as_functor(&self) -> CompensatedIP {
526 CompensatedIP::new(&self.shift, self.output_dim())
527 }
528}
529
530impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
531where
532 A: Allocator,
533{
534 fn as_functor(&self) -> CompensatedCosine {
535 CompensatedCosine::new(self.as_functor())
536 }
537}
538
539#[derive(Debug, Error, Clone, Copy, PartialEq)]
544#[non_exhaustive]
545pub enum CompressionError {
546 #[error("input contains NaN")]
547 InputContainsNaN,
548
549 #[error("expected source vector to have length {expected}")]
550 SourceDimensionMismatch { expected: usize },
551
552 #[error("expected destination vector to have length {expected}")]
553 DestinationDimensionMismatch { expected: usize },
554
555 #[error(
556 "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
557 )]
558 EncodingError(#[from] DataMetaError),
559
560 #[error(transparent)]
561 AllocatorError(#[from] AllocatorError),
562}
563
564fn check_dims(
565 input: usize,
566 output: usize,
567 from: usize,
568 into: usize,
569) -> Result<(), CompressionError> {
570 if from != input {
571 return Err(CompressionError::SourceDimensionMismatch { expected: input });
572 }
573 if into != output {
574 return Err(CompressionError::DestinationDimensionMismatch { expected: output });
575 }
576 Ok(())
577}
578
579trait FinishCompressing {
582 fn finish_compressing(
583 &mut self,
584 preprocessed: &Preprocessed<'_>,
585 transformed: &[f32],
586 transformed_norm: f32,
587 allocator: ScopedAllocator<'_>,
588 ) -> Result<(), CompressionError>;
589}
590
591impl FinishCompressing for DataMut<'_, 1> {
592 fn finish_compressing(
593 &mut self,
594 preprocessed: &Preprocessed<'_>,
595 transformed: &[f32],
596 transformed_norm: f32,
597 _: ScopedAllocator<'_>,
598 ) -> Result<(), CompressionError> {
599 let mut quant_raw_inner_product = 0.0f32;
602 let mut bit_sum = 0u32;
603 transformed.iter().enumerate().for_each(|(i, &r)| {
604 let bit: u8 = if r > 0.0 { 1 } else { 0 };
605
606 quant_raw_inner_product += r.abs();
607 bit_sum += <u8 as Into<u32>>::into(bit);
608
609 unsafe { self.vector_mut().set_unchecked(i, bit) };
611 });
612
613 let inner_product_correction =
634 2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
635 self.set_meta(DataMeta::new(
636 inner_product_correction,
637 preprocessed.metric_specific(),
638 bit_sum,
639 )?);
640 Ok(())
641 }
642}
643
644impl FinishCompressing for DataMut<'_, 2> {
645 fn finish_compressing(
646 &mut self,
647 preprocessed: &Preprocessed<'_>,
648 transformed: &[f32],
649 transformed_norm: f32,
650 allocator: ScopedAllocator<'_>,
651 ) -> Result<(), CompressionError> {
652 compress_via_maximum_cosine(
653 self.reborrow_mut(),
654 preprocessed,
655 transformed,
656 transformed_norm,
657 allocator,
658 )
659 }
660}
661
662impl FinishCompressing for DataMut<'_, 4> {
663 fn finish_compressing(
664 &mut self,
665 preprocessed: &Preprocessed<'_>,
666 transformed: &[f32],
667 transformed_norm: f32,
668 allocator: ScopedAllocator<'_>,
669 ) -> Result<(), CompressionError> {
670 compress_via_maximum_cosine(
671 self.reborrow_mut(),
672 preprocessed,
673 transformed,
674 transformed_norm,
675 allocator,
676 )
677 }
678}
679
680impl FinishCompressing for DataMut<'_, 8> {
681 fn finish_compressing(
682 &mut self,
683 preprocessed: &Preprocessed<'_>,
684 transformed: &[f32],
685 transformed_norm: f32,
686 allocator: ScopedAllocator<'_>,
687 ) -> Result<(), CompressionError> {
688 compress_via_maximum_cosine(
689 self.reborrow_mut(),
690 preprocessed,
691 transformed,
692 transformed_norm,
693 allocator,
694 )
695 }
696}
697
698impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
703where
704 A: Allocator,
705{
706 type Error = CompressionError;
707
708 fn compress_into_with(
721 &self,
722 from: &[f32],
723 mut into: FullQueryMut<'_>,
724 allocator: ScopedAllocator<'_>,
725 ) -> Result<(), Self::Error> {
726 let input_dim = self.shift.len();
727 let output_dim = self.output_dim();
728 check_dims(input_dim, output_dim, from.len(), into.len())?;
729
730 let mut preprocessed = self.preprocess(from, allocator)?;
731
732 if preprocessed.shifted_norm == 0.0 {
735 into.vector_mut().fill(0.0);
736 *into.meta_mut() = Default::default();
737 return Ok(());
738 }
739
740 preprocessed
741 .shifted
742 .iter_mut()
743 .for_each(|v| *v /= preprocessed.shifted_norm);
744
745 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
750 match self
751 .transform
752 .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
753 {
754 Ok(()) => {}
755 Err(TransformFailed::AllocatorError(err)) => {
756 return Err(CompressionError::AllocatorError(err));
757 }
758 Err(TransformFailed::SourceMismatch { .. })
759 | Err(TransformFailed::DestinationMismatch { .. }) => {
760 panic!(
761 "The sizes of these arrays should already be checked - this is a logic error"
762 );
763 }
764 }
765
766 *into.meta_mut() = FullQueryMeta {
767 sum: into.vector().iter().sum::<f32>(),
768 shifted_norm: preprocessed.shifted_norm,
769 metric_specific: preprocessed.metric_specific(),
770 };
771 Ok(())
772 }
773}
774
775impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
776 for SphericalQuantizer<A>
777where
778 A: Allocator,
779 Unsigned: Representation<NBITS>,
780 for<'a> DataMut<'a, NBITS>: FinishCompressing,
781{
782 type Error = CompressionError;
783
784 fn compress_into_with(
797 &self,
798 from: &[f32],
799 mut into: DataMut<'_, NBITS>,
800 allocator: ScopedAllocator<'_>,
801 ) -> Result<(), Self::Error> {
802 let input_dim = self.shift.len();
803 let output_dim = self.output_dim();
804 check_dims(input_dim, output_dim, from.len(), into.len())?;
805
806 let mut preprocessed = self.preprocess(from, allocator)?;
807
808 if preprocessed.shifted_norm == 0.0 {
809 into.set_meta(DataMeta::default());
810 return Ok(());
811 }
812
813 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
814 preprocessed
815 .shifted
816 .iter_mut()
817 .for_each(|v| *v /= preprocessed.shifted_norm);
818
819 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
824 match self
825 .transform
826 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
827 {
828 Ok(()) => {}
829 Err(TransformFailed::AllocatorError(err)) => {
830 return Err(CompressionError::AllocatorError(err));
831 }
832 Err(TransformFailed::SourceMismatch { .. })
833 | Err(TransformFailed::DestinationMismatch { .. }) => {
834 panic!(
835 "The sizes of these arrays should already be checked - this is a logic error"
836 );
837 }
838 }
839
840 let transformed_norm = if self.transform.preserves_norms() {
841 1.0
842 } else {
843 (FastL2Norm).evaluate(&*transformed)
844 };
845
846 into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
847 Ok(())
848 }
849}
850
851struct AsNonZero<const NBITS: usize>;
852impl<const NBITS: usize> AsNonZero<NBITS> {
853 #[allow(clippy::unwrap_used)]
855 const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
856}
857
858fn compress_via_maximum_cosine<const NBITS: usize>(
859 mut data: DataMut<'_, NBITS>,
860 preprocessed: &Preprocessed<'_>,
861 transformed: &[f32],
862 transformed_norm: f32,
863 allocator: ScopedAllocator<'_>,
864) -> Result<(), CompressionError>
865where
866 Unsigned: Representation<NBITS>,
867{
868 assert_eq!(data.len(), transformed.len());
869
870 let optimal_scale =
873 maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
874
875 let domain = Unsigned::domain_const::<NBITS>();
876 let min = *domain.start() as f32;
877 let max = *domain.end() as f32;
878 let offset = max / 2.0;
879
880 let mut self_inner_product = 0.0f32;
881 let mut bit_sum = 0u32;
882 for (i, t) in transformed.iter().enumerate() {
883 let v = (*t * optimal_scale + offset).clamp(min, max).round();
884 let dv = v - offset;
885 self_inner_product = dv.mul_add(*t, self_inner_product);
886
887 let v = v as u8;
888 bit_sum += <u8 as Into<u32>>::into(v);
889
890 unsafe { data.vector_mut().set_unchecked(i, v) };
895 }
896
897 let shifted_norm = preprocessed.shifted_norm;
898 let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
899 data.set_meta(DataMeta::new(
900 inner_product_correction,
901 preprocessed.metric_specific(),
902 bit_sum,
903 )?);
904 Ok(())
905}
906
907#[derive(Debug, Clone, Copy)]
921struct Pair {
922 value: f32,
923 position: u32,
924}
925
926impl PartialEq for Pair {
927 fn eq(&self, other: &Self) -> bool {
928 self.value.eq(&other.value)
929 }
930}
931
932impl Eq for Pair {}
933impl PartialOrd for Pair {
934 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
935 Some(self.cmp(other))
936 }
937}
938impl Ord for Pair {
939 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
940 other
941 .value
942 .partial_cmp(&self.value)
943 .unwrap_or(std::cmp::Ordering::Equal)
944 }
945}
946
947fn maximize_cosine_similarity(
979 v: &[f32],
980 num_bits: NonZeroUsize,
981 allocator: ScopedAllocator<'_>,
982) -> Result<f32, AllocatorError> {
983 let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
988 let mut current_square_norm = 0.25 * (v.len() as f64);
989
990 let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
994
995 let eps = 0.0001f32;
1001 let one_and_change = 1.0 + eps;
1002 let mut base = Poly::from_iter(
1003 v.iter().enumerate().map(|(position, value)| {
1004 let value = one_and_change / value.abs();
1005 Pair {
1006 value,
1007 position: position as u32,
1008 }
1009 }),
1010 allocator,
1011 )?;
1012
1013 #[allow(clippy::expect_used)]
1016 let mut critical_values =
1017 SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1018
1019 let mut max_similarity = f64::NEG_INFINITY;
1020 let mut optimal_scale = f32::default();
1021 let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1022
1023 loop {
1024 let mut should_break = false;
1025 critical_values.update_root(|pair| {
1026 let Pair { value, position } = *pair;
1027 if value == f32::MAX {
1028 should_break = true;
1029 return;
1030 }
1031
1032 let r = &mut rounded[position as usize];
1033 let vp = &v[position as usize];
1034
1035 let old_r = *r;
1036 *r += 1;
1039
1040 current_ip += vp.abs() as f64;
1045
1046 current_square_norm += (2 * old_r) as f64;
1060
1061 let similarity = current_ip / current_square_norm.sqrt();
1063 if similarity > max_similarity {
1064 max_similarity = similarity;
1065 optimal_scale = value;
1066 }
1067
1068 if *r < stop {
1070 *pair = Pair {
1071 value: (*r as f32 + eps) / vp.abs(),
1072 position,
1073 };
1074 } else {
1075 *pair = Pair {
1076 value: f32::MAX,
1077 position,
1078 };
1079 }
1080 });
1081 if should_break {
1082 break;
1083 }
1084 }
1085
1086 Ok(optimal_scale)
1087}
1088
1089impl<const NBITS: usize, Perm, A>
1094 CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1095 for SphericalQuantizer<A>
1096where
1097 Unsigned: Representation<NBITS>,
1098 Perm: PermutationStrategy<NBITS>,
1099 A: Allocator,
1100{
1101 type Error = CompressionError;
1102
1103 fn compress_into_with(
1116 &self,
1117 from: &[f32],
1118 mut into: QueryMut<'_, NBITS, Perm>,
1119 allocator: ScopedAllocator<'_>,
1120 ) -> Result<(), Self::Error> {
1121 let input_dim = self.shift.len();
1122 let output_dim = self.output_dim();
1123 check_dims(input_dim, output_dim, from.len(), into.len())?;
1124
1125 let mut preprocessed = self.preprocess(from, allocator)?;
1126
1127 if preprocessed.shifted_norm == 0.0 {
1128 into.set_meta(QueryMeta::default());
1129 return Ok(());
1130 }
1131
1132 preprocessed
1133 .shifted
1134 .iter_mut()
1135 .for_each(|v| *v /= preprocessed.shifted_norm);
1136
1137 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1138
1139 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1144 match self
1145 .transform
1146 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1147 {
1148 Ok(()) => {}
1149 Err(TransformFailed::AllocatorError(err)) => {
1150 return Err(CompressionError::AllocatorError(err));
1151 }
1152 Err(TransformFailed::SourceMismatch { .. })
1153 | Err(TransformFailed::DestinationMismatch { .. }) => {
1154 panic!(
1155 "The sizes of these arrays should already be checked - this is a logic error"
1156 );
1157 }
1158 }
1159
1160 let (min, max) = transformed
1162 .iter()
1163 .fold((f32::MAX, f32::MIN), |(min, max), i| {
1164 (i.min(min), i.max(max))
1165 });
1166
1167 let domain = Unsigned::domain_const::<NBITS>();
1168 let lo = (*domain.start()) as f32;
1169 let hi = (*domain.end()) as f32;
1170
1171 let scale = (max - min) / hi;
1172 let mut bit_sum: f32 = 0.0;
1173 transformed.iter().enumerate().for_each(|(i, v)| {
1174 let c = ((v - min) / scale).round().clamp(lo, hi);
1175 bit_sum += c;
1176
1177 #[allow(clippy::unwrap_used)]
1183 into.vector_mut().set(i, c as i64).unwrap();
1184 });
1185
1186 into.set_meta(QueryMeta {
1188 inner_product_correction: preprocessed.shifted_norm * scale,
1189 bit_sum,
1190 offset: min / scale,
1191 metric_specific: preprocessed.metric_specific(),
1192 });
1193
1194 Ok(())
1195 }
1196}
1197
1198#[cfg(not(miri))]
1203#[cfg(test)]
1204mod tests {
1205 use super::*;
1206
1207 use std::fmt::Display;
1208
1209 use diskann_utils::{
1210 ReborrowMut, lazy_format,
1211 views::{self, Matrix},
1212 };
1213 use diskann_vector::{PureDistanceFunction, norm::FastL2NormSquared};
1214 use diskann_wide::ARCH;
1215 use rand::{
1216 SeedableRng,
1217 distr::{Distribution, Uniform},
1218 rngs::StdRng,
1219 };
1220 use rand_distr::StandardNormal;
1221
1222 use crate::{
1223 algorithms::transforms::TargetDim,
1224 alloc::GlobalAllocator,
1225 bits::{BitTranspose, Dense},
1226 spherical::{Data, DataMetaF32, FullQuery, Query},
1227 test_util,
1228 };
1229
1230 #[test]
1232 fn test_cosine_similarity_maximizer() {
1233 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1234 let num_trials = 10000;
1235 let num_bits = NonZeroUsize::new(3).unwrap();
1236
1237 let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1238
1239 let run_test = |target: [f32; 4]| {
1240 let scale =
1241 maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1242
1243 let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1244 let mut best_similarity: f32 = f32::NEG_INFINITY;
1245
1246 let min = -3.5;
1249 for i0 in 0..8 {
1250 for i1 in 0..8 {
1251 for i2 in 0..8 {
1252 for i3 in 0..8 {
1253 let p: [f32; 4] = [
1254 min + (i0 as f32),
1255 min + (i1 as f32),
1256 min + (i2 as f32),
1257 min + (i3 as f32),
1258 ];
1259
1260 let sim: MathematicalValue<f32> =
1261 diskann_vector::distance::Cosine::evaluate(&p, &target);
1262 let sim = sim.into_inner();
1263 if sim > best_similarity {
1264 best_similarity = sim;
1265 best = p.map(|i| i - min);
1267 }
1268 }
1269 }
1270 }
1271 }
1272
1273 let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1276 let clamped_cosine: MathematicalValue<f32> =
1277 diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1278
1279 let passed = if best == clamped {
1282 true
1283 } else {
1284 let ratio: Vec<f32> = std::iter::zip(best, clamped)
1285 .map(|(b, c)| {
1286 let ratio = (b + min) / (c + min);
1287 assert_ne!(
1288 ratio, 0.0,
1289 "ratio should never be zero because `b` is an integer and \
1290 `min` is not"
1291 );
1292 ratio
1293 })
1294 .collect();
1295
1296 ratio.iter().all(|i| *i == ratio[0])
1297 };
1298
1299 if !passed {
1300 panic!(
1301 "failed for input {:?}.\
1302 Best = {:?}, Found = {:?}\
1303 Best similarity = {}, similarity with clamped = {}",
1304 target,
1305 best,
1306 clamped,
1307 best_similarity,
1308 clamped_cosine.into_inner()
1309 );
1310 }
1311 };
1312
1313 let min = -3.5;
1315 for i0 in (0..8).step_by(2) {
1316 for i1 in (1..9).step_by(2) {
1317 for i2 in (0..8).step_by(2) {
1318 for i3 in (1..9).step_by(2) {
1319 let p: [f32; 4] = [
1320 min + (i0 as f32),
1321 min + (i1 as f32),
1322 min + (i2 as f32),
1323 min + (i3 as f32),
1324 ];
1325 run_test(p)
1326 }
1327 }
1328 }
1329 }
1330
1331 for _ in 0..num_trials {
1332 let this_scale: f32 = scale_distribution.sample(&mut rng);
1333 let v: [f32; 4] = [(); 4].map(|_| {
1334 let v: f32 = StandardNormal {}.sample(&mut rng);
1335 this_scale * v
1336 });
1337 run_test(v);
1338 }
1339 }
1340
1341 #[test]
1342 #[should_panic(expected = "calling code should not allow the slice to be empty")]
1343 fn empty_slice_panics() {
1344 maximize_cosine_similarity(
1345 &[],
1346 NonZeroUsize::new(4).unwrap(),
1347 ScopedAllocator::global(),
1348 )
1349 .unwrap();
1350 }
1351
1352 struct Setup {
1353 transform: TransformKind,
1354 nrows: usize,
1355 ncols: usize,
1356 num_trials: usize,
1357 }
1358
1359 fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1360 match scale {
1361 PreScale::None => 1.0,
1362 PreScale::Some(v) => v.into_inner(),
1363 PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1364 }
1365 }
1366
1367 fn test_l2<const Q: usize, const D: usize, Perm>(
1368 setup: &Setup,
1369 problem: &test_util::TestProblem,
1370 computed_means: &[f32],
1371 pre_scale: PreScale,
1372 rng: &mut StdRng,
1373 ) where
1374 Unsigned: Representation<Q>,
1375 Unsigned: Representation<D>,
1376 Perm: PermutationStrategy<Q>,
1377 for<'a> SphericalQuantizer:
1378 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1379 for<'a> SphericalQuantizer:
1380 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1381 {
1382 assert_eq!(setup.nrows, problem.data.nrows());
1383 assert_eq!(setup.ncols, problem.data.ncols());
1384
1385 let scoped_global = ScopedAllocator::global();
1386 let distribution = Uniform::new(0, setup.nrows).unwrap();
1387 let quantizer = SphericalQuantizer::train(
1388 problem.data.as_view(),
1389 setup.transform,
1390 SupportedMetric::SquaredL2,
1391 pre_scale,
1392 rng,
1393 GlobalAllocator,
1394 )
1395 .unwrap();
1396
1397 let scale = get_scale(pre_scale, &quantizer);
1398
1399 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1400 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1401 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1402
1403 assert_eq!(
1404 quantizer.mean_norm.into_inner(),
1405 problem.mean_norm as f32,
1406 "computed mean norm should not apply scale"
1407 );
1408 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1409 assert_eq!(&*scaled_means, quantizer.shift());
1410
1411 let l2: CompensatedSquaredL2 = quantizer.as_functor();
1412 assert_eq!(l2.dim, quantizer.output_dim() as f32);
1413
1414 for _ in 0..setup.num_trials {
1415 let i = distribution.sample(rng);
1416 let v = problem.data.row(i);
1417
1418 quantizer
1419 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1420 .unwrap();
1421 quantizer
1422 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1423 .unwrap();
1424 quantizer
1425 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1426 .unwrap();
1427
1428 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1429 .map(|(a, b)| scale * a - b)
1430 .collect();
1431
1432 {
1434 let DataMetaF32 {
1435 inner_product_correction,
1436 bit_sum,
1437 metric_specific,
1438 } = b.meta().to_full(ARCH);
1439
1440 let shifted_square_norm = metric_specific;
1441
1442 let bv = b.vector();
1444 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1445 assert_eq!(s, bit_sum as usize);
1446
1447 {
1449 let expected = FastL2NormSquared.evaluate(&*shifted);
1450 let err = (shifted_square_norm - expected).abs() / expected.abs();
1451 assert!(
1452 err < 5.0e-4, "failed diff check, got {}, expected {} - relative error = {}",
1454 shifted_square_norm,
1455 expected,
1456 err
1457 );
1458 }
1459
1460 if const { D == 1 } {
1463 let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1464 / (inner_product_correction * (bv.len() as f32).sqrt());
1465 assert!(
1466 (self_inner_product - 0.8).abs() < 0.13,
1467 "self inner-product should be close to 0.8. Instead, it's {}",
1468 self_inner_product
1469 );
1470 }
1471 }
1472
1473 {
1474 let QueryMeta {
1475 inner_product_correction,
1476 bit_sum,
1477 offset,
1478 metric_specific,
1479 } = q.meta();
1480
1481 let shifted_square_norm = metric_specific;
1482 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1483 preprocessed
1484 .shifted
1485 .iter_mut()
1486 .for_each(|i| *i /= preprocessed.shifted_norm);
1487
1488 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1489 quantizer
1490 .transform
1491 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1492 .unwrap();
1493
1494 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1495 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1496
1497 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1498
1499 {
1501 let expected = FastL2NormSquared.evaluate(&*shifted);
1502 let err = (shifted_square_norm - expected).abs() / expected.abs();
1503 assert!(
1504 err < 2e-7,
1505 "failed diff check, got {}, expected {} - relative error = {}",
1506 shifted_square_norm,
1507 expected,
1508 err
1509 );
1510 }
1511
1512 {
1514 let expected = shifted_square_norm.sqrt() * scale;
1515 let got = inner_product_correction;
1516
1517 let err = (expected - got).abs();
1518 assert!(
1519 err < 1.0e-7,
1520 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1521 expected,
1522 got,
1523 err
1524 );
1525 }
1526
1527 {
1529 let expected = min / scale;
1530 let got = offset;
1531
1532 let err = (expected - got).abs();
1533 assert!(
1534 err < 1.0e-7,
1535 "\"sum_scale\": expected {}, got {}, error = {}",
1536 expected,
1537 got,
1538 err
1539 );
1540 }
1541
1542 {
1544 let expected = (0..q.len())
1545 .map(|i| q.vector().get(i).unwrap())
1546 .sum::<i64>() as f32;
1547
1548 let got = bit_sum;
1549
1550 let err = (expected - got).abs();
1551 assert!(
1552 err < 1.0e-7,
1553 "\"offset\": expected {}, got {}, error = {}",
1554 expected,
1555 got,
1556 err
1557 );
1558 }
1559 }
1560
1561 {
1563 let s: f32 = f.data.iter().sum::<f32>();
1565 assert_eq!(s, f.meta.sum);
1566
1567 {
1569 let expected = FastL2Norm.evaluate(&*shifted);
1570 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1571 assert!(
1572 err < 2e-7,
1573 "failed diff check, got {}, expected {} - relative error = {}",
1574 f.meta.shifted_norm,
1575 expected,
1576 err
1577 );
1578 }
1579
1580 assert_eq!(
1581 f.meta.metric_specific,
1582 f.meta.shifted_norm * f.meta.shifted_norm,
1583 "metric specific data for squared l2 is the square shifted norm",
1584 );
1585 }
1586 }
1587
1588 quantizer
1591 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1592 .unwrap();
1593 assert_eq!(b.meta(), DataMeta::default());
1594
1595 quantizer
1596 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1597 .unwrap();
1598 assert_eq!(q.meta(), QueryMeta::default());
1599
1600 f.data.fill(f32::INFINITY);
1601 quantizer
1602 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1603 .unwrap();
1604 assert!(f.data.iter().all(|&i| i == 0.0));
1605 assert_eq!(f.meta.sum, 0.0);
1606 assert_eq!(f.meta.metric_specific, 0.0);
1607 }
1608
1609 fn test_ip<const Q: usize, const D: usize, Perm>(
1610 setup: &Setup,
1611 problem: &test_util::TestProblem,
1612 computed_means: &[f32],
1613 pre_scale: PreScale,
1614 rng: &mut StdRng,
1615 ctx: &dyn Display,
1616 ) where
1617 Unsigned: Representation<Q>,
1618 Unsigned: Representation<D>,
1619 Perm: PermutationStrategy<Q>,
1620 for<'a> SphericalQuantizer:
1621 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1622 for<'a> SphericalQuantizer:
1623 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1624 {
1625 assert_eq!(setup.nrows, problem.data.nrows());
1626 assert_eq!(setup.ncols, problem.data.ncols());
1627
1628 let scoped_global = ScopedAllocator::global();
1629 let distribution = Uniform::new(0, setup.nrows).unwrap();
1630 let quantizer = SphericalQuantizer::train(
1631 problem.data.as_view(),
1632 setup.transform,
1633 SupportedMetric::InnerProduct,
1634 pre_scale,
1635 rng,
1636 GlobalAllocator,
1637 )
1638 .unwrap();
1639
1640 let scale = get_scale(pre_scale, &quantizer);
1641
1642 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1643 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1644 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1645
1646 assert_eq!(
1647 quantizer.mean_norm.into_inner(),
1648 problem.mean_norm as f32,
1649 "computed mean norm should not apply scale"
1650 );
1651 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1652 assert_eq!(&*scaled_means, quantizer.shift());
1653
1654 let ip: CompensatedIP = quantizer.as_functor();
1655
1656 assert_eq!(ip.dim, quantizer.output_dim() as f32);
1657 assert_eq!(
1658 ip.squared_shift_norm,
1659 FastL2NormSquared.evaluate(quantizer.shift())
1660 );
1661
1662 for _ in 0..setup.num_trials {
1663 let i = distribution.sample(rng);
1664 let v = problem.data.row(i);
1665
1666 quantizer
1667 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1668 .unwrap();
1669 quantizer
1670 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1671 .unwrap();
1672 quantizer
1673 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1674 .unwrap();
1675
1676 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1677 .map(|(a, b)| scale * a - b)
1678 .collect();
1679
1680 {
1682 let DataMetaF32 {
1683 inner_product_correction,
1684 bit_sum,
1685 metric_specific,
1686 } = b.meta().to_full(ARCH);
1687
1688 let inner_product_with_centroid = metric_specific;
1689
1690 let bv = b.vector();
1692 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1693 assert_eq!(s, bit_sum as usize);
1694
1695 let inner_product: MathematicalValue<f32> =
1697 InnerProduct::evaluate(&*shifted, quantizer.shift());
1698
1699 let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1700 assert!(
1701 diff < 1.53e-5,
1702 "got a diff of {}. Expected = {}, got = {} -- context: {}",
1703 diff,
1704 inner_product.into_inner(),
1705 inner_product_with_centroid,
1706 ctx,
1707 );
1708
1709 if const { D == 1 } {
1712 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1713 / (inner_product_correction * (bv.len() as f32).sqrt());
1714 assert!(
1715 (self_inner_product - 0.8).abs() < 0.12,
1716 "self inner-product should be close to 0.8. Instead, it's {}",
1717 self_inner_product
1718 );
1719 }
1720 }
1721
1722 {
1723 let QueryMeta {
1724 inner_product_correction,
1725 bit_sum,
1726 offset,
1727 metric_specific,
1728 } = q.meta();
1729
1730 let inner_product_with_centroid = metric_specific;
1731 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1732 preprocessed
1733 .shifted
1734 .iter_mut()
1735 .for_each(|i| *i /= preprocessed.shifted_norm);
1736
1737 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1738 quantizer
1739 .transform
1740 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1741 .unwrap();
1742
1743 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1744 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1745
1746 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1747
1748 {
1750 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1751 let got = inner_product_correction;
1752
1753 let err = (expected - got).abs();
1754 assert!(
1755 err < 1.0e-7,
1756 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1757 expected,
1758 got,
1759 err
1760 );
1761 }
1762
1763 {
1765 let expected = min / scale;
1766 let got = offset;
1767
1768 let err = (expected - got).abs();
1769 assert!(
1770 err < 1.0e-7,
1771 "\"sum_scale\": expected {}, got {}, error = {}",
1772 expected,
1773 got,
1774 err
1775 );
1776 }
1777
1778 {
1780 let expected = (0..q.len())
1781 .map(|i| q.vector().get(i).unwrap())
1782 .sum::<i64>() as f32;
1783
1784 let got = bit_sum;
1785
1786 let err = (expected - got).abs();
1787 assert!(
1788 err < 1.0e-7,
1789 "\"offset\": expected {}, got {}, error = {}",
1790 expected,
1791 got,
1792 err
1793 );
1794 }
1795
1796 {
1798 let inner_product: MathematicalValue<f32> =
1800 InnerProduct::evaluate(&*shifted, quantizer.shift());
1801 assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1802 }
1803 }
1804
1805 {
1807 let s: f32 = f.data.iter().sum::<f32>();
1809 assert_eq!(s, f.meta.sum);
1810
1811 {
1813 let expected = FastL2Norm.evaluate(&*shifted);
1814 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1815 assert!(
1816 err < 2e-7,
1817 "failed diff check, got {}, expected {} - relative error = {}",
1818 f.meta.shifted_norm,
1819 expected,
1820 err
1821 );
1822 }
1823
1824 let inner_product: MathematicalValue<f32> =
1826 InnerProduct::evaluate(&*shifted, quantizer.shift());
1827 assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1828 }
1829 }
1830
1831 quantizer
1834 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1835 .unwrap();
1836 assert_eq!(b.meta(), DataMeta::default());
1837
1838 quantizer
1839 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1840 .unwrap();
1841 assert_eq!(q.meta(), QueryMeta::default());
1842
1843 f.data.fill(f32::INFINITY);
1844 quantizer
1845 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1846 .unwrap();
1847 assert!(f.data.iter().all(|&i| i == 0.0));
1848 assert_eq!(f.meta.sum, 0.0);
1849 assert_eq!(f.meta.metric_specific, 0.0);
1850 }
1851
1852 fn test_cosine<const Q: usize, const D: usize, Perm>(
1853 setup: &Setup,
1854 problem: &test_util::TestProblem,
1855 pre_scale: PreScale,
1856 rng: &mut StdRng,
1857 ) where
1858 Unsigned: Representation<Q>,
1859 Unsigned: Representation<D>,
1860 Perm: PermutationStrategy<Q>,
1861 for<'a> SphericalQuantizer:
1862 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1863 for<'a> SphericalQuantizer:
1864 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1865 {
1866 assert_eq!(setup.nrows, problem.data.nrows());
1867 assert_eq!(setup.ncols, problem.data.ncols());
1868
1869 let scoped_global = ScopedAllocator::global();
1870 let distribution = Uniform::new(0, setup.nrows).unwrap();
1871 let quantizer = SphericalQuantizer::train(
1872 problem.data.as_view(),
1873 setup.transform,
1874 SupportedMetric::Cosine,
1875 pre_scale,
1876 rng,
1877 GlobalAllocator,
1878 )
1879 .unwrap();
1880
1881 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1882 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1883 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1884
1885 let cosine: CompensatedCosine = quantizer.as_functor();
1886
1887 assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1888 assert_eq!(
1889 cosine.inner.squared_shift_norm,
1890 FastL2NormSquared.evaluate(quantizer.shift())
1891 );
1892
1893 const IP_BOUND: f32 = 2.6e-3;
1894
1895 let mut test_row = |v: &[f32]| {
1896 let vnorm = (FastL2Norm).evaluate(v);
1897 let v_normalized: Vec<f32> = v
1898 .iter()
1899 .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1900 .collect();
1901
1902 quantizer
1903 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1904 .unwrap();
1905
1906 quantizer
1907 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1908 .unwrap();
1909
1910 quantizer
1911 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1912 .unwrap();
1913
1914 let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1915 .map(|(a, b)| a - b)
1916 .collect();
1917
1918 {
1920 let DataMetaF32 {
1921 inner_product_correction,
1922 bit_sum,
1923 metric_specific,
1924 } = b.meta().to_full(ARCH);
1925
1926 let inner_product_with_centroid = metric_specific;
1927
1928 let bv = b.vector();
1930 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1931 assert_eq!(s, bit_sum as usize);
1932
1933 let inner_product: MathematicalValue<f32> =
1936 InnerProduct::evaluate(&*shifted, quantizer.shift());
1937
1938 let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1939 let relative = abs / inner_product.into_inner().abs();
1940
1941 assert!(
1942 abs < 1e-7 || relative < IP_BOUND,
1943 "got an abs/rel of {}/{} with a bound of {}/{}",
1944 abs,
1945 relative,
1946 1e-7,
1947 IP_BOUND
1948 );
1949
1950 if const { D == 1 } {
1953 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1954 / (inner_product_correction * (bv.len() as f32).sqrt());
1955 assert!(
1956 (self_inner_product - 0.8).abs() < 0.11,
1957 "self inner-product should be close to 0.8. Instead, it's {}",
1958 self_inner_product
1959 );
1960 }
1961 }
1962
1963 {
1964 let QueryMeta {
1965 inner_product_correction,
1966 bit_sum,
1967 offset,
1968 metric_specific,
1969 } = q.meta();
1970
1971 let inner_product_with_centroid = metric_specific;
1972 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1973 preprocessed
1974 .shifted
1975 .iter_mut()
1976 .for_each(|i| *i /= preprocessed.shifted_norm);
1977
1978 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1979 quantizer
1980 .transform
1981 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1982 .unwrap();
1983
1984 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1985 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1986
1987 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1988
1989 {
1991 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1992 let got = inner_product_correction;
1993
1994 let err = (expected - got).abs();
1995 assert!(
1996 err < 1.0e-7,
1997 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1998 expected,
1999 got,
2000 err
2001 );
2002 }
2003
2004 {
2006 let expected = min / scale;
2007 let got = offset;
2008
2009 let err = (expected - got).abs();
2010 assert!(
2011 err < 1.0e-7,
2012 "\"sum_scale\": expected {}, got {}, error = {}",
2013 expected,
2014 got,
2015 err
2016 );
2017 }
2018
2019 {
2021 let expected = (0..q.len())
2022 .map(|i| q.vector().get(i).unwrap())
2023 .sum::<i64>() as f32;
2024
2025 let got = bit_sum;
2026
2027 let err = (expected - got).abs();
2028 assert!(
2029 err < 1.0e-7,
2030 "\"offset\": expected {}, got {}, error = {}",
2031 expected,
2032 got,
2033 err
2034 );
2035 }
2036
2037 {
2039 let inner_product: MathematicalValue<f32> =
2041 InnerProduct::evaluate(&*shifted, quantizer.shift());
2042
2043 let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2044 / inner_product.into_inner().abs();
2045 assert!(
2046 err < IP_BOUND,
2047 "\"offset\": expected {}, got {}, error = {}",
2048 inner_product.into_inner(),
2049 inner_product_with_centroid,
2050 err
2051 );
2052 }
2053 }
2054
2055 {
2057 let s: f32 = f.data.iter().sum::<f32>();
2059 assert_eq!(s, f.meta.sum);
2060
2061 {
2063 let expected = FastL2Norm.evaluate(&*shifted);
2064 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2065 assert!(
2066 err < 2e-7,
2067 "failed diff check, got {}, expected {} - relative error = {}",
2068 f.meta.shifted_norm,
2069 expected,
2070 err
2071 );
2072 }
2073
2074 let inner_product: MathematicalValue<f32> =
2076 InnerProduct::evaluate(&*shifted, quantizer.shift());
2077 let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2078 / inner_product.into_inner().abs();
2079 assert!(
2080 err < IP_BOUND,
2081 "\"offset\": expected {}, got {}, error = {}",
2082 inner_product.into_inner(),
2083 f.meta.metric_specific,
2084 err
2085 );
2086 }
2087 };
2088
2089 for _ in 0..setup.num_trials {
2090 let i = distribution.sample(rng);
2091 let v = problem.data.row(i);
2092 test_row(v);
2093 }
2094
2095 let zero = vec![0.0f32; quantizer.input_dim()];
2097 test_row(&zero);
2098 }
2099
2100 fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2101 where
2102 for<'a> T: ReborrowMut<'a>,
2103 for<'a> SphericalQuantizer: CompressIntoWith<
2104 &'a [f32],
2105 <T as ReborrowMut<'a>>::Target,
2106 ScopedAllocator<'a>,
2107 Error = CompressionError,
2108 >,
2109 {
2110 let mut succeeded = false;
2111 let mut failed = false;
2112 for max_allocations in 0..10 {
2113 match quantizer.compress_into_with(
2114 data,
2115 dst.reborrow_mut(),
2116 ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2117 ) {
2118 Ok(()) => {
2119 succeeded = true;
2120 }
2121 Err(CompressionError::AllocatorError(_)) => {
2122 failed = true;
2123 }
2124 Err(other) => {
2125 panic!("received an unexpected error: {:?}", other);
2126 }
2127 }
2128 }
2129 assert!(succeeded);
2130 assert!(failed);
2131 }
2132
2133 fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2134 setup: &Setup,
2135 problem: &test_util::TestProblem,
2136 pre_scale: PreScale,
2137 rng: &mut StdRng,
2138 ) where
2139 Unsigned: Representation<Q>,
2140 Unsigned: Representation<D>,
2141 Perm: PermutationStrategy<Q>,
2142 for<'a> SphericalQuantizer: CompressIntoWith<
2143 &'a [f32],
2144 DataMut<'a, D>,
2145 ScopedAllocator<'a>,
2146 Error = CompressionError,
2147 >,
2148 for<'a> SphericalQuantizer: CompressIntoWith<
2149 &'a [f32],
2150 QueryMut<'a, Q, Perm>,
2151 ScopedAllocator<'a>,
2152 Error = CompressionError,
2153 >,
2154 {
2155 assert_eq!(setup.nrows, problem.data.nrows());
2156 assert_eq!(setup.ncols, problem.data.ncols());
2157
2158 let quantizer = SphericalQuantizer::train(
2159 problem.data.as_view(),
2160 setup.transform,
2161 SupportedMetric::SquaredL2,
2162 pre_scale,
2163 rng,
2164 GlobalAllocator,
2165 )
2166 .unwrap();
2167
2168 let data = problem.data.row(0);
2170 _test_oom_resiliance::<Data<D, _>>(
2171 &quantizer,
2172 data,
2173 &mut Data::new_boxed(quantizer.output_dim()),
2174 );
2175 _test_oom_resiliance::<Query<Q, Perm, _>>(
2176 &quantizer,
2177 data,
2178 &mut Query::new_boxed(quantizer.output_dim()),
2179 );
2180 _test_oom_resiliance::<FullQuery<_>>(
2181 &quantizer,
2182 data,
2183 &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2184 );
2185 }
2186
2187 fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2188 where
2189 Unsigned: Representation<Q>,
2190 Unsigned: Representation<D>,
2191 Perm: PermutationStrategy<Q>,
2192 for<'a> SphericalQuantizer: CompressIntoWith<
2193 &'a [f32],
2194 DataMut<'a, D>,
2195 ScopedAllocator<'a>,
2196 Error = CompressionError,
2197 >,
2198 for<'a> SphericalQuantizer: CompressIntoWith<
2199 &'a [f32],
2200 QueryMut<'a, Q, Perm>,
2201 ScopedAllocator<'a>,
2202 Error = CompressionError,
2203 >,
2204 {
2205 let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2206 let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2207
2208 let scales = [
2209 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2210 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2211 PreScale::ReciprocalMeanNorm,
2212 ];
2213
2214 for scale in scales {
2215 let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2216
2217 test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2218 test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2219 test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2220 }
2221
2222 test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2223 }
2224
2225 #[test]
2226 fn test_spherical_quantizer() {
2227 let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2228 for dim in [56, 72, 128, 255] {
2229 let setup = Setup {
2230 transform: TransformKind::PaddingHadamard {
2231 target_dim: TargetDim::Same,
2232 },
2233 nrows: 64,
2234 ncols: dim,
2235 num_trials: 10,
2236 };
2237
2238 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2239 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2240 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2241 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2242
2243 let setup = Setup {
2244 transform: TransformKind::DoubleHadamard {
2245 target_dim: TargetDim::Same,
2246 },
2247 nrows: 64,
2248 ncols: dim,
2249 num_trials: 10,
2250 };
2251 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2252 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2253 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2254 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2255 }
2256 }
2257
2258 #[test]
2263 fn err_dim_cannot_be_zero() {
2264 let data = Matrix::new(0.0f32, 10, 0);
2265 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2266 let err = SphericalQuantizer::train(
2267 data.as_view(),
2268 TransformKind::DoubleHadamard {
2269 target_dim: TargetDim::Same,
2270 },
2271 SupportedMetric::SquaredL2,
2272 PreScale::None,
2273 &mut rng,
2274 GlobalAllocator,
2275 )
2276 .unwrap_err();
2277 assert_eq!(err.to_string(), "data dim cannot be zero");
2278 }
2279
2280 #[test]
2281 fn err_norm_must_be_positive() {
2282 let data = Matrix::new(0.0f32, 10, 10);
2283 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2284 let err = SphericalQuantizer::train(
2285 data.as_view(),
2286 TransformKind::DoubleHadamard {
2287 target_dim: TargetDim::Same,
2288 },
2289 SupportedMetric::SquaredL2,
2290 PreScale::None,
2291 &mut rng,
2292 GlobalAllocator,
2293 )
2294 .unwrap_err();
2295 assert_eq!(err.to_string(), "norm must be positive");
2296 }
2297
2298 #[test]
2299 fn err_norm_cannot_be_infinity() {
2300 let mut data = Matrix::new(0.0f32, 10, 10);
2301 data[(2, 5)] = f32::INFINITY;
2302
2303 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2304 let err = SphericalQuantizer::train(
2305 data.as_view(),
2306 TransformKind::DoubleHadamard {
2307 target_dim: TargetDim::Same,
2308 },
2309 SupportedMetric::SquaredL2,
2310 PreScale::None,
2311 &mut rng,
2312 GlobalAllocator,
2313 )
2314 .unwrap_err();
2315 assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2316 }
2317
2318 #[test]
2319 fn err_reciprocal_norm_cannot_be_infinity() {
2320 let mut data = Matrix::new(0.0f32, 10, 10);
2321 data[(2, 5)] = 2.93863e-39;
2322
2323 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2324 let err = SphericalQuantizer::train(
2325 data.as_view(),
2326 TransformKind::DoubleHadamard {
2327 target_dim: TargetDim::Same,
2328 },
2329 SupportedMetric::SquaredL2,
2330 PreScale::ReciprocalMeanNorm,
2331 &mut rng,
2332 GlobalAllocator,
2333 )
2334 .unwrap_err();
2335 assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2336 }
2337
2338 #[test]
2339 fn err_mean_norm_cannot_be_zero_generate() {
2340 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2341 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2342 let err = SphericalQuantizer::generate(
2343 centroid,
2344 0.0,
2345 TransformKind::DoubleHadamard {
2346 target_dim: TargetDim::Same,
2347 },
2348 SupportedMetric::SquaredL2,
2349 None,
2350 &mut rng,
2351 GlobalAllocator,
2352 )
2353 .unwrap_err();
2354 assert_eq!(err.to_string(), "norm must be positive");
2355 }
2356
2357 #[test]
2358 fn err_scale_cannot_be_zero_generate() {
2359 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2360 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2361 let err = SphericalQuantizer::generate(
2362 centroid,
2363 1.0,
2364 TransformKind::DoubleHadamard {
2365 target_dim: TargetDim::Same,
2366 },
2367 SupportedMetric::SquaredL2,
2368 Some(0.0),
2369 &mut rng,
2370 GlobalAllocator,
2371 )
2372 .unwrap_err();
2373 assert_eq!(err.to_string(), "pre-scale must be positive");
2374 }
2375
2376 #[test]
2377 fn compression_errors_data() {
2378 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2379 let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2380
2381 let quantizer = SphericalQuantizer::train(
2382 data.as_view(),
2383 TransformKind::PaddingHadamard {
2384 target_dim: TargetDim::Same,
2385 },
2386 SupportedMetric::SquaredL2,
2387 PreScale::None,
2388 &mut rng,
2389 GlobalAllocator,
2390 )
2391 .unwrap();
2392
2393 let scoped_global = ScopedAllocator::global();
2394
2395 {
2397 let mut query: Vec<f32> = quantizer.shift().to_vec();
2398 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2399 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2400
2401 for i in 0..query.len() {
2402 let last = query[i];
2403 for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2404 query[i] = v;
2405
2406 let err = quantizer
2407 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2408 .unwrap_err();
2409
2410 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2411
2412 let err = quantizer
2413 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2414 .unwrap_err();
2415
2416 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2417 }
2418 query[i] = last;
2419 }
2420 }
2421
2422 {
2424 let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2425 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2426
2427 let err = quantizer
2428 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2429 .unwrap_err();
2430
2431 let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2432
2433 assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2434 }
2435
2436 for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2438 let query = vec![0.0f32; len];
2439 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2440 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2441
2442 let err = quantizer
2443 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2444 .unwrap_err();
2445 assert_eq!(
2446 err,
2447 CompressionError::SourceDimensionMismatch {
2448 expected: quantizer.input_dim(),
2449 }
2450 );
2451
2452 let err = quantizer
2453 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2454 .unwrap_err();
2455 assert_eq!(
2456 err,
2457 CompressionError::SourceDimensionMismatch {
2458 expected: quantizer.input_dim(),
2459 }
2460 );
2461 }
2462
2463 for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2464 let query = vec![0.0f32; quantizer.input_dim()];
2465 let mut d = Data::<1, _>::new_boxed(len);
2466 let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2467
2468 let err = quantizer
2469 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2470 .unwrap_err();
2471 assert_eq!(
2472 err,
2473 CompressionError::DestinationDimensionMismatch {
2474 expected: quantizer.output_dim(),
2475 }
2476 );
2477
2478 let err = quantizer
2479 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2480 .unwrap_err();
2481 assert_eq!(
2482 err,
2483 CompressionError::DestinationDimensionMismatch {
2484 expected: quantizer.output_dim(),
2485 }
2486 );
2487 }
2488 }
2489
2490 #[test]
2491 fn centroid_scaling_happens_in_generate() {
2492 let centroid = Poly::from_iter(
2493 [1088.6732f32, 1393.32, 1547.877].into_iter(),
2494 GlobalAllocator,
2495 )
2496 .unwrap();
2497 let mean_norm = 2359.27;
2498 let pre_scale = 1.0 / mean_norm;
2499
2500 let quantizer = SphericalQuantizer::generate(
2501 centroid,
2502 mean_norm,
2503 TransformKind::Null,
2504 SupportedMetric::InnerProduct,
2505 Some(pre_scale),
2506 &mut StdRng::seed_from_u64(10),
2507 GlobalAllocator,
2508 )
2509 .unwrap();
2510
2511 let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2512 let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2513 assert!(
2514 quantizer
2515 .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2516 .is_ok(),
2517 "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2518 );
2519 }
2520}
2521
2522#[cfg(feature = "flatbuffers")]
2523#[cfg(test)]
2524mod test_serialization {
2525 use rand::{SeedableRng, rngs::StdRng};
2526
2527 use super::*;
2528 use crate::{
2529 algorithms::transforms::TargetDim,
2530 flatbuffers::{self as fb, to_flatbuffer},
2531 poly, test_util,
2532 };
2533
2534 #[test]
2535 fn test_serialization_happy_path() {
2536 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2537 let problem = test_util::create_test_problem(10, 128, &mut rng);
2538
2539 let low = NonZeroUsize::new(100).unwrap();
2540 let high = NonZeroUsize::new(150).unwrap();
2541
2542 let kinds = [
2543 TransformKind::Null,
2545 TransformKind::DoubleHadamard {
2547 target_dim: TargetDim::Same,
2548 },
2549 TransformKind::DoubleHadamard {
2550 target_dim: TargetDim::Natural,
2551 },
2552 TransformKind::DoubleHadamard {
2553 target_dim: TargetDim::Override(low),
2554 },
2555 TransformKind::DoubleHadamard {
2556 target_dim: TargetDim::Override(high),
2557 },
2558 TransformKind::PaddingHadamard {
2560 target_dim: TargetDim::Same,
2561 },
2562 TransformKind::PaddingHadamard {
2563 target_dim: TargetDim::Natural,
2564 },
2565 TransformKind::PaddingHadamard {
2566 target_dim: TargetDim::Override(low),
2567 },
2568 TransformKind::PaddingHadamard {
2569 target_dim: TargetDim::Override(high),
2570 },
2571 #[cfg(all(not(miri), feature = "linalg"))]
2573 TransformKind::RandomRotation {
2574 target_dim: TargetDim::Same,
2575 },
2576 #[cfg(all(not(miri), feature = "linalg"))]
2577 TransformKind::RandomRotation {
2578 target_dim: TargetDim::Natural,
2579 },
2580 #[cfg(all(not(miri), feature = "linalg"))]
2581 TransformKind::RandomRotation {
2582 target_dim: TargetDim::Override(low),
2583 },
2584 #[cfg(all(not(miri), feature = "linalg"))]
2585 TransformKind::RandomRotation {
2586 target_dim: TargetDim::Override(high),
2587 },
2588 ];
2589
2590 let pre_scales = [
2591 PreScale::None,
2592 PreScale::Some(Positive::new(0.5).unwrap()),
2593 PreScale::Some(Positive::new(1.0).unwrap()),
2594 PreScale::Some(Positive::new(1.5).unwrap()),
2595 PreScale::ReciprocalMeanNorm,
2596 ];
2597
2598 let alloc = GlobalAllocator;
2599 for kind in kinds.into_iter() {
2600 for metric in SupportedMetric::all() {
2601 for pre_scale in pre_scales {
2602 let quantizer = SphericalQuantizer::train(
2603 problem.data.as_view(),
2604 kind,
2605 metric,
2606 pre_scale,
2607 &mut rng,
2608 alloc,
2609 )
2610 .unwrap();
2611
2612 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2613 let proto =
2614 flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2615 let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2616 assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2617 }
2618 }
2619 }
2620 }
2621
2622 #[test]
2623 fn test_error_checking() {
2624 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2625 let problem = test_util::create_test_problem(10, 128, &mut rng);
2626
2627 let transform = TransformKind::DoubleHadamard {
2628 target_dim: TargetDim::Same,
2629 };
2630
2631 let alloc = GlobalAllocator;
2632 let mut make_quantizer = || {
2633 SphericalQuantizer::train(
2634 problem.data.as_view(),
2635 transform,
2636 SupportedMetric::SquaredL2,
2637 PreScale::None,
2638 &mut rng,
2639 alloc,
2640 )
2641 .unwrap()
2642 };
2643
2644 type E = DeserializationError;
2645
2646 {
2648 let mut quantizer = make_quantizer();
2649 quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2653
2654 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2655 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2656 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2657 assert_eq!(err, E::MissingNorm);
2658 }
2659
2660 {
2662 let mut quantizer = make_quantizer();
2663
2664 quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2668
2669 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2670 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2671 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2672 assert_eq!(err, E::MissingNorm);
2673 }
2674
2675 {
2677 let mut quantizer = make_quantizer();
2678
2679 quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2682
2683 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2684 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2685 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2686 assert_eq!(err, E::PreScaleNotPositive);
2687 }
2688
2689 {
2691 let mut quantizer = make_quantizer();
2692 quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2693
2694 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2695 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2696 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2697 assert_eq!(err, E::DimMismatch);
2698 }
2699 }
2700}