1use std::num::NonZeroUsize;
7
8use diskann_utils::{views::MatrixView, ReborrowMut};
9use diskann_vector::{
10 distance::InnerProduct, norm::FastL2Norm, MathematicalValue, Norm, PureDistanceFunction,
11};
12#[cfg(feature = "flatbuffers")]
13use flatbuffers::{FlatBufferBuilder, WIPOffset};
14use rand::{Rng, RngCore};
15use thiserror::Error;
16
17use super::{
18 CompensatedCosine, CompensatedIP, CompensatedSquaredL2, DataMeta, DataMetaError, DataMut,
19 FullQueryMeta, FullQueryMut, QueryMeta, QueryMut, SupportedMetric,
20};
21#[cfg(feature = "flatbuffers")]
22use crate::{
23 algorithms::transforms::TransformError, flatbuffers::spherical, spherical::InvalidMetric,
24};
25use crate::{
26 algorithms::{
27 heap::SliceHeap,
28 transforms::{NewTransformError, Transform, TransformFailed, TransformKind},
29 },
30 alloc::{Allocator, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone},
31 bits::{PermutationStrategy, Representation, Unsigned},
32 num::Positive,
33 utils::{compute_means_and_average_norm, compute_normalized_means, CannotBeEmpty},
34 AsFunctor, CompressIntoWith,
35};
36
37#[derive(Debug)]
42#[cfg_attr(test, derive(PartialEq))]
43pub struct SphericalQuantizer<A = GlobalAllocator>
44where
45 A: Allocator,
46{
47 shift: Poly<[f32], A>,
49
50 transform: Transform<A>,
58
59 metric: SupportedMetric,
61
62 mean_norm: Positive<f32>,
71
72 pre_scale: Positive<f32>,
78}
79
80impl<A> TryClone for SphericalQuantizer<A>
81where
82 A: Allocator,
83{
84 fn try_clone(&self) -> Result<Self, AllocatorError> {
85 Ok(Self {
86 shift: self.shift.try_clone()?,
87 transform: self.transform.try_clone()?,
88 metric: self.metric,
89 mean_norm: self.mean_norm,
90 pre_scale: self.pre_scale,
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, Error)]
96#[non_exhaustive]
97pub enum TrainError {
98 #[error("data dim cannot be zero")]
99 DimCannotBeZero,
100 #[error("data cannot be empty")]
101 DataCannotBeEmpty,
102 #[error("pre-scale must be positive")]
103 PrescaleNotPositive,
104 #[error("norm must be positive")]
105 NormNotPositive,
106 #[error("computed norm contains infinity or NaN")]
107 NormNotFinite,
108 #[error("reciprocal norm contains infinity or NaN")]
109 ReciprocalNormNotFinite,
110 #[error(transparent)]
111 AllocatorError(#[from] AllocatorError),
112}
113
114impl<A> SphericalQuantizer<A>
115where
116 A: Allocator,
117{
118 pub fn input_dim(&self) -> usize {
120 self.shift.len()
121 }
122
123 pub fn output_dim(&self) -> usize {
128 self.transform.output_dim()
129 }
130
131 pub fn shift(&self) -> &[f32] {
138 &self.shift
139 }
140
141 pub fn mean_norm(&self) -> Positive<f32> {
143 self.mean_norm
144 }
145
146 pub fn pre_scale(&self) -> Positive<f32> {
151 self.pre_scale
152 }
153
154 pub fn allocator(&self) -> &A {
156 self.shift.allocator()
157 }
158
159 pub fn generate(
161 mut centroid: Poly<[f32], A>,
162 mean_norm: f32,
163 transform: TransformKind,
164 metric: SupportedMetric,
165 pre_scale: Option<f32>,
166 rng: &mut dyn RngCore,
167 allocator: A,
168 ) -> Result<Self, TrainError> {
169 let pre_scale = match pre_scale {
170 Some(v) => Positive::new(v).map_err(|_| TrainError::PrescaleNotPositive)?,
171 None => crate::num::POSITIVE_ONE_F32,
172 };
173
174 let dim = match NonZeroUsize::new(centroid.len()) {
175 Some(dim) => dim,
176 None => {
177 return Err(TrainError::DimCannotBeZero);
178 }
179 };
180
181 let mean_norm = Positive::new(mean_norm).map_err(|_| TrainError::NormNotPositive)?;
182
183 let transform = match Transform::new(transform, dim, Some(rng), allocator.clone()) {
185 Ok(v) => v,
186 Err(NewTransformError::RngMissing(_)) => unreachable!("An Rng was provided"),
187 Err(NewTransformError::AllocatorError(err)) => {
188 return Err(TrainError::AllocatorError(err));
189 }
190 };
191
192 centroid
194 .iter_mut()
195 .for_each(|v| *v *= pre_scale.into_inner());
196
197 Ok(SphericalQuantizer {
198 shift: centroid,
199 transform,
200 metric,
201 mean_norm,
202 pre_scale,
203 })
204 }
205
206 pub fn metric(&self) -> SupportedMetric {
208 self.metric
209 }
210
211 pub fn train<T, R>(
226 data: MatrixView<T>,
227 transform: TransformKind,
228 metric: SupportedMetric,
229 pre_scale: PreScale,
230 rng: &mut R,
231 allocator: A,
232 ) -> Result<Self, TrainError>
233 where
234 T: Copy + Into<f64> + Into<f32>,
235 R: Rng,
236 {
237 #[inline(never)]
240 fn train<T, A>(
241 data: MatrixView<T>,
242 transform: TransformKind,
243 metric: SupportedMetric,
244 pre_scale: PreScale,
245 rng: &mut dyn RngCore,
246 allocator: A,
247 ) -> Result<SphericalQuantizer<A>, TrainError>
248 where
249 T: Copy + Into<f64> + Into<f32>,
250 A: Allocator,
251 {
252 if data.ncols() == 0 {
255 return Err(TrainError::DimCannotBeZero);
256 }
257
258 let (centroid, mean_norm) = match metric {
259 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => {
260 compute_means_and_average_norm(data)
261 }
262 SupportedMetric::Cosine => (
263 compute_normalized_means(data)
264 .map_err(|_: CannotBeEmpty| TrainError::DataCannotBeEmpty)?,
265 1.0,
266 ),
267 };
268
269 let mean_norm = mean_norm as f32;
270
271 if mean_norm <= 0.0 {
272 return Err(TrainError::NormNotPositive);
273 }
274
275 if !mean_norm.is_finite() {
276 return Err(TrainError::NormNotFinite);
277 }
278
279 let pre_scale: Positive<f32> = match pre_scale {
281 PreScale::None => crate::num::POSITIVE_ONE_F32,
282 PreScale::Some(v) => v,
283 PreScale::ReciprocalMeanNorm => {
284 let pre_scale = Positive::new(1.0 / mean_norm)
293 .map_err(|_| TrainError::ReciprocalNormNotFinite)?;
294
295 if !pre_scale.into_inner().is_finite() {
296 return Err(TrainError::ReciprocalNormNotFinite);
297 }
298
299 pre_scale
300 }
301 };
302
303 let centroid =
305 Poly::from_iter(centroid.into_iter().map(|i| i as f32), allocator.clone())?;
306
307 SphericalQuantizer::generate(
308 centroid,
309 mean_norm,
310 transform,
311 metric,
312 Some(pre_scale.into_inner()),
313 rng,
314 allocator,
315 )
316 }
317
318 train(data, transform, metric, pre_scale, rng, allocator)
319 }
320
321 pub fn rescale(&self, v: &mut [f32]) {
323 let norm = FastL2Norm.evaluate(&*v);
324 let m = self.mean_norm.into_inner() / norm;
325 v.iter_mut().for_each(|i| *i *= m);
326 }
327
328 fn preprocess<'a>(
334 &self,
335 data: &[f32],
336 allocator: ScopedAllocator<'a>,
337 ) -> Result<Preprocessed<'a>, CompressionError> {
338 assert_eq!(data.len(), self.input_dim(), "Data dimension is incorrect.");
339
340 let scale = self.pre_scale.into_inner();
345 let mul: f32 = match self.metric {
346 SupportedMetric::Cosine => {
347 let norm: f32 = (FastL2Norm).evaluate(data);
348 if norm == 0.0 {
349 1.0
350 } else {
351 1.0 / norm
352 }
353 }
354 SupportedMetric::SquaredL2 | SupportedMetric::InnerProduct => scale,
355 };
356
357 let shifted = Poly::from_iter(
359 std::iter::zip(data.iter(), self.shift.iter()).map(|(&f, &s)| mul * f - s),
360 allocator,
361 )?;
362
363 let shifted_norm = FastL2Norm.evaluate(&*shifted);
364 if !shifted_norm.is_finite() {
365 return Err(CompressionError::InputContainsNaN);
366 }
367 let inner_product_with_centroid = match self.metric {
368 SupportedMetric::SquaredL2 => None,
369 SupportedMetric::InnerProduct | SupportedMetric::Cosine => {
370 let ip: MathematicalValue<f32> = InnerProduct::evaluate(&*shifted, &*self.shift);
371 Some(ip.into_inner())
372 }
373 };
374
375 Ok(Preprocessed {
376 shifted,
377 shifted_norm,
378 inner_product_with_centroid,
379 })
380 }
381}
382
383#[derive(Debug, Clone, Copy)]
387pub enum PreScale {
388 None,
390 Some(Positive<f32>),
392 ReciprocalMeanNorm,
395}
396
397#[cfg(feature = "flatbuffers")]
398#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
399#[derive(Debug, Clone, Error, PartialEq)]
400#[non_exhaustive]
401pub enum DeserializationError {
402 #[error(transparent)]
403 TransformError(#[from] TransformError),
404 #[error("unrecognized flatbuffer identifier")]
405 UnrecognizedIdentifier,
406 #[error("transform length not equal to centroid")]
407 DimMismatch,
408 #[error("norm is missing or is not positive")]
409 MissingNorm,
410 #[error("pre-scale is missing or is not positive")]
411 PreScaleNotPositive,
412
413 #[error(transparent)]
414 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
415
416 #[error(transparent)]
417 InvalidMetric(#[from] InvalidMetric),
418
419 #[error(transparent)]
420 AllocatorError(#[from] AllocatorError),
421}
422
423#[cfg(feature = "flatbuffers")]
424#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
425impl<A> SphericalQuantizer<A>
426where
427 A: Allocator + Clone,
428{
429 pub(crate) fn pack<'a, FA>(
432 &self,
433 buf: &mut FlatBufferBuilder<'a, FA>,
434 ) -> WIPOffset<spherical::SphericalQuantizer<'a>>
435 where
436 FA: flatbuffers::Allocator + 'a,
437 {
438 let centroid = buf.create_vector(&self.shift);
440
441 let transform = self.transform.pack(buf);
443
444 spherical::SphericalQuantizer::create(
446 buf,
447 &spherical::SphericalQuantizerArgs {
448 centroid: Some(centroid),
449 transform: Some(transform),
450 metric: self.metric.into(),
451 mean_norm: self.mean_norm.into_inner(),
452 pre_scale: self.pre_scale.into_inner(),
453 },
454 )
455 }
456
457 pub(crate) fn try_unpack(
460 alloc: A,
461 proto: spherical::SphericalQuantizer<'_>,
462 ) -> Result<Self, DeserializationError> {
463 let metric: SupportedMetric = proto.metric().try_into()?;
464
465 let shift = Poly::from_iter(proto.centroid().into_iter(), alloc.clone())?;
467
468 let transform = Transform::try_unpack(alloc, proto.transform())?;
470
471 if shift.len() != transform.input_dim() {
473 return Err(DeserializationError::DimMismatch);
474 }
475
476 let mean_norm =
478 Positive::new(proto.mean_norm()).map_err(|_| DeserializationError::MissingNorm)?;
479
480 let pre_scale = Positive::new(proto.pre_scale())
481 .map_err(|_| DeserializationError::PreScaleNotPositive)?;
482
483 Ok(Self {
484 shift,
485 transform,
486 metric,
487 mean_norm,
488 pre_scale,
489 })
490 }
491}
492
493struct Preprocessed<'a> {
494 shifted: Poly<[f32], ScopedAllocator<'a>>,
495 shifted_norm: f32,
496 inner_product_with_centroid: Option<f32>,
497}
498
499impl Preprocessed<'_> {
500 fn metric_specific(&self) -> f32 {
505 match self.inner_product_with_centroid {
506 Some(ip) => ip,
507 None => self.shifted_norm * self.shifted_norm,
508 }
509 }
510}
511
512impl<A> AsFunctor<CompensatedSquaredL2> for SphericalQuantizer<A>
517where
518 A: Allocator,
519{
520 fn as_functor(&self) -> CompensatedSquaredL2 {
521 CompensatedSquaredL2::new(self.output_dim())
522 }
523}
524
525impl<A> AsFunctor<CompensatedIP> for SphericalQuantizer<A>
526where
527 A: Allocator,
528{
529 fn as_functor(&self) -> CompensatedIP {
530 CompensatedIP::new(&self.shift, self.output_dim())
531 }
532}
533
534impl<A> AsFunctor<CompensatedCosine> for SphericalQuantizer<A>
535where
536 A: Allocator,
537{
538 fn as_functor(&self) -> CompensatedCosine {
539 CompensatedCosine::new(self.as_functor())
540 }
541}
542
543#[derive(Debug, Error, Clone, Copy, PartialEq)]
548#[non_exhaustive]
549pub enum CompressionError {
550 #[error("input contains NaN")]
551 InputContainsNaN,
552
553 #[error("expected source vector to have length {expected}")]
554 SourceDimensionMismatch { expected: usize },
555
556 #[error("expected destination vector to have length {expected}")]
557 DestinationDimensionMismatch { expected: usize },
558
559 #[error(
560 "encoding error - you may need to scale the entire dataset to reduce its dynamic range"
561 )]
562 EncodingError(#[from] DataMetaError),
563
564 #[error(transparent)]
565 AllocatorError(#[from] AllocatorError),
566}
567
568fn check_dims(
569 input: usize,
570 output: usize,
571 from: usize,
572 into: usize,
573) -> Result<(), CompressionError> {
574 if from != input {
575 return Err(CompressionError::SourceDimensionMismatch { expected: input });
576 }
577 if into != output {
578 return Err(CompressionError::DestinationDimensionMismatch { expected: output });
579 }
580 Ok(())
581}
582
583trait FinishCompressing {
586 fn finish_compressing(
587 &mut self,
588 preprocessed: &Preprocessed<'_>,
589 transformed: &[f32],
590 transformed_norm: f32,
591 allocator: ScopedAllocator<'_>,
592 ) -> Result<(), CompressionError>;
593}
594
595impl FinishCompressing for DataMut<'_, 1> {
596 fn finish_compressing(
597 &mut self,
598 preprocessed: &Preprocessed<'_>,
599 transformed: &[f32],
600 transformed_norm: f32,
601 _: ScopedAllocator<'_>,
602 ) -> Result<(), CompressionError> {
603 let mut quant_raw_inner_product = 0.0f32;
606 let mut bit_sum = 0u32;
607 transformed.iter().enumerate().for_each(|(i, &r)| {
608 let bit: u8 = if r > 0.0 { 1 } else { 0 };
609
610 quant_raw_inner_product += r.abs();
611 bit_sum += <u8 as Into<u32>>::into(bit);
612
613 unsafe { self.vector_mut().set_unchecked(i, bit) };
615 });
616
617 let inner_product_correction =
638 2.0 * transformed_norm * preprocessed.shifted_norm / quant_raw_inner_product;
639 self.set_meta(DataMeta::new(
640 inner_product_correction,
641 preprocessed.metric_specific(),
642 bit_sum,
643 )?);
644 Ok(())
645 }
646}
647
648impl FinishCompressing for DataMut<'_, 2> {
649 fn finish_compressing(
650 &mut self,
651 preprocessed: &Preprocessed<'_>,
652 transformed: &[f32],
653 transformed_norm: f32,
654 allocator: ScopedAllocator<'_>,
655 ) -> Result<(), CompressionError> {
656 compress_via_maximum_cosine(
657 self.reborrow_mut(),
658 preprocessed,
659 transformed,
660 transformed_norm,
661 allocator,
662 )
663 }
664}
665
666impl FinishCompressing for DataMut<'_, 4> {
667 fn finish_compressing(
668 &mut self,
669 preprocessed: &Preprocessed<'_>,
670 transformed: &[f32],
671 transformed_norm: f32,
672 allocator: ScopedAllocator<'_>,
673 ) -> Result<(), CompressionError> {
674 compress_via_maximum_cosine(
675 self.reborrow_mut(),
676 preprocessed,
677 transformed,
678 transformed_norm,
679 allocator,
680 )
681 }
682}
683
684impl FinishCompressing for DataMut<'_, 8> {
685 fn finish_compressing(
686 &mut self,
687 preprocessed: &Preprocessed<'_>,
688 transformed: &[f32],
689 transformed_norm: f32,
690 allocator: ScopedAllocator<'_>,
691 ) -> Result<(), CompressionError> {
692 compress_via_maximum_cosine(
693 self.reborrow_mut(),
694 preprocessed,
695 transformed,
696 transformed_norm,
697 allocator,
698 )
699 }
700}
701
702impl<A> CompressIntoWith<&[f32], FullQueryMut<'_>, ScopedAllocator<'_>> for SphericalQuantizer<A>
707where
708 A: Allocator,
709{
710 type Error = CompressionError;
711
712 fn compress_into_with(
725 &self,
726 from: &[f32],
727 mut into: FullQueryMut<'_>,
728 allocator: ScopedAllocator<'_>,
729 ) -> Result<(), Self::Error> {
730 let input_dim = self.shift.len();
731 let output_dim = self.output_dim();
732 check_dims(input_dim, output_dim, from.len(), into.len())?;
733
734 let mut preprocessed = self.preprocess(from, allocator)?;
735
736 if preprocessed.shifted_norm == 0.0 {
739 into.vector_mut().fill(0.0);
740 *into.meta_mut() = Default::default();
741 return Ok(());
742 }
743
744 preprocessed
745 .shifted
746 .iter_mut()
747 .for_each(|v| *v /= preprocessed.shifted_norm);
748
749 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
754 match self
755 .transform
756 .transform_into(into.vector_mut(), &preprocessed.shifted, allocator)
757 {
758 Ok(()) => {}
759 Err(TransformFailed::AllocatorError(err)) => {
760 return Err(CompressionError::AllocatorError(err))
761 }
762 Err(TransformFailed::SourceMismatch { .. })
763 | Err(TransformFailed::DestinationMismatch { .. }) => {
764 panic!(
765 "The sizes of these arrays should already be checked - this is a logic error"
766 );
767 }
768 }
769
770 *into.meta_mut() = FullQueryMeta {
771 sum: into.vector().iter().sum::<f32>(),
772 shifted_norm: preprocessed.shifted_norm,
773 metric_specific: preprocessed.metric_specific(),
774 };
775 Ok(())
776 }
777}
778
779impl<const NBITS: usize, A> CompressIntoWith<&[f32], DataMut<'_, NBITS>, ScopedAllocator<'_>>
780 for SphericalQuantizer<A>
781where
782 A: Allocator,
783 Unsigned: Representation<NBITS>,
784 for<'a> DataMut<'a, NBITS>: FinishCompressing,
785{
786 type Error = CompressionError;
787
788 fn compress_into_with(
801 &self,
802 from: &[f32],
803 mut into: DataMut<'_, NBITS>,
804 allocator: ScopedAllocator<'_>,
805 ) -> Result<(), Self::Error> {
806 let input_dim = self.shift.len();
807 let output_dim = self.output_dim();
808 check_dims(input_dim, output_dim, from.len(), into.len())?;
809
810 let mut preprocessed = self.preprocess(from, allocator)?;
811
812 if preprocessed.shifted_norm == 0.0 {
813 into.set_meta(DataMeta::default());
814 return Ok(());
815 }
816
817 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
818 preprocessed
819 .shifted
820 .iter_mut()
821 .for_each(|v| *v /= preprocessed.shifted_norm);
822
823 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
828 match self
829 .transform
830 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
831 {
832 Ok(()) => {}
833 Err(TransformFailed::AllocatorError(err)) => {
834 return Err(CompressionError::AllocatorError(err))
835 }
836 Err(TransformFailed::SourceMismatch { .. })
837 | Err(TransformFailed::DestinationMismatch { .. }) => {
838 panic!(
839 "The sizes of these arrays should already be checked - this is a logic error"
840 );
841 }
842 }
843
844 let transformed_norm = if self.transform.preserves_norms() {
845 1.0
846 } else {
847 (FastL2Norm).evaluate(&*transformed)
848 };
849
850 into.finish_compressing(&preprocessed, &transformed, transformed_norm, allocator)?;
851 Ok(())
852 }
853}
854
855struct AsNonZero<const NBITS: usize>;
856impl<const NBITS: usize> AsNonZero<NBITS> {
857 #[allow(clippy::unwrap_used)]
859 const NON_ZERO: NonZeroUsize = NonZeroUsize::new(NBITS).unwrap();
860}
861
862fn compress_via_maximum_cosine<const NBITS: usize>(
863 mut data: DataMut<'_, NBITS>,
864 preprocessed: &Preprocessed<'_>,
865 transformed: &[f32],
866 transformed_norm: f32,
867 allocator: ScopedAllocator<'_>,
868) -> Result<(), CompressionError>
869where
870 Unsigned: Representation<NBITS>,
871{
872 assert_eq!(data.len(), transformed.len());
873
874 let optimal_scale =
877 maximize_cosine_similarity(transformed, AsNonZero::<NBITS>::NON_ZERO, allocator)?;
878
879 let domain = Unsigned::domain_const::<NBITS>();
880 let min = *domain.start() as f32;
881 let max = *domain.end() as f32;
882 let offset = max / 2.0;
883
884 let mut self_inner_product = 0.0f32;
885 let mut bit_sum = 0u32;
886 for (i, t) in transformed.iter().enumerate() {
887 let v = (*t * optimal_scale + offset).clamp(min, max).round();
888 let dv = v - offset;
889 self_inner_product = dv.mul_add(*t, self_inner_product);
890
891 let v = v as u8;
892 bit_sum += <u8 as Into<u32>>::into(v);
893
894 unsafe { data.vector_mut().set_unchecked(i, v) };
899 }
900
901 let shifted_norm = preprocessed.shifted_norm;
902 let inner_product_correction = (transformed_norm * shifted_norm) / self_inner_product;
903 data.set_meta(DataMeta::new(
904 inner_product_correction,
905 preprocessed.metric_specific(),
906 bit_sum,
907 )?);
908 Ok(())
909}
910
911#[derive(Debug, Clone, Copy)]
925struct Pair {
926 value: f32,
927 position: u32,
928}
929
930impl PartialEq for Pair {
931 fn eq(&self, other: &Self) -> bool {
932 self.value.eq(&other.value)
933 }
934}
935
936impl Eq for Pair {}
937impl PartialOrd for Pair {
938 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
939 Some(self.cmp(other))
940 }
941}
942impl Ord for Pair {
943 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
944 other
945 .value
946 .partial_cmp(&self.value)
947 .unwrap_or(std::cmp::Ordering::Equal)
948 }
949}
950
951fn maximize_cosine_similarity(
983 v: &[f32],
984 num_bits: NonZeroUsize,
985 allocator: ScopedAllocator<'_>,
986) -> Result<f32, AllocatorError> {
987 let mut current_ip = 0.5 * v.iter().map(|i| i.abs() as f64).sum::<f64>();
992 let mut current_square_norm = 0.25 * (v.len() as f64);
993
994 let mut rounded = Poly::broadcast(1u16, v.len(), allocator)?;
998
999 let eps = 0.0001f32;
1005 let one_and_change = 1.0 + eps;
1006 let mut base = Poly::from_iter(
1007 v.iter().enumerate().map(|(position, value)| {
1008 let value = one_and_change / value.abs();
1009 Pair {
1010 value,
1011 position: position as u32,
1012 }
1013 }),
1014 allocator,
1015 )?;
1016
1017 #[allow(clippy::expect_used)]
1020 let mut critical_values =
1021 SliceHeap::new(&mut base).expect("calling code should not allow the slice to be empty");
1022
1023 let mut max_similarity = f64::NEG_INFINITY;
1024 let mut optimal_scale = f32::default();
1025 let stop = (2usize).pow(num_bits.get() as u32 - 1) as u16;
1026
1027 loop {
1028 let mut should_break = false;
1029 critical_values.update_root(|pair| {
1030 let Pair { value, position } = *pair;
1031 if value == f32::MAX {
1032 should_break = true;
1033 return;
1034 }
1035
1036 let r = &mut rounded[position as usize];
1037 let vp = &v[position as usize];
1038
1039 let old_r = *r;
1040 *r += 1;
1043
1044 current_ip += vp.abs() as f64;
1049
1050 current_square_norm += (2 * old_r) as f64;
1064
1065 let similarity = current_ip / current_square_norm.sqrt();
1067 if similarity > max_similarity {
1068 max_similarity = similarity;
1069 optimal_scale = value;
1070 }
1071
1072 if *r < stop {
1074 *pair = Pair {
1075 value: (*r as f32 + eps) / vp.abs(),
1076 position,
1077 };
1078 } else {
1079 *pair = Pair {
1080 value: f32::MAX,
1081 position,
1082 };
1083 }
1084 });
1085 if should_break {
1086 break;
1087 }
1088 }
1089
1090 Ok(optimal_scale)
1091}
1092
1093impl<const NBITS: usize, Perm, A>
1098 CompressIntoWith<&[f32], QueryMut<'_, NBITS, Perm>, ScopedAllocator<'_>>
1099 for SphericalQuantizer<A>
1100where
1101 Unsigned: Representation<NBITS>,
1102 Perm: PermutationStrategy<NBITS>,
1103 A: Allocator,
1104{
1105 type Error = CompressionError;
1106
1107 fn compress_into_with(
1120 &self,
1121 from: &[f32],
1122 mut into: QueryMut<'_, NBITS, Perm>,
1123 allocator: ScopedAllocator<'_>,
1124 ) -> Result<(), Self::Error> {
1125 let input_dim = self.shift.len();
1126 let output_dim = self.output_dim();
1127 check_dims(input_dim, output_dim, from.len(), into.len())?;
1128
1129 let mut preprocessed = self.preprocess(from, allocator)?;
1130
1131 if preprocessed.shifted_norm == 0.0 {
1132 into.set_meta(QueryMeta::default());
1133 return Ok(());
1134 }
1135
1136 preprocessed
1137 .shifted
1138 .iter_mut()
1139 .for_each(|v| *v /= preprocessed.shifted_norm);
1140
1141 let mut transformed = Poly::broadcast(0.0f32, output_dim, allocator)?;
1142
1143 #[expect(clippy::panic, reason = "the dimensions should already be as expected")]
1148 match self
1149 .transform
1150 .transform_into(&mut transformed, &preprocessed.shifted, allocator)
1151 {
1152 Ok(()) => {}
1153 Err(TransformFailed::AllocatorError(err)) => {
1154 return Err(CompressionError::AllocatorError(err))
1155 }
1156 Err(TransformFailed::SourceMismatch { .. })
1157 | Err(TransformFailed::DestinationMismatch { .. }) => {
1158 panic!(
1159 "The sizes of these arrays should already be checked - this is a logic error"
1160 );
1161 }
1162 }
1163
1164 let (min, max) = transformed
1166 .iter()
1167 .fold((f32::MAX, f32::MIN), |(min, max), i| {
1168 (i.min(min), i.max(max))
1169 });
1170
1171 let domain = Unsigned::domain_const::<NBITS>();
1172 let lo = (*domain.start()) as f32;
1173 let hi = (*domain.end()) as f32;
1174
1175 let scale = (max - min) / hi;
1176 let mut bit_sum: f32 = 0.0;
1177 transformed.iter().enumerate().for_each(|(i, v)| {
1178 let c = ((v - min) / scale).round().clamp(lo, hi);
1179 bit_sum += c;
1180
1181 #[allow(clippy::unwrap_used)]
1187 into.vector_mut().set(i, c as i64).unwrap();
1188 });
1189
1190 into.set_meta(QueryMeta {
1192 inner_product_correction: preprocessed.shifted_norm * scale,
1193 bit_sum,
1194 offset: min / scale,
1195 metric_specific: preprocessed.metric_specific(),
1196 });
1197
1198 Ok(())
1199 }
1200}
1201
1202#[cfg(not(miri))]
1207#[cfg(test)]
1208mod tests {
1209 use super::*;
1210
1211 use std::fmt::Display;
1212
1213 use diskann_utils::{
1214 lazy_format,
1215 views::{self, Matrix},
1216 ReborrowMut,
1217 };
1218 use diskann_vector::{norm::FastL2NormSquared, PureDistanceFunction};
1219 use diskann_wide::ARCH;
1220 use rand::{
1221 distr::{Distribution, Uniform},
1222 rngs::StdRng,
1223 SeedableRng,
1224 };
1225 use rand_distr::StandardNormal;
1226
1227 use crate::{
1228 algorithms::transforms::TargetDim,
1229 alloc::GlobalAllocator,
1230 bits::{BitTranspose, Dense},
1231 spherical::{Data, DataMetaF32, FullQuery, Query},
1232 test_util,
1233 };
1234
1235 #[test]
1237 fn test_cosine_similarity_maximizer() {
1238 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
1239 let num_trials = 10000;
1240 let num_bits = NonZeroUsize::new(3).unwrap();
1241
1242 let scale_distribution = Uniform::new(0.5f32, 10.0f32).unwrap();
1243
1244 let run_test = |target: [f32; 4]| {
1245 let scale =
1246 maximize_cosine_similarity(&target, num_bits, ScopedAllocator::global()).unwrap();
1247
1248 let mut best: [f32; 4] = [0.0, 0.0, 0.0, 0.0];
1249 let mut best_similarity: f32 = f32::NEG_INFINITY;
1250
1251 let min = -3.5;
1254 for i0 in 0..8 {
1255 for i1 in 0..8 {
1256 for i2 in 0..8 {
1257 for i3 in 0..8 {
1258 let p: [f32; 4] = [
1259 min + (i0 as f32),
1260 min + (i1 as f32),
1261 min + (i2 as f32),
1262 min + (i3 as f32),
1263 ];
1264
1265 let sim: MathematicalValue<f32> =
1266 diskann_vector::distance::Cosine::evaluate(&p, &target);
1267 let sim = sim.into_inner();
1268 if sim > best_similarity {
1269 best_similarity = sim;
1270 best = p.map(|i| i - min);
1272 }
1273 }
1274 }
1275 }
1276 }
1277
1278 let clamped = target.map(|i| (i * scale - min).round().clamp(0.0, 7.0));
1281 let clamped_cosine: MathematicalValue<f32> =
1282 diskann_vector::distance::Cosine::evaluate(&clamped.map(|i| i + min), &target);
1283
1284 let passed = if best == clamped {
1287 true
1288 } else {
1289 let ratio: Vec<f32> = std::iter::zip(best, clamped)
1290 .map(|(b, c)| {
1291 let ratio = (b + min) / (c + min);
1292 assert_ne!(
1293 ratio, 0.0,
1294 "ratio should never be zero because `b` is an integer and \
1295 `min` is not"
1296 );
1297 ratio
1298 })
1299 .collect();
1300
1301 ratio.iter().all(|i| *i == ratio[0])
1302 };
1303
1304 if !passed {
1305 panic!(
1306 "failed for input {:?}.\
1307 Best = {:?}, Found = {:?}\
1308 Best similarity = {}, similarity with clamped = {}",
1309 target,
1310 best,
1311 clamped,
1312 best_similarity,
1313 clamped_cosine.into_inner()
1314 );
1315 }
1316 };
1317
1318 let min = -3.5;
1320 for i0 in (0..8).step_by(2) {
1321 for i1 in (1..9).step_by(2) {
1322 for i2 in (0..8).step_by(2) {
1323 for i3 in (1..9).step_by(2) {
1324 let p: [f32; 4] = [
1325 min + (i0 as f32),
1326 min + (i1 as f32),
1327 min + (i2 as f32),
1328 min + (i3 as f32),
1329 ];
1330 run_test(p)
1331 }
1332 }
1333 }
1334 }
1335
1336 for _ in 0..num_trials {
1337 let this_scale: f32 = scale_distribution.sample(&mut rng);
1338 let v: [f32; 4] = [(); 4].map(|_| {
1339 let v: f32 = StandardNormal {}.sample(&mut rng);
1340 this_scale * v
1341 });
1342 run_test(v);
1343 }
1344 }
1345
1346 #[test]
1347 #[should_panic(expected = "calling code should not allow the slice to be empty")]
1348 fn empty_slice_panics() {
1349 maximize_cosine_similarity(
1350 &[],
1351 NonZeroUsize::new(4).unwrap(),
1352 ScopedAllocator::global(),
1353 )
1354 .unwrap();
1355 }
1356
1357 struct Setup {
1358 transform: TransformKind,
1359 nrows: usize,
1360 ncols: usize,
1361 num_trials: usize,
1362 }
1363
1364 fn get_scale(scale: PreScale, quantizer: &SphericalQuantizer) -> f32 {
1365 match scale {
1366 PreScale::None => 1.0,
1367 PreScale::Some(v) => v.into_inner(),
1368 PreScale::ReciprocalMeanNorm => 1.0 / quantizer.mean_norm().into_inner(),
1369 }
1370 }
1371
1372 fn test_l2<const Q: usize, const D: usize, Perm>(
1373 setup: &Setup,
1374 problem: &test_util::TestProblem,
1375 computed_means: &[f32],
1376 pre_scale: PreScale,
1377 rng: &mut StdRng,
1378 ) where
1379 Unsigned: Representation<Q>,
1380 Unsigned: Representation<D>,
1381 Perm: PermutationStrategy<Q>,
1382 for<'a> SphericalQuantizer:
1383 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1384 for<'a> SphericalQuantizer:
1385 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1386 {
1387 assert_eq!(setup.nrows, problem.data.nrows());
1388 assert_eq!(setup.ncols, problem.data.ncols());
1389
1390 let scoped_global = ScopedAllocator::global();
1391 let distribution = Uniform::new(0, setup.nrows).unwrap();
1392 let quantizer = SphericalQuantizer::train(
1393 problem.data.as_view(),
1394 setup.transform,
1395 SupportedMetric::SquaredL2,
1396 pre_scale,
1397 rng,
1398 GlobalAllocator,
1399 )
1400 .unwrap();
1401
1402 let scale = get_scale(pre_scale, &quantizer);
1403
1404 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1405 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1406 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1407
1408 assert_eq!(
1409 quantizer.mean_norm.into_inner(),
1410 problem.mean_norm as f32,
1411 "computed mean norm should not apply scale"
1412 );
1413 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1414 assert_eq!(&*scaled_means, quantizer.shift());
1415
1416 let l2: CompensatedSquaredL2 = quantizer.as_functor();
1417 assert_eq!(l2.dim, quantizer.output_dim() as f32);
1418
1419 for _ in 0..setup.num_trials {
1420 let i = distribution.sample(rng);
1421 let v = problem.data.row(i);
1422
1423 quantizer
1424 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1425 .unwrap();
1426 quantizer
1427 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1428 .unwrap();
1429 quantizer
1430 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1431 .unwrap();
1432
1433 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1434 .map(|(a, b)| scale * a - b)
1435 .collect();
1436
1437 {
1439 let DataMetaF32 {
1440 inner_product_correction,
1441 bit_sum,
1442 metric_specific,
1443 } = b.meta().to_full(ARCH);
1444
1445 let shifted_square_norm = metric_specific;
1446
1447 let bv = b.vector();
1449 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1450 assert_eq!(s, bit_sum as usize);
1451
1452 {
1454 let expected = FastL2NormSquared.evaluate(&*shifted);
1455 let err = (shifted_square_norm - expected).abs() / expected.abs();
1456 assert!(
1457 err < 5.0e-4, "failed diff check, got {}, expected {} - relative error = {}",
1459 shifted_square_norm,
1460 expected,
1461 err
1462 );
1463 }
1464
1465 if const { D == 1 } {
1468 let self_inner_product = 2.0 * shifted_square_norm.sqrt()
1469 / (inner_product_correction * (bv.len() as f32).sqrt());
1470 assert!(
1471 (self_inner_product - 0.8).abs() < 0.13,
1472 "self inner-product should be close to 0.8. Instead, it's {}",
1473 self_inner_product
1474 );
1475 }
1476 }
1477
1478 {
1479 let QueryMeta {
1480 inner_product_correction,
1481 bit_sum,
1482 offset,
1483 metric_specific,
1484 } = q.meta();
1485
1486 let shifted_square_norm = metric_specific;
1487 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1488 preprocessed
1489 .shifted
1490 .iter_mut()
1491 .for_each(|i| *i /= preprocessed.shifted_norm);
1492
1493 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1494 quantizer
1495 .transform
1496 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1497 .unwrap();
1498
1499 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1500 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1501
1502 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1503
1504 {
1506 let expected = FastL2NormSquared.evaluate(&*shifted);
1507 let err = (shifted_square_norm - expected).abs() / expected.abs();
1508 assert!(
1509 err < 2e-7,
1510 "failed diff check, got {}, expected {} - relative error = {}",
1511 shifted_square_norm,
1512 expected,
1513 err
1514 );
1515 }
1516
1517 {
1519 let expected = shifted_square_norm.sqrt() * scale;
1520 let got = inner_product_correction;
1521
1522 let err = (expected - got).abs();
1523 assert!(
1524 err < 1.0e-7,
1525 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1526 expected,
1527 got,
1528 err
1529 );
1530 }
1531
1532 {
1534 let expected = min / scale;
1535 let got = offset;
1536
1537 let err = (expected - got).abs();
1538 assert!(
1539 err < 1.0e-7,
1540 "\"sum_scale\": expected {}, got {}, error = {}",
1541 expected,
1542 got,
1543 err
1544 );
1545 }
1546
1547 {
1549 let expected = (0..q.len())
1550 .map(|i| q.vector().get(i).unwrap())
1551 .sum::<i64>() as f32;
1552
1553 let got = bit_sum;
1554
1555 let err = (expected - got).abs();
1556 assert!(
1557 err < 1.0e-7,
1558 "\"offset\": expected {}, got {}, error = {}",
1559 expected,
1560 got,
1561 err
1562 );
1563 }
1564 }
1565
1566 {
1568 let s: f32 = f.data.iter().sum::<f32>();
1570 assert_eq!(s, f.meta.sum);
1571
1572 {
1574 let expected = FastL2Norm.evaluate(&*shifted);
1575 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1576 assert!(
1577 err < 2e-7,
1578 "failed diff check, got {}, expected {} - relative error = {}",
1579 f.meta.shifted_norm,
1580 expected,
1581 err
1582 );
1583 }
1584
1585 assert_eq!(
1586 f.meta.metric_specific,
1587 f.meta.shifted_norm * f.meta.shifted_norm,
1588 "metric specific data for squared l2 is the square shifted norm",
1589 );
1590 }
1591 }
1592
1593 quantizer
1596 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1597 .unwrap();
1598 assert_eq!(b.meta(), DataMeta::default());
1599
1600 quantizer
1601 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1602 .unwrap();
1603 assert_eq!(q.meta(), QueryMeta::default());
1604
1605 f.data.fill(f32::INFINITY);
1606 quantizer
1607 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1608 .unwrap();
1609 assert!(f.data.iter().all(|&i| i == 0.0));
1610 assert_eq!(f.meta.sum, 0.0);
1611 assert_eq!(f.meta.metric_specific, 0.0);
1612 }
1613
1614 fn test_ip<const Q: usize, const D: usize, Perm>(
1615 setup: &Setup,
1616 problem: &test_util::TestProblem,
1617 computed_means: &[f32],
1618 pre_scale: PreScale,
1619 rng: &mut StdRng,
1620 ctx: &dyn Display,
1621 ) where
1622 Unsigned: Representation<Q>,
1623 Unsigned: Representation<D>,
1624 Perm: PermutationStrategy<Q>,
1625 for<'a> SphericalQuantizer:
1626 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1627 for<'a> SphericalQuantizer:
1628 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1629 {
1630 assert_eq!(setup.nrows, problem.data.nrows());
1631 assert_eq!(setup.ncols, problem.data.ncols());
1632
1633 let scoped_global = ScopedAllocator::global();
1634 let distribution = Uniform::new(0, setup.nrows).unwrap();
1635 let quantizer = SphericalQuantizer::train(
1636 problem.data.as_view(),
1637 setup.transform,
1638 SupportedMetric::InnerProduct,
1639 pre_scale,
1640 rng,
1641 GlobalAllocator,
1642 )
1643 .unwrap();
1644
1645 let scale = get_scale(pre_scale, &quantizer);
1646
1647 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1648 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1649 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1650
1651 assert_eq!(
1652 quantizer.mean_norm.into_inner(),
1653 problem.mean_norm as f32,
1654 "computed mean norm should not apply scale"
1655 );
1656 let scaled_means: Vec<_> = computed_means.iter().map(|i| scale * i).collect();
1657 assert_eq!(&*scaled_means, quantizer.shift());
1658
1659 let ip: CompensatedIP = quantizer.as_functor();
1660
1661 assert_eq!(ip.dim, quantizer.output_dim() as f32);
1662 assert_eq!(
1663 ip.squared_shift_norm,
1664 FastL2NormSquared.evaluate(quantizer.shift())
1665 );
1666
1667 for _ in 0..setup.num_trials {
1668 let i = distribution.sample(rng);
1669 let v = problem.data.row(i);
1670
1671 quantizer
1672 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1673 .unwrap();
1674 quantizer
1675 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1676 .unwrap();
1677 quantizer
1678 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1679 .unwrap();
1680
1681 let shifted: Vec<f32> = std::iter::zip(v.iter(), quantizer.shift().iter())
1682 .map(|(a, b)| scale * a - b)
1683 .collect();
1684
1685 {
1687 let DataMetaF32 {
1688 inner_product_correction,
1689 bit_sum,
1690 metric_specific,
1691 } = b.meta().to_full(ARCH);
1692
1693 let inner_product_with_centroid = metric_specific;
1694
1695 let bv = b.vector();
1697 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1698 assert_eq!(s, bit_sum as usize);
1699
1700 let inner_product: MathematicalValue<f32> =
1702 InnerProduct::evaluate(&*shifted, quantizer.shift());
1703
1704 let diff = (inner_product.into_inner() - inner_product_with_centroid).abs();
1705 assert!(
1706 diff < 1.53e-5,
1707 "got a diff of {}. Expected = {}, got = {} -- context: {}",
1708 diff,
1709 inner_product.into_inner(),
1710 inner_product_with_centroid,
1711 ctx,
1712 );
1713
1714 if const { D == 1 } {
1717 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1718 / (inner_product_correction * (bv.len() as f32).sqrt());
1719 assert!(
1720 (self_inner_product - 0.8).abs() < 0.12,
1721 "self inner-product should be close to 0.8. Instead, it's {}",
1722 self_inner_product
1723 );
1724 }
1725 }
1726
1727 {
1728 let QueryMeta {
1729 inner_product_correction,
1730 bit_sum,
1731 offset,
1732 metric_specific,
1733 } = q.meta();
1734
1735 let inner_product_with_centroid = metric_specific;
1736 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1737 preprocessed
1738 .shifted
1739 .iter_mut()
1740 .for_each(|i| *i /= preprocessed.shifted_norm);
1741
1742 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1743 quantizer
1744 .transform
1745 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1746 .unwrap();
1747
1748 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1749 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1750
1751 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1752
1753 {
1755 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1756 let got = inner_product_correction;
1757
1758 let err = (expected - got).abs();
1759 assert!(
1760 err < 1.0e-7,
1761 "\"innerproduct_scale\": expected {}, got {}, error = {}",
1762 expected,
1763 got,
1764 err
1765 );
1766 }
1767
1768 {
1770 let expected = min / scale;
1771 let got = offset;
1772
1773 let err = (expected - got).abs();
1774 assert!(
1775 err < 1.0e-7,
1776 "\"sum_scale\": expected {}, got {}, error = {}",
1777 expected,
1778 got,
1779 err
1780 );
1781 }
1782
1783 {
1785 let expected = (0..q.len())
1786 .map(|i| q.vector().get(i).unwrap())
1787 .sum::<i64>() as f32;
1788
1789 let got = bit_sum;
1790
1791 let err = (expected - got).abs();
1792 assert!(
1793 err < 1.0e-7,
1794 "\"offset\": expected {}, got {}, error = {}",
1795 expected,
1796 got,
1797 err
1798 );
1799 }
1800
1801 {
1803 let inner_product: MathematicalValue<f32> =
1805 InnerProduct::evaluate(&*shifted, quantizer.shift());
1806 assert_eq!(inner_product.into_inner(), inner_product_with_centroid);
1807 }
1808 }
1809
1810 {
1812 let s: f32 = f.data.iter().sum::<f32>();
1814 assert_eq!(s, f.meta.sum);
1815
1816 {
1818 let expected = FastL2Norm.evaluate(&*shifted);
1819 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
1820 assert!(
1821 err < 2e-7,
1822 "failed diff check, got {}, expected {} - relative error = {}",
1823 f.meta.shifted_norm,
1824 expected,
1825 err
1826 );
1827 }
1828
1829 let inner_product: MathematicalValue<f32> =
1831 InnerProduct::evaluate(&*shifted, quantizer.shift());
1832 assert_eq!(inner_product.into_inner(), f.meta.metric_specific,);
1833 }
1834 }
1835
1836 quantizer
1839 .compress_into_with(computed_means, b.reborrow_mut(), scoped_global)
1840 .unwrap();
1841 assert_eq!(b.meta(), DataMeta::default());
1842
1843 quantizer
1844 .compress_into_with(computed_means, q.reborrow_mut(), scoped_global)
1845 .unwrap();
1846 assert_eq!(q.meta(), QueryMeta::default());
1847
1848 f.data.fill(f32::INFINITY);
1849 quantizer
1850 .compress_into_with(computed_means, f.reborrow_mut(), scoped_global)
1851 .unwrap();
1852 assert!(f.data.iter().all(|&i| i == 0.0));
1853 assert_eq!(f.meta.sum, 0.0);
1854 assert_eq!(f.meta.metric_specific, 0.0);
1855 }
1856
1857 fn test_cosine<const Q: usize, const D: usize, Perm>(
1858 setup: &Setup,
1859 problem: &test_util::TestProblem,
1860 pre_scale: PreScale,
1861 rng: &mut StdRng,
1862 ) where
1863 Unsigned: Representation<Q>,
1864 Unsigned: Representation<D>,
1865 Perm: PermutationStrategy<Q>,
1866 for<'a> SphericalQuantizer:
1867 CompressIntoWith<&'a [f32], DataMut<'a, D>, ScopedAllocator<'a>>,
1868 for<'a> SphericalQuantizer:
1869 CompressIntoWith<&'a [f32], QueryMut<'a, Q, Perm>, ScopedAllocator<'a>>,
1870 {
1871 assert_eq!(setup.nrows, problem.data.nrows());
1872 assert_eq!(setup.ncols, problem.data.ncols());
1873
1874 let scoped_global = ScopedAllocator::global();
1875 let distribution = Uniform::new(0, setup.nrows).unwrap();
1876 let quantizer = SphericalQuantizer::train(
1877 problem.data.as_view(),
1878 setup.transform,
1879 SupportedMetric::Cosine,
1880 pre_scale,
1881 rng,
1882 GlobalAllocator,
1883 )
1884 .unwrap();
1885
1886 let mut b = Data::<D, _>::new_boxed(quantizer.output_dim());
1887 let mut q = Query::<Q, Perm, _>::new_boxed(quantizer.output_dim());
1888 let mut f = FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap();
1889
1890 let cosine: CompensatedCosine = quantizer.as_functor();
1891
1892 assert_eq!(cosine.inner.dim, quantizer.output_dim() as f32);
1893 assert_eq!(
1894 cosine.inner.squared_shift_norm,
1895 FastL2NormSquared.evaluate(quantizer.shift())
1896 );
1897
1898 const IP_BOUND: f32 = 2.6e-3;
1899
1900 let mut test_row = |v: &[f32]| {
1901 let vnorm = (FastL2Norm).evaluate(v);
1902 let v_normalized: Vec<f32> = v
1903 .iter()
1904 .map(|i| if vnorm == 0.0 { 0.0 } else { *i / vnorm })
1905 .collect();
1906
1907 quantizer
1908 .compress_into_with(v, b.reborrow_mut(), scoped_global)
1909 .unwrap();
1910
1911 quantizer
1912 .compress_into_with(v, q.reborrow_mut(), scoped_global)
1913 .unwrap();
1914
1915 quantizer
1916 .compress_into_with(v, f.reborrow_mut(), scoped_global)
1917 .unwrap();
1918
1919 let shifted: Vec<f32> = std::iter::zip(v_normalized.iter(), quantizer.shift().iter())
1920 .map(|(a, b)| a - b)
1921 .collect();
1922
1923 {
1925 let DataMetaF32 {
1926 inner_product_correction,
1927 bit_sum,
1928 metric_specific,
1929 } = b.meta().to_full(ARCH);
1930
1931 let inner_product_with_centroid = metric_specific;
1932
1933 let bv = b.vector();
1935 let s: usize = (0..bv.len()).map(|i| bv.get(i).unwrap() as usize).sum();
1936 assert_eq!(s, bit_sum as usize);
1937
1938 let inner_product: MathematicalValue<f32> =
1941 InnerProduct::evaluate(&*shifted, quantizer.shift());
1942
1943 let abs = (inner_product.into_inner() - inner_product_with_centroid).abs();
1944 let relative = abs / inner_product.into_inner().abs();
1945
1946 assert!(
1947 abs < 1e-7 || relative < IP_BOUND,
1948 "got an abs/rel of {}/{} with a bound of {}/{}",
1949 abs,
1950 relative,
1951 1e-7,
1952 IP_BOUND
1953 );
1954
1955 if const { D == 1 } {
1958 let self_inner_product = 2.0 * (FastL2Norm).evaluate(&*shifted)
1959 / (inner_product_correction * (bv.len() as f32).sqrt());
1960 assert!(
1961 (self_inner_product - 0.8).abs() < 0.11,
1962 "self inner-product should be close to 0.8. Instead, it's {}",
1963 self_inner_product
1964 );
1965 }
1966 }
1967
1968 {
1969 let QueryMeta {
1970 inner_product_correction,
1971 bit_sum,
1972 offset,
1973 metric_specific,
1974 } = q.meta();
1975
1976 let inner_product_with_centroid = metric_specific;
1977 let mut preprocessed = quantizer.preprocess(v, scoped_global).unwrap();
1978 preprocessed
1979 .shifted
1980 .iter_mut()
1981 .for_each(|i| *i /= preprocessed.shifted_norm);
1982
1983 let mut transformed = vec![0.0f32; quantizer.output_dim()];
1984 quantizer
1985 .transform
1986 .transform_into(&mut transformed, &preprocessed.shifted, scoped_global)
1987 .unwrap();
1988
1989 let min = transformed.iter().fold(f32::MAX, |min, &i| min.min(i));
1990 let max = transformed.iter().fold(f32::MIN, |max, &i| max.max(i));
1991
1992 let scale = (max - min) / ((2usize.pow(Q as u32) - 1) as f32);
1993
1994 {
1996 let expected = (FastL2Norm).evaluate(&*shifted) * scale;
1997 let got = inner_product_correction;
1998
1999 let err = (expected - got).abs();
2000 assert!(
2001 err < 1.0e-7,
2002 "\"innerproduct_scale\": expected {}, got {}, error = {}",
2003 expected,
2004 got,
2005 err
2006 );
2007 }
2008
2009 {
2011 let expected = min / scale;
2012 let got = offset;
2013
2014 let err = (expected - got).abs();
2015 assert!(
2016 err < 1.0e-7,
2017 "\"sum_scale\": expected {}, got {}, error = {}",
2018 expected,
2019 got,
2020 err
2021 );
2022 }
2023
2024 {
2026 let expected = (0..q.len())
2027 .map(|i| q.vector().get(i).unwrap())
2028 .sum::<i64>() as f32;
2029
2030 let got = bit_sum;
2031
2032 let err = (expected - got).abs();
2033 assert!(
2034 err < 1.0e-7,
2035 "\"offset\": expected {}, got {}, error = {}",
2036 expected,
2037 got,
2038 err
2039 );
2040 }
2041
2042 {
2044 let inner_product: MathematicalValue<f32> =
2046 InnerProduct::evaluate(&*shifted, quantizer.shift());
2047
2048 let err = (inner_product.into_inner() - inner_product_with_centroid).abs()
2049 / inner_product.into_inner().abs();
2050 assert!(
2051 err < IP_BOUND,
2052 "\"offset\": expected {}, got {}, error = {}",
2053 inner_product.into_inner(),
2054 inner_product_with_centroid,
2055 err
2056 );
2057 }
2058 }
2059
2060 {
2062 let s: f32 = f.data.iter().sum::<f32>();
2064 assert_eq!(s, f.meta.sum);
2065
2066 {
2068 let expected = FastL2Norm.evaluate(&*shifted);
2069 let err = (f.meta.shifted_norm - expected).abs() / expected.abs();
2070 assert!(
2071 err < 2e-7,
2072 "failed diff check, got {}, expected {} - relative error = {}",
2073 f.meta.shifted_norm,
2074 expected,
2075 err
2076 );
2077 }
2078
2079 let inner_product: MathematicalValue<f32> =
2081 InnerProduct::evaluate(&*shifted, quantizer.shift());
2082 let err = (inner_product.into_inner() - f.meta.metric_specific).abs()
2083 / inner_product.into_inner().abs();
2084 assert!(
2085 err < IP_BOUND,
2086 "\"offset\": expected {}, got {}, error = {}",
2087 inner_product.into_inner(),
2088 f.meta.metric_specific,
2089 err
2090 );
2091 }
2092 };
2093
2094 for _ in 0..setup.num_trials {
2095 let i = distribution.sample(rng);
2096 let v = problem.data.row(i);
2097 test_row(v);
2098 }
2099
2100 let zero = vec![0.0f32; quantizer.input_dim()];
2102 test_row(&zero);
2103 }
2104
2105 fn _test_oom_resiliance<T>(quantizer: &SphericalQuantizer, data: &[f32], dst: &mut T)
2106 where
2107 for<'a> T: ReborrowMut<'a>,
2108 for<'a> SphericalQuantizer: CompressIntoWith<
2109 &'a [f32],
2110 <T as ReborrowMut<'a>>::Target,
2111 ScopedAllocator<'a>,
2112 Error = CompressionError,
2113 >,
2114 {
2115 let mut succeeded = false;
2116 let mut failed = false;
2117 for max_allocations in 0..10 {
2118 match quantizer.compress_into_with(
2119 data,
2120 dst.reborrow_mut(),
2121 ScopedAllocator::new(&test_util::LimitedAllocator::new(max_allocations)),
2122 ) {
2123 Ok(()) => {
2124 succeeded = true;
2125 }
2126 Err(CompressionError::AllocatorError(_)) => {
2127 failed = true;
2128 }
2129 Err(other) => {
2130 panic!("received an unexpected error: {:?}", other);
2131 }
2132 }
2133 }
2134 assert!(succeeded);
2135 assert!(failed);
2136 }
2137
2138 fn test_oom_resiliance<const Q: usize, const D: usize, Perm>(
2139 setup: &Setup,
2140 problem: &test_util::TestProblem,
2141 pre_scale: PreScale,
2142 rng: &mut StdRng,
2143 ) where
2144 Unsigned: Representation<Q>,
2145 Unsigned: Representation<D>,
2146 Perm: PermutationStrategy<Q>,
2147 for<'a> SphericalQuantizer: CompressIntoWith<
2148 &'a [f32],
2149 DataMut<'a, D>,
2150 ScopedAllocator<'a>,
2151 Error = CompressionError,
2152 >,
2153 for<'a> SphericalQuantizer: CompressIntoWith<
2154 &'a [f32],
2155 QueryMut<'a, Q, Perm>,
2156 ScopedAllocator<'a>,
2157 Error = CompressionError,
2158 >,
2159 {
2160 assert_eq!(setup.nrows, problem.data.nrows());
2161 assert_eq!(setup.ncols, problem.data.ncols());
2162
2163 let quantizer = SphericalQuantizer::train(
2164 problem.data.as_view(),
2165 setup.transform,
2166 SupportedMetric::SquaredL2,
2167 pre_scale,
2168 rng,
2169 GlobalAllocator,
2170 )
2171 .unwrap();
2172
2173 let data = problem.data.row(0);
2175 _test_oom_resiliance::<Data<D, _>>(
2176 &quantizer,
2177 data,
2178 &mut Data::new_boxed(quantizer.output_dim()),
2179 );
2180 _test_oom_resiliance::<Query<Q, Perm, _>>(
2181 &quantizer,
2182 data,
2183 &mut Query::new_boxed(quantizer.output_dim()),
2184 );
2185 _test_oom_resiliance::<FullQuery<_>>(
2186 &quantizer,
2187 data,
2188 &mut FullQuery::empty(quantizer.output_dim(), GlobalAllocator).unwrap(),
2189 );
2190 }
2191
2192 fn test_quantizer<const Q: usize, const D: usize, Perm>(setup: &Setup, rng: &mut StdRng)
2193 where
2194 Unsigned: Representation<Q>,
2195 Unsigned: Representation<D>,
2196 Perm: PermutationStrategy<Q>,
2197 for<'a> SphericalQuantizer: CompressIntoWith<
2198 &'a [f32],
2199 DataMut<'a, D>,
2200 ScopedAllocator<'a>,
2201 Error = CompressionError,
2202 >,
2203 for<'a> SphericalQuantizer: CompressIntoWith<
2204 &'a [f32],
2205 QueryMut<'a, Q, Perm>,
2206 ScopedAllocator<'a>,
2207 Error = CompressionError,
2208 >,
2209 {
2210 let problem = test_util::create_test_problem(setup.nrows, setup.ncols, rng);
2211 let computed_means_f32: Vec<_> = problem.means.iter().map(|i| *i as f32).collect();
2212
2213 let scales = [
2214 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2215 PreScale::Some(Positive::new(1.0 / 1024.0).unwrap()),
2216 PreScale::ReciprocalMeanNorm,
2217 ];
2218
2219 for scale in scales {
2220 let ctx = &lazy_format!("dim = {}, scale = {:?}", setup.ncols, scale);
2221
2222 test_l2::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng);
2223 test_ip::<Q, D, Perm>(setup, &problem, &computed_means_f32, scale, rng, ctx);
2224 test_cosine::<Q, D, Perm>(setup, &problem, scale, rng);
2225 }
2226
2227 test_oom_resiliance::<Q, D, Perm>(setup, &problem, PreScale::ReciprocalMeanNorm, rng);
2228 }
2229
2230 #[test]
2231 fn test_spherical_quantizer() {
2232 let mut rng = StdRng::seed_from_u64(0xab516aef1ce61640);
2233 for dim in [56, 72, 128, 255] {
2234 let setup = Setup {
2235 transform: TransformKind::PaddingHadamard {
2236 target_dim: TargetDim::Same,
2237 },
2238 nrows: 64,
2239 ncols: dim,
2240 num_trials: 10,
2241 };
2242
2243 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2244 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2245 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2246 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2247
2248 let setup = Setup {
2249 transform: TransformKind::DoubleHadamard {
2250 target_dim: TargetDim::Same,
2251 },
2252 nrows: 64,
2253 ncols: dim,
2254 num_trials: 10,
2255 };
2256 test_quantizer::<4, 1, BitTranspose>(&setup, &mut rng);
2257 test_quantizer::<2, 2, Dense>(&setup, &mut rng);
2258 test_quantizer::<4, 4, Dense>(&setup, &mut rng);
2259 test_quantizer::<8, 8, Dense>(&setup, &mut rng);
2260 }
2261 }
2262
2263 #[test]
2268 fn err_dim_cannot_be_zero() {
2269 let data = Matrix::new(0.0f32, 10, 0);
2270 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2271 let err = SphericalQuantizer::train(
2272 data.as_view(),
2273 TransformKind::DoubleHadamard {
2274 target_dim: TargetDim::Same,
2275 },
2276 SupportedMetric::SquaredL2,
2277 PreScale::None,
2278 &mut rng,
2279 GlobalAllocator,
2280 )
2281 .unwrap_err();
2282 assert_eq!(err.to_string(), "data dim cannot be zero");
2283 }
2284
2285 #[test]
2286 fn err_norm_must_be_positive() {
2287 let data = Matrix::new(0.0f32, 10, 10);
2288 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2289 let err = SphericalQuantizer::train(
2290 data.as_view(),
2291 TransformKind::DoubleHadamard {
2292 target_dim: TargetDim::Same,
2293 },
2294 SupportedMetric::SquaredL2,
2295 PreScale::None,
2296 &mut rng,
2297 GlobalAllocator,
2298 )
2299 .unwrap_err();
2300 assert_eq!(err.to_string(), "norm must be positive");
2301 }
2302
2303 #[test]
2304 fn err_norm_cannot_be_infinity() {
2305 let mut data = Matrix::new(0.0f32, 10, 10);
2306 data[(2, 5)] = f32::INFINITY;
2307
2308 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2309 let err = SphericalQuantizer::train(
2310 data.as_view(),
2311 TransformKind::DoubleHadamard {
2312 target_dim: TargetDim::Same,
2313 },
2314 SupportedMetric::SquaredL2,
2315 PreScale::None,
2316 &mut rng,
2317 GlobalAllocator,
2318 )
2319 .unwrap_err();
2320 assert_eq!(err.to_string(), "computed norm contains infinity or NaN");
2321 }
2322
2323 #[test]
2324 fn err_reciprocal_norm_cannot_be_infinity() {
2325 let mut data = Matrix::new(0.0f32, 10, 10);
2326 data[(2, 5)] = 2.93863e-39;
2327
2328 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2329 let err = SphericalQuantizer::train(
2330 data.as_view(),
2331 TransformKind::DoubleHadamard {
2332 target_dim: TargetDim::Same,
2333 },
2334 SupportedMetric::SquaredL2,
2335 PreScale::ReciprocalMeanNorm,
2336 &mut rng,
2337 GlobalAllocator,
2338 )
2339 .unwrap_err();
2340 assert_eq!(err.to_string(), "reciprocal norm contains infinity or NaN");
2341 }
2342
2343 #[test]
2344 fn err_mean_norm_cannot_be_zero_generate() {
2345 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2346 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2347 let err = SphericalQuantizer::generate(
2348 centroid,
2349 0.0,
2350 TransformKind::DoubleHadamard {
2351 target_dim: TargetDim::Same,
2352 },
2353 SupportedMetric::SquaredL2,
2354 None,
2355 &mut rng,
2356 GlobalAllocator,
2357 )
2358 .unwrap_err();
2359 assert_eq!(err.to_string(), "norm must be positive");
2360 }
2361
2362 #[test]
2363 fn err_scale_cannot_be_zero_generate() {
2364 let centroid = Poly::broadcast(0.0f32, 10, GlobalAllocator).unwrap();
2365 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2366 let err = SphericalQuantizer::generate(
2367 centroid,
2368 1.0,
2369 TransformKind::DoubleHadamard {
2370 target_dim: TargetDim::Same,
2371 },
2372 SupportedMetric::SquaredL2,
2373 Some(0.0),
2374 &mut rng,
2375 GlobalAllocator,
2376 )
2377 .unwrap_err();
2378 assert_eq!(err.to_string(), "pre-scale must be positive");
2379 }
2380
2381 #[test]
2382 fn compression_errors_data() {
2383 let mut rng = StdRng::seed_from_u64(0xe3e9f42ed9f15883);
2384 let data = Matrix::<f32>::new(views::Init(|| StandardNormal {}.sample(&mut rng)), 16, 12);
2385
2386 let quantizer = SphericalQuantizer::train(
2387 data.as_view(),
2388 TransformKind::PaddingHadamard {
2389 target_dim: TargetDim::Same,
2390 },
2391 SupportedMetric::SquaredL2,
2392 PreScale::None,
2393 &mut rng,
2394 GlobalAllocator,
2395 )
2396 .unwrap();
2397
2398 let scoped_global = ScopedAllocator::global();
2399
2400 {
2402 let mut query: Vec<f32> = quantizer.shift().to_vec();
2403 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2404 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2405
2406 for i in 0..query.len() {
2407 let last = query[i];
2408 for v in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
2409 query[i] = v;
2410
2411 let err = quantizer
2412 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2413 .unwrap_err();
2414
2415 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2416
2417 let err = quantizer
2418 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2419 .unwrap_err();
2420
2421 assert_eq!(err.to_string(), "input contains NaN", "failed for {}", v);
2422 }
2423 query[i] = last;
2424 }
2425 }
2426
2427 {
2429 let query: Vec<f32> = vec![1000000.0; quantizer.input_dim()];
2430 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2431
2432 let err = quantizer
2433 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2434 .unwrap_err();
2435
2436 let expected = "encoding error - you may need to scale the entire dataset to reduce its dynamic range";
2437
2438 assert_eq!(err.to_string(), expected, "failed for {:?}", query);
2439 }
2440
2441 for len in [quantizer.input_dim() - 1, quantizer.input_dim() + 1] {
2443 let query = vec![0.0f32; len];
2444 let mut d = Data::<1, _>::new_boxed(quantizer.output_dim());
2445 let mut q = Query::<4, BitTranspose, _>::new_boxed(quantizer.output_dim());
2446
2447 let err = quantizer
2448 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2449 .unwrap_err();
2450 assert_eq!(
2451 err,
2452 CompressionError::SourceDimensionMismatch {
2453 expected: quantizer.input_dim(),
2454 }
2455 );
2456
2457 let err = quantizer
2458 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2459 .unwrap_err();
2460 assert_eq!(
2461 err,
2462 CompressionError::SourceDimensionMismatch {
2463 expected: quantizer.input_dim(),
2464 }
2465 );
2466 }
2467
2468 for len in [quantizer.output_dim() - 1, quantizer.output_dim() + 1] {
2469 let query = vec![0.0f32; quantizer.input_dim()];
2470 let mut d = Data::<1, _>::new_boxed(len);
2471 let mut q = Query::<4, BitTranspose, _>::new_boxed(len);
2472
2473 let err = quantizer
2474 .compress_into_with(&*query, d.reborrow_mut(), scoped_global)
2475 .unwrap_err();
2476 assert_eq!(
2477 err,
2478 CompressionError::DestinationDimensionMismatch {
2479 expected: quantizer.output_dim(),
2480 }
2481 );
2482
2483 let err = quantizer
2484 .compress_into_with(&*query, q.reborrow_mut(), scoped_global)
2485 .unwrap_err();
2486 assert_eq!(
2487 err,
2488 CompressionError::DestinationDimensionMismatch {
2489 expected: quantizer.output_dim(),
2490 }
2491 );
2492 }
2493 }
2494
2495 #[test]
2496 fn centroid_scaling_happens_in_generate() {
2497 let centroid = Poly::from_iter(
2498 [1088.6732f32, 1393.32, 1547.877].into_iter(),
2499 GlobalAllocator,
2500 )
2501 .unwrap();
2502 let mean_norm = 2359.27;
2503 let pre_scale = 1.0 / mean_norm;
2504
2505 let quantizer = SphericalQuantizer::generate(
2506 centroid,
2507 mean_norm,
2508 TransformKind::Null,
2509 SupportedMetric::InnerProduct,
2510 Some(pre_scale),
2511 &mut StdRng::seed_from_u64(10),
2512 GlobalAllocator,
2513 )
2514 .unwrap();
2515
2516 let mut v = Data::<4, _>::new_boxed(quantizer.input_dim());
2517 let data: &[f32] = &[1000.34, 1456.32, 1234.5446];
2518 assert!(quantizer
2519 .compress_into_with(data, v.reborrow_mut(), ScopedAllocator::global())
2520 .is_ok(),
2521 "if this failed, the likely culprit is exceeding the value of the 16-bit correction terms"
2522 );
2523 }
2524}
2525
2526#[cfg(feature = "flatbuffers")]
2527#[cfg(test)]
2528mod test_serialization {
2529 use rand::{rngs::StdRng, SeedableRng};
2530
2531 use super::*;
2532 use crate::{
2533 algorithms::transforms::TargetDim,
2534 flatbuffers::{self as fb, to_flatbuffer},
2535 poly, test_util,
2536 };
2537
2538 #[test]
2539 fn test_serialization_happy_path() {
2540 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2541 let problem = test_util::create_test_problem(10, 128, &mut rng);
2542
2543 let low = NonZeroUsize::new(100).unwrap();
2544 let high = NonZeroUsize::new(150).unwrap();
2545
2546 let kinds = [
2547 TransformKind::Null,
2549 TransformKind::DoubleHadamard {
2551 target_dim: TargetDim::Same,
2552 },
2553 TransformKind::DoubleHadamard {
2554 target_dim: TargetDim::Natural,
2555 },
2556 TransformKind::DoubleHadamard {
2557 target_dim: TargetDim::Override(low),
2558 },
2559 TransformKind::DoubleHadamard {
2560 target_dim: TargetDim::Override(high),
2561 },
2562 TransformKind::PaddingHadamard {
2564 target_dim: TargetDim::Same,
2565 },
2566 TransformKind::PaddingHadamard {
2567 target_dim: TargetDim::Natural,
2568 },
2569 TransformKind::PaddingHadamard {
2570 target_dim: TargetDim::Override(low),
2571 },
2572 TransformKind::PaddingHadamard {
2573 target_dim: TargetDim::Override(high),
2574 },
2575 #[cfg(all(not(miri), feature = "linalg"))]
2577 TransformKind::RandomRotation {
2578 target_dim: TargetDim::Same,
2579 },
2580 #[cfg(all(not(miri), feature = "linalg"))]
2581 TransformKind::RandomRotation {
2582 target_dim: TargetDim::Natural,
2583 },
2584 #[cfg(all(not(miri), feature = "linalg"))]
2585 TransformKind::RandomRotation {
2586 target_dim: TargetDim::Override(low),
2587 },
2588 #[cfg(all(not(miri), feature = "linalg"))]
2589 TransformKind::RandomRotation {
2590 target_dim: TargetDim::Override(high),
2591 },
2592 ];
2593
2594 let pre_scales = [
2595 PreScale::None,
2596 PreScale::Some(Positive::new(0.5).unwrap()),
2597 PreScale::Some(Positive::new(1.0).unwrap()),
2598 PreScale::Some(Positive::new(1.5).unwrap()),
2599 PreScale::ReciprocalMeanNorm,
2600 ];
2601
2602 let alloc = GlobalAllocator;
2603 for kind in kinds.into_iter() {
2604 for metric in SupportedMetric::all() {
2605 for pre_scale in pre_scales {
2606 let quantizer = SphericalQuantizer::train(
2607 problem.data.as_view(),
2608 kind,
2609 metric,
2610 pre_scale,
2611 &mut rng,
2612 alloc,
2613 )
2614 .unwrap();
2615
2616 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2617 let proto =
2618 flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2619 let reloaded = SphericalQuantizer::try_unpack(alloc, proto).unwrap();
2620 assert_eq!(quantizer, reloaded, "failed on transform {:?}", kind);
2621 }
2622 }
2623 }
2624 }
2625
2626 #[test]
2627 fn test_error_checking() {
2628 let mut rng = StdRng::seed_from_u64(0x070d9ff8cf5e0f8c);
2629 let problem = test_util::create_test_problem(10, 128, &mut rng);
2630
2631 let transform = TransformKind::DoubleHadamard {
2632 target_dim: TargetDim::Same,
2633 };
2634
2635 let alloc = GlobalAllocator;
2636 let mut make_quantizer = || {
2637 SphericalQuantizer::train(
2638 problem.data.as_view(),
2639 transform,
2640 SupportedMetric::SquaredL2,
2641 PreScale::None,
2642 &mut rng,
2643 alloc,
2644 )
2645 .unwrap()
2646 };
2647
2648 type E = DeserializationError;
2649
2650 {
2652 let mut quantizer = make_quantizer();
2653 quantizer.mean_norm = unsafe { Positive::new_unchecked(0.0) };
2657
2658 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2659 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2660 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2661 assert_eq!(err, E::MissingNorm);
2662 }
2663
2664 {
2666 let mut quantizer = make_quantizer();
2667
2668 quantizer.mean_norm = unsafe { Positive::new_unchecked(-1.0) };
2672
2673 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2674 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2675 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2676 assert_eq!(err, E::MissingNorm);
2677 }
2678
2679 {
2681 let mut quantizer = make_quantizer();
2682
2683 quantizer.pre_scale = unsafe { Positive::new_unchecked(0.0) };
2686
2687 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2688 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2689 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2690 assert_eq!(err, E::PreScaleNotPositive);
2691 }
2692
2693 {
2695 let mut quantizer = make_quantizer();
2696 quantizer.shift = poly!([1.0, 2.0, 3.0], alloc).unwrap();
2697
2698 let data = to_flatbuffer(|buf| quantizer.pack(buf));
2699 let proto = flatbuffers::root::<fb::spherical::SphericalQuantizer>(&data).unwrap();
2700 let err = SphericalQuantizer::try_unpack(alloc, proto).unwrap_err();
2701 assert_eq!(err, E::DimMismatch);
2702 }
2703 }
2704}