1use diskann_vector::{MathematicalValue, PureDistanceFunction};
7use thiserror::Error;
8
9use super::{
10 bit_scale, inverse_bit_scale,
11 vectors::{
12 CompensatedCosineNormalized, CompensatedIP, CompensatedSquaredL2, Compensation,
13 MutCompensatedVectorRef,
14 },
15};
16use crate::{
17 AsFunctor, CompressInto,
18 bits::{MutBitSlice, PermutationStrategy, Representation, Unsigned},
19};
20
21#[derive(Clone, Debug)]
86pub struct ScalarQuantizer {
87 scale: f32,
89
90 shift: Vec<f32>,
101
102 shift_square_norm: f32,
105
106 mean_norm: Option<f32>,
112}
113
114impl ScalarQuantizer {
115 pub fn new(scale: f32, shift: Vec<f32>, mean_norm: Option<f32>) -> Self {
117 let shift_square_norm: MathematicalValue<f32> =
118 diskann_vector::distance::InnerProduct::evaluate(&*shift, &*shift);
119
120 Self {
121 scale,
122 shift,
123 shift_square_norm: shift_square_norm.into_inner(),
124 mean_norm,
125 }
126 }
127
128 pub fn dim(&self) -> usize {
130 self.shift.len()
131 }
132
133 pub fn scale(&self) -> f32 {
135 self.scale
136 }
137
138 pub fn shift_square_norm(&self) -> f32 {
140 self.shift_square_norm
141 }
142
143 pub fn shift(&self) -> &[f32] {
150 &self.shift
151 }
152
153 pub fn mean_norm(&self) -> Option<f32> {
155 self.mean_norm
156 }
157
158 pub fn rescale(&self, x: &mut [f32]) -> Result<(), MeanNormMissing> {
166 match self.mean_norm {
167 Some(mean_norm) => {
168 rescale(x, mean_norm);
169 Ok(())
170 }
171 None => Err(MeanNormMissing),
172 }
173 }
174
175 fn compress<const NBITS: usize, T, F, Perm>(
191 &self,
192 from: &[T],
193 mut into: MutBitSlice<'_, NBITS, Unsigned, Perm>,
194 mut callback: F,
195 ) -> Result<(), InputContainsNaN>
196 where
197 T: Copy + Into<f32>,
198 F: FnMut(f32, usize),
199 Unsigned: Representation<NBITS>,
200 Perm: PermutationStrategy<NBITS>,
201 {
202 let len = self.shift.len();
203 assert_eq!(from.len(), len);
204 assert_eq!(into.len(), len);
205
206 let domain = Unsigned::domain_const::<NBITS>();
207 let min = *domain.start() as f32;
208 let max = *domain.end() as f32;
209 let inverse_scale = bit_scale::<NBITS>() / (self.scale);
210 let mut nan_check = false;
211
212 std::iter::zip(from.iter(), self.shift.iter())
213 .enumerate()
214 .for_each(|(i, (&f, &s))| {
215 let f: f32 = f.into();
218 nan_check |= f.is_nan();
219
220 let code: f32 = ((f - s) * inverse_scale).clamp(min, max).round();
221
222 callback(code, i);
224
225 unsafe { into.set_unchecked(i, code as u8) };
231 });
232
233 if nan_check {
234 Err(InputContainsNaN)
235 } else {
236 Ok(())
237 }
238 }
239
240 pub fn compare(&self, other: &Self) -> Result<(), SQComparisonError> {
244 if self.scale != other.scale {
245 return Err(SQComparisonError::Scale(self.scale, other.scale));
246 }
247
248 if self.shift.len() != other.shift.len() {
249 return Err(SQComparisonError::ShiftLength(
250 self.shift.len(),
251 other.shift.len(),
252 ));
253 }
254
255 for (i, (a, b)) in self.shift.iter().zip(other.shift.iter()).enumerate() {
256 if a != b {
257 return Err(SQComparisonError::ShiftElement {
258 index: i,
259 a: *a,
260 b: *b,
261 });
262 }
263 }
264
265 if self.shift_square_norm != other.shift_square_norm {
266 return Err(SQComparisonError::ShiftSquareNorm(
267 self.shift_square_norm,
268 other.shift_square_norm,
269 ));
270 }
271
272 match (&self.mean_norm, &other.mean_norm) {
273 (Some(a), Some(b)) => {
274 if a != b {
275 return Err(SQComparisonError::MeanNorm(*a, *b));
276 }
277 }
278 (None, None) => {
279 }
281 _ => {
282 return Err(SQComparisonError::MeanNormPresence);
283 }
284 }
285
286 Ok(())
287 }
288}
289
290#[derive(Debug, Error, Clone, Copy)]
291#[error("mean norm is missing from the quantizer")]
292#[non_exhaustive]
293pub struct MeanNormMissing;
294
295#[derive(Debug, Error, Clone, Copy)]
296#[error("input contains NaN")]
297#[non_exhaustive]
298pub struct InputContainsNaN;
299
300fn rescale(x: &mut [f32], to_norm: f32) {
301 let norm_square: MathematicalValue<f32> =
302 diskann_vector::distance::InnerProduct::evaluate(&*x, &*x);
303 let norm = norm_square.into_inner().sqrt();
304 if norm == 0.0 {
305 return;
306 }
307
308 let scale = to_norm / norm;
309 x.iter_mut().for_each(|i| (*i) *= scale);
310}
311
312impl AsFunctor<CompensatedSquaredL2> for ScalarQuantizer {
317 fn as_functor(&self) -> CompensatedSquaredL2 {
318 let scale = self.scale();
319 CompensatedSquaredL2::new(scale * scale)
320 }
321}
322
323impl AsFunctor<CompensatedIP> for ScalarQuantizer {
324 fn as_functor(&self) -> CompensatedIP {
325 let scale = self.scale();
326 CompensatedIP::new(scale * scale, self.shift_square_norm())
327 }
328}
329
330impl AsFunctor<CompensatedCosineNormalized> for ScalarQuantizer {
331 fn as_functor(&self) -> CompensatedCosineNormalized {
332 let scale = self.scale();
333 CompensatedCosineNormalized::new(scale * scale)
334 }
335}
336
337impl<const NBITS: usize, T, Perm> CompressInto<&[T], MutBitSlice<'_, NBITS, Unsigned, Perm>>
342 for ScalarQuantizer
343where
344 T: Copy + Into<f32>,
345 Unsigned: Representation<NBITS>,
346 Perm: PermutationStrategy<NBITS>,
347{
348 type Error = InputContainsNaN;
349
350 type Output = ();
351
352 fn compress_into(
370 &self,
371 from: &[T],
372 into: MutBitSlice<'_, NBITS, Unsigned, Perm>,
373 ) -> Result<(), Self::Error> {
374 ScalarQuantizer::compress(self, from, into, |_, _| {})
377 }
378}
379
380impl<const NBITS: usize, T, Perm> CompressInto<&[T], MutCompensatedVectorRef<'_, NBITS, Perm>>
381 for ScalarQuantizer
382where
383 T: Copy + Into<f32>,
384 Unsigned: Representation<NBITS>,
385 Perm: PermutationStrategy<NBITS>,
386{
387 type Error = InputContainsNaN;
388
389 type Output = ();
390
391 fn compress_into(
408 &self,
409 from: &[T],
410 mut into: MutCompensatedVectorRef<'_, NBITS, Perm>,
411 ) -> Result<(), Self::Error> {
412 let mut dot: f32 = 0.0;
417 let result = ScalarQuantizer::compress(
418 self,
419 from,
420 into.vector_mut(),
421 |code: f32, index: usize| {
423 dot = code.mul_add(self.shift[index], dot);
424 },
425 );
426 into.set_meta(Compensation(
427 self.scale * inverse_bit_scale::<NBITS>() * dot,
428 ));
429 result
430 }
431}
432
433#[derive(Debug, Error, PartialEq)]
434pub enum SQComparisonError {
435 #[error("Scale mismatch: {0} vs {1}")]
436 Scale(f32, f32),
437
438 #[error("Shift vector length mismatch: {0} vs {1}")]
439 ShiftLength(usize, usize),
440
441 #[error("Shift element at index {index} mismatch: {a} vs {b}")]
442 ShiftElement { index: usize, a: f32, b: f32 },
443
444 #[error("Shift square norm mismatch: {0} vs {1}")]
445 ShiftSquareNorm(f32, f32),
446
447 #[error("Mean norm mismatch: {0} vs {1}")]
448 MeanNorm(f32, f32),
449
450 #[error("Mean norm is missing in one quantizer but present in the other")]
451 MeanNormPresence,
452}
453
454#[cfg(test)]
459mod tests {
460 use std::collections::HashSet;
461
462 use diskann_utils::{ReborrowMut, views};
463
464 use rand::{
465 SeedableRng,
466 distr::{Distribution, Uniform},
467 rngs::StdRng,
468 seq::SliceRandom,
469 };
470 use rand_distr::Normal;
471
472 use super::*;
473 use crate::{
474 bits::BoxedBitSlice,
475 scalar::{CompensatedVector, inverse_bit_scale},
476 };
477
478 #[test]
480 fn test_rescale() {
481 let dim = 32;
482 let to_norm = 25.0;
483
484 let mut rng = StdRng::seed_from_u64(0x64e956ca2eb726ee);
485 let distribution = Normal::<f32>::new(0.0, 16.0).unwrap();
486
487 let mut v: Vec<f32> = distribution.sample_iter(&mut rng).take(dim).collect();
488 let norm = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
489
490 rescale(&mut v, to_norm);
491 let norm_next = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
492 let relative_error = (norm_next - to_norm).abs() / to_norm;
493
494 assert!(
495 relative_error <= 1.0e-7,
496 "vector was not renormalized, expected {}, got {}, started with {}. Relative error: {}",
497 to_norm,
498 norm_next,
499 norm,
500 relative_error,
501 );
502
503 let mut v: Vec<f32> = vec![0.0; dim];
505 rescale(&mut v, 10.0);
506 assert!(v.iter().all(|&i| i == 0.0));
507
508 let mut quantizer = ScalarQuantizer::new(0.0, vec![0.0; dim], Some(to_norm));
510
511 let mut v: Vec<f32> = distribution.sample_iter(&mut rng).take(dim).collect();
512 let norm = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
513
514 quantizer.rescale(&mut v).unwrap();
515 let norm_next = v.iter().map(|&i| i * i).sum::<f32>().sqrt();
516 let relative_error = (norm_next - to_norm).abs() / to_norm;
517
518 assert!(
519 relative_error <= 1.0e-7,
520 "vector was not renormalized, expected {}, got {}, started with {}. Relative error: {}",
521 to_norm,
522 norm_next,
523 norm,
524 relative_error,
525 );
526
527 let mut v: Vec<f32> = vec![0.0; dim];
529 quantizer.rescale(&mut v).unwrap();
530 assert!(v.iter().all(|&i| i == 0.0));
531
532 quantizer.mean_norm = None;
534 let r = quantizer.rescale(&mut v);
535 assert!(matches!(r, Err(MeanNormMissing)));
536 }
537
538 fn test_nbit_quantizer<const NBITS: usize>(dim: usize, rng: &mut StdRng)
554 where
555 Unsigned: Representation<NBITS>,
556 ScalarQuantizer: for<'a, 'b> CompressInto<&'a [f32], MutBitSlice<'b, NBITS, Unsigned>>
557 + for<'a, 'b> CompressInto<&'a [f32], MutCompensatedVectorRef<'b, NBITS>>,
558 {
559 let distribution = Uniform::new_inclusive::<i64, i64>(-10, 10).unwrap();
560 let shift: Vec<f32> = (0..dim).map(|_| distribution.sample(rng) as f32).collect();
561 let scale: f32 = 2.0;
562 let mean_norm: f32 = 1.0;
563
564 let quantizer =
565 ScalarQuantizer::new(scale * bit_scale::<NBITS>(), shift.clone(), Some(mean_norm));
566
567 assert_eq!(quantizer.dim(), dim);
568 assert_eq!(quantizer.scale(), scale * bit_scale::<NBITS>());
569 assert_eq!(quantizer.shift(), shift);
570 assert_eq!(quantizer.mean_norm().unwrap(), mean_norm);
571
572 let expected_shift_norm: f32 = shift.iter().map(|&i| i * i).sum();
573 assert_eq!(quantizer.shift_square_norm(), expected_shift_norm);
574
575 {
577 let l2: CompensatedSquaredL2 = quantizer.as_functor();
578 assert_eq!(l2.scale_squared, quantizer.scale() * quantizer.scale());
579
580 let ip: CompensatedIP = quantizer.as_functor();
581 assert_eq!(ip.scale_squared, quantizer.scale() * quantizer.scale());
582 assert_eq!(ip.shift_square_norm, quantizer.shift_square_norm());
583 }
584
585 let sample_points: f32 = 1.25 * (2_usize.pow(NBITS as u32) as f32) + 10.0;
589
590 let min_encodable: f32 = 0.0;
591 let max_encodable: f32 = (*Unsigned::domain_const::<NBITS>().end() as f32) * scale;
592
593 let dim_offsets: views::Matrix<f32> = {
597 let range_min = -min_encodable - 3.0 * scale;
598 let range_max = max_encodable + 3.0 * scale;
599 let mut base: Vec<f32> = Vec::new();
600
601 let step_size = (range_max - range_min) / sample_points;
602 let mut i: f32 = range_min;
603 while i < range_max {
604 base.push(i);
605 i += step_size;
606 }
607 base.push(i);
609
610 let mut output = views::Matrix::new(0.0, base.len(), dim);
611 (0..dim).for_each(|j| {
612 base.shuffle(rng);
613 for (i, b) in base.iter().enumerate() {
614 output[(i, j)] = *b;
615 }
616 });
617 output
618 };
619 let ntests = dim_offsets.nrows();
620 assert!(ntests as f32 >= sample_points);
621
622 let mut seen_below_min = false;
624 let mut seen_above_max = false;
625 let mut seen: Vec<HashSet<i64>> = (0..dim).map(|_| HashSet::new()).collect();
626
627 let mut query: Vec<f32> = vec![0.0; dim];
629 for test_number in 0..ntests {
630 let offsets = dim_offsets.row(test_number);
631 query
632 .iter_mut()
633 .zip(std::iter::zip(shift.iter(), offsets.iter()))
634 .for_each(|(q, (c, o))| {
635 *q = *c + *o;
636 });
637
638 let mut bitslice = BoxedBitSlice::<NBITS, _>::new_boxed(dim);
640 let mut compensated = CompensatedVector::<NBITS>::new_boxed(dim);
641
642 quantizer
643 .compress_into(&*query, bitslice.reborrow_mut())
644 .unwrap();
645 quantizer
646 .compress_into(&*query, compensated.reborrow_mut())
647 .unwrap();
648
649 let domain = Unsigned::domain_const::<NBITS>();
651
652 let mut computed_compensation: f32 = 0.0;
654 for d in 0..dim {
655 let code = bitslice.get(d).unwrap();
656 computed_compensation = (code as f32).mul_add(shift[d], computed_compensation);
657
658 seen[d].insert(code);
660
661 let offset = offsets[d];
662 if offset <= min_encodable {
663 assert_eq!(
664 code,
665 *domain.start(),
666 "expected values below threshold to be set to zero \
667 test_number = {}, dim = {} of {}, offset = {}, scale = {}",
668 test_number,
669 d,
670 dim,
671 offset,
672 scale,
673 );
674 seen_below_min = true;
675 } else if offset >= max_encodable {
676 assert_eq!(
677 code,
678 *domain.end(),
679 "expected values below threshold to be set to max value \
680 test_number = {}, dim = {} of {}, offset = {}, scale = {}",
681 test_number,
682 d,
683 dim,
684 offset,
685 scale,
686 );
687 seen_above_max = true;
688 } else {
689 let reconstructed =
692 quantizer.scale() * (code as f32) * inverse_bit_scale::<NBITS>();
693 let error = (offset - reconstructed).abs();
694 assert!(
695 error <= scale / 2.0,
696 "failed reconstruction check: \
697 test_number = {}, dim = {} of {}, offset = {}, scale = {} \
698 code = {}, reconstructed = {}, error = {}",
699 test_number,
700 d,
701 dim,
702 offset,
703 scale,
704 code,
705 reconstructed,
706 error,
707 );
708 }
709
710 assert_eq!(
713 compensated.vector().get(d).unwrap(),
714 code,
715 "compensated disagrees with bitslice"
716 );
717 }
718 assert_eq!(scale * computed_compensation, compensated.meta().0);
719 }
720
721 assert!(seen_below_min);
723 assert!(seen_above_max);
724 let num_codes = 2usize.pow(NBITS as u32);
725 for (i, s) in seen.iter().enumerate() {
726 assert_eq!(
727 s.len(),
728 num_codes,
729 "dimension {} did not have full coverage",
730 i
731 );
732 }
733
734 {
736 let mut query: Vec<f32> = shift.clone();
737 let mut bitslice = BoxedBitSlice::<NBITS, _>::new_boxed(query.len());
738 let mut compensated = CompensatedVector::<NBITS>::new_boxed(query.len());
739 for i in 0..query.len() {
740 let last = query[i];
741 query[i] = f32::NAN;
742
743 let err = quantizer
744 .compress_into(&*query, bitslice.reborrow_mut())
745 .unwrap_err();
746 assert_eq!(err.to_string(), "input contains NaN");
747
748 let err = quantizer
749 .compress_into(&*query, compensated.reborrow_mut())
750 .unwrap_err();
751 assert_eq!(err.to_string(), "input contains NaN");
752
753 query[i] = last;
754 }
755 }
756 }
757
758 cfg_if::cfg_if! {
759 if #[cfg(miri)] {
760 const TEST_DIM: usize = 2;
761 } else {
762 const TEST_DIM: usize = 10;
763 }
764 }
765
766 macro_rules! test_quantizer {
767 ($name:ident, $nbits:literal, $seed:literal) => {
768 #[test]
769 fn $name() {
770 let mut rng = StdRng::seed_from_u64($seed);
771 test_nbit_quantizer::<$nbits>(TEST_DIM, &mut rng);
772 }
773 };
774 }
775
776 test_quantizer!(test_8bit_quantizer, 8, 0xb7b4c124102b9fb9);
777 test_quantizer!(test_7bit_quantizer, 7, 0x86d19a821fe934d1);
778 test_quantizer!(test_6bit_quantizer, 6, 0x0de9610f0b9be4f7);
779 test_quantizer!(test_5bit_quantizer, 5, 0x605ed3e7ed775047);
780 test_quantizer!(test_4bit_quantizer, 4, 0x9b66ace7090fa728);
781 test_quantizer!(test_3bit_quantizer, 3, 0x0ce424ddc61ebdb0);
782 test_quantizer!(test_2bit_quantizer, 2, 0x2ba8e5ef6415d4f0);
783 test_quantizer!(test_1bit_quantizer, 1, 0xdcd8c10c4a407956);
784
785 fn base_quantizer() -> ScalarQuantizer {
786 ScalarQuantizer {
787 scale: 2.0,
788 shift: vec![1.0, -1.0, 0.5],
789 shift_square_norm: 1.0_f32 * 1.0 + (-1.0_f32) * (-1.0) + 0.5_f32 * 0.5,
790 mean_norm: Some(4.13),
791 }
792 }
793
794 #[test]
795 fn test_compare_identical_returns_ok() {
796 let q1 = base_quantizer();
797 let q2 = base_quantizer();
798 assert!(q1.compare(&q2).is_ok());
799 }
800
801 #[test]
802 fn test_compare_scale_mismatch() {
803 let q1 = base_quantizer();
804 let mut q2 = base_quantizer();
805 q2.scale = 4.0;
806 let err = q1.compare(&q2).unwrap_err();
807 assert_eq!(err, SQComparisonError::Scale(2.0, 4.0));
808 }
809
810 #[test]
811 fn test_compare_shift_length_mismatch() {
812 let q1 = base_quantizer();
813 let mut q2 = base_quantizer();
814 q2.shift.push(0.0);
815 let err = q1.compare(&q2).unwrap_err();
816 assert_eq!(
817 err,
818 SQComparisonError::ShiftLength(q1.shift.len(), q2.shift.len())
819 );
820 }
821
822 #[test]
823 fn test_compare_shift_element_mismatch() {
824 let q1 = base_quantizer();
825 let mut q2 = base_quantizer();
826 q2.shift[2] = 0.0;
827 let err = q1.compare(&q2).unwrap_err();
828 match err {
829 SQComparisonError::ShiftElement { index, a, b } => {
830 assert_eq!(index, 2);
831 assert_eq!(a, 0.5);
832 assert_eq!(b, 0.0);
833 }
834 _ => panic!("Expected ShiftElementMismatch variant"),
835 }
836 }
837
838 #[test]
839 fn test_compare_shift_square_norm_mismatch() {
840 let q1 = base_quantizer();
841 let mut q2 = base_quantizer();
842 q2.shift_square_norm = 9.0;
843 let err = q1.compare(&q2).unwrap_err();
844 assert_eq!(err, SQComparisonError::ShiftSquareNorm(2.25, 9.0));
845 }
846
847 #[test]
848 fn test_compare_mean_norm_value_mismatch() {
849 let q1 = base_quantizer();
850 let mut q2 = base_quantizer();
851 q2.mean_norm = Some(1.0);
852 let err = q1.compare(&q2).unwrap_err();
853 assert_eq!(err, SQComparisonError::MeanNorm(4.13, 1.0));
854 }
855
856 #[test]
857 fn test_compare_mean_norm_presence_mismatch_left_none() {
858 let mut q1 = base_quantizer();
859 let q2 = base_quantizer();
860 q1.mean_norm = None;
861 let err = q1.compare(&q2).unwrap_err();
862 assert_eq!(err, SQComparisonError::MeanNormPresence);
863 }
864
865 #[test]
866 fn test_compare_mean_norm_presence_mismatch_right_none() {
867 let q1 = base_quantizer();
868 let mut q2 = base_quantizer();
869 q2.mean_norm = None;
870 let err = q1.compare(&q2).unwrap_err();
871 assert_eq!(err, SQComparisonError::MeanNormPresence);
872 }
873}