1use super::vectors::{DataMutRef, MinMaxCompensation, MinMaxIP, MinMaxL2Squared};
7use core::f32;
8use diskann_utils::views::MutDenseData;
9
10use crate::{
11 algorithms::Transform,
12 alloc::{GlobalAllocator, ScopedAllocator},
13 bits::{Representation, Unsigned},
14 minmax::{vectors::FullQueryMeta, FullQuery, MinMaxCosine, MinMaxCosineNormalized},
15 num::Positive,
16 scalar::{bit_scale, InputContainsNaN},
17 AsFunctor, CompressInto,
18};
19
20pub struct MinMaxQuantizer {
52 transform: Transform<GlobalAllocator>,
56
57 grid_scale: Positive<f32>,
60}
61
62impl MinMaxQuantizer {
63 pub fn new(transform: Transform<GlobalAllocator>, grid_scale: Positive<f32>) -> Self {
65 Self {
66 transform,
67 grid_scale,
68 }
69 }
70
71 pub fn dim(&self) -> usize {
73 self.transform.input_dim()
74 }
75
76 pub fn output_dim(&self) -> usize {
81 self.transform.output_dim()
82 }
83
84 fn get_range<const NBITS: usize>(&self, vec: &[f32]) -> (f32, f32) {
100 let (min, max) = match NBITS {
101 1 => {
102 let (mut min, mut min_count) = (0.0f32, 0.0f32);
103 let (mut max, mut max_count) = (0.0f32, 0.0f32);
104
105 let mean = vec.iter().sum::<f32>() / (vec.len() as f32);
106
107 vec.iter().for_each(|x| {
108 let m = f32::from((*x < mean) as u8);
109 min += m * x;
110 min_count += m;
111 max += (1.0 - m) * x;
112 max_count += 1.0 - m;
113 });
114
115 ((min / min_count).min(mean), (max / max_count).max(mean))
116 }
117 _ => {
118 vec .iter()
120 .fold((f32::NAN, f32::NAN), |(cmin, cmax), &e| {
121 (cmin.min(e), cmax.max(e))
122 })
123 }
124 };
125
126 let width = (max - min) / 2.0;
127 let mid = min + width;
128
129 (
130 mid - width * self.grid_scale.into_inner(),
131 mid + width * self.grid_scale.into_inner(),
132 )
133 }
134
135 fn compress<const NBITS: usize, T>(
136 &self,
137 from: &[T],
138 mut into: DataMutRef<'_, NBITS>,
139 ) -> Result<L2Loss, InputContainsNaN>
140 where
141 T: Copy + Into<f32>,
142 Unsigned: Representation<NBITS>,
143 {
144 let mut into_vec = into.vector_mut();
145
146 assert_eq!(from.len(), self.dim());
147 assert_eq!(self.output_dim(), into_vec.len());
148
149 let domain = Unsigned::domain_const::<NBITS>();
150 let domain_min = *domain.start() as f32;
151 let domain_max = *domain.end() as f32;
152
153 let mut vec = vec![f32::default(); self.output_dim()];
154
155 #[allow(clippy::unwrap_used)]
157 self.transform
158 .transform_into(
159 &mut vec,
160 &from.iter().map(|&x| x.into()).collect::<Vec<f32>>(),
161 ScopedAllocator::global(),
162 )
163 .unwrap();
164
165 let (min, max) = self.get_range::<NBITS>(&vec);
166
167 let inverse_scale = (max - min).max(1e-8) / bit_scale::<NBITS>(); let mut norm_squared: f32 = 0.0;
169 let mut code_sum: f32 = 0.0;
170 let mut loss: f32 = 0.0;
171
172 let mut nan_check = false;
173
174 vec.iter().enumerate().for_each(|(i, &v)| {
175 nan_check |= v.is_nan();
176
177 let code = ((v - min) / inverse_scale)
178 .clamp(domain_min, domain_max)
179 .round();
180
181 let v_r = (code * inverse_scale) + min; norm_squared += v_r * v_r;
183 code_sum += code;
184 loss += (v_r - v).powi(2);
185
186 unsafe {
188 into_vec.set_unchecked(i, code as u8);
189 }
190 });
191
192 let meta = MinMaxCompensation {
193 dim: self.output_dim() as u32,
194 b: min,
195 a: inverse_scale,
196 n: inverse_scale * code_sum,
197 norm_squared,
198 };
199
200 into.set_meta(meta);
201
202 if nan_check {
203 Err(InputContainsNaN)
204 } else {
205 Ok(match Positive::new(loss) {
206 Ok(p) => L2Loss::Positive(p),
207 Err(_) => L2Loss::Zero,
208 })
209 }
210 }
211}
212
213#[derive(Clone, Copy, Debug)]
222pub enum L2Loss {
223 Zero,
224 Positive(Positive<f32>),
225}
226
227impl L2Loss {
228 pub fn as_f32(&self) -> f32 {
230 match self {
231 L2Loss::Zero => 0.0,
232 L2Loss::Positive(p) => p.into_inner(),
233 }
234 }
235}
236
237impl<const NBITS: usize, T> CompressInto<&[T], DataMutRef<'_, NBITS>> for MinMaxQuantizer
238where
239 T: Copy + Into<f32>,
240 Unsigned: Representation<NBITS>,
241{
242 type Error = InputContainsNaN;
243
244 type Output = L2Loss;
245
246 fn compress_into(&self, from: &[T], to: DataMutRef<'_, NBITS>) -> Result<L2Loss, Self::Error> {
263 self.compress::<NBITS, T>(from, to)
264 }
265}
266
267impl<T> CompressInto<&[T], &mut FullQuery> for MinMaxQuantizer
268where
269 T: Copy + Into<f32>,
270{
271 type Error = InputContainsNaN;
272
273 type Output = ();
274
275 fn compress_into(&self, from: &[T], to: &mut FullQuery) -> Result<(), Self::Error> {
292 assert_eq!(from.len(), self.dim());
293 assert_eq!(self.output_dim(), to.len());
294
295 let from: Vec<f32> = from.iter().map(|&x| x.into()).collect();
297 if from.iter().any(|x| x.is_nan()) {
298 return Err(InputContainsNaN);
299 }
300
301 #[allow(clippy::unwrap_used)]
303 self.transform
304 .transform_into(to.data.as_mut_slice(), &from, ScopedAllocator::global())
305 .unwrap();
306
307 let norm_squared = to.data.iter().map(|x| *x * *x).sum::<f32>();
308 let sum = to.data.iter().sum::<f32>();
309
310 to.meta = FullQueryMeta { norm_squared, sum };
311
312 Ok(())
313 }
314}
315
316macro_rules! impl_functor {
321 ($dist:ident) => {
322 impl AsFunctor<$dist> for MinMaxQuantizer {
323 fn as_functor(&self) -> $dist {
325 $dist
326 }
327 }
328 };
329}
330
331impl_functor!(MinMaxIP);
332impl_functor!(MinMaxL2Squared);
333impl_functor!(MinMaxCosine);
334impl_functor!(MinMaxCosineNormalized);
335
336#[cfg(test)]
340mod minmax_quantizer_tests {
341 use std::num::NonZeroUsize;
342
343 use diskann_utils::{Reborrow, ReborrowMut};
344 use diskann_vector::{distance::SquaredL2, PureDistanceFunction};
345 use rand::{
346 distr::{Distribution, Uniform},
347 rngs::StdRng,
348 SeedableRng,
349 };
350
351 use super::*;
352 use crate::{
353 algorithms::transforms::NullTransform,
354 minmax::vectors::{Data, DataRef},
355 };
356
357 fn reconstruct_minmax<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
358 where
359 Unsigned: Representation<NBITS>,
360 {
361 (0..v.len())
362 .map(|i| {
363 let m = v.meta();
364 v.vector().get(i).unwrap() as f32 * m.a + m.b
365 })
366 .collect()
367 }
368
369 fn test_quantizer_encoding_random<const NBITS: usize>(
370 dim: usize,
371 rng: &mut StdRng,
372 relative_err: f32,
373 scale: f32,
374 ) where
375 Unsigned: Representation<NBITS>,
376 MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>
377 + for<'a, 'b> CompressInto<&'a [f32], &'b mut FullQuery, Output = ()>,
378 {
379 let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
380
381 let quantizer = MinMaxQuantizer::new(
382 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
383 Positive::new(scale).unwrap(),
384 );
385
386 assert_eq!(quantizer.dim(), dim);
387
388 let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
389
390 let mut encoded = Data::new_boxed(dim);
391 let loss = quantizer
392 .compress_into(&*vector, encoded.reborrow_mut())
393 .unwrap();
394
395 let reconstructed = reconstruct_minmax::<NBITS>(encoded.reborrow());
396 assert_eq!(reconstructed.len(), dim);
397
398 let reconstruction_error: f32 = SquaredL2::evaluate(&*vector, &*reconstructed);
399 let norm = vector.iter().map(|x| x * x).sum::<f32>();
400 assert!(
401 (reconstruction_error / norm) <= relative_err,
402 "Expected vector : {:?} to be reconstructed within error {} but instead got : {:?}, with error {} for dim : {}",
403 &vector,
404 relative_err,
405 &reconstructed,
406 reconstruction_error / norm,
407 dim,
408 );
409
410 assert!((loss.as_f32() - reconstruction_error) <= 1e-4);
411
412 let expected_code_sum = (0..dim)
413 .map(|i| encoded.vector().get(i).unwrap() as f32)
414 .sum::<f32>();
415 let code_sum = encoded.reborrow().meta().n / encoded.reborrow().meta().a;
416 assert!(
417 (code_sum - expected_code_sum).abs() <= 2e-5 * (dim as f32),
418 "Encoded vector with dim : {dim} is {:?}, got error : {} for vector : {:?}",
419 encoded.reborrow(),
420 (code_sum - expected_code_sum).abs(),
421 &vector,
422 );
423 let recon_norm_sq = reconstructed.iter().map(|x| x * x).sum::<f32>();
424 assert!((encoded.reborrow().meta().norm_squared - recon_norm_sq).abs() <= 1e-3);
425
426 let mut f = FullQuery::empty(dim);
428 quantizer
429 .compress_into(vector.as_slice(), f.reborrow_mut())
430 .unwrap();
431
432 f.data
433 .iter()
434 .enumerate()
435 .zip(vector.iter())
436 .for_each(|((i, x), y)| {
437 assert!(
438 (*x - *y).abs() < 1e-10,
439 "Full Query did not compress dimension {i} with value {} correctly, got {} instead.",
440 *y,
441 *x,
442 )
443 });
444
445 assert!(
446 (f.meta.norm_squared - norm).abs() < 1e-10,
447 "Full Query norm in meta should be {norm} but instead got {}",
448 f.meta.norm_squared
449 );
450
451 let sum = vector.iter().sum::<f32>();
452 assert!(
453 (f.meta.sum - sum) < 1e-10,
454 "Full Query norm in meta should be {sum} but instead got {}",
455 f.meta.sum
456 );
457 }
458
459 cfg_if::cfg_if! {
460 if #[cfg(miri)] {
461 const TRIALS: usize = 2;
465 } else {
466 const TRIALS: usize = 10;
467 }
468 }
469
470 macro_rules! test_minmax_quantizer_encoding {
471 ($name:ident, $dim:literal, $nbits:literal, $seed:literal, $err:expr) => {
472 #[test]
473 fn $name() {
474 let mut rng = StdRng::seed_from_u64($seed);
475 let scales = [1.0, 1.1, 0.9];
476 for (s, e) in scales.iter().zip($err) {
477 for d in 10..$dim {
478 for _ in 0..TRIALS {
479 test_quantizer_encoding_random::<$nbits>(d, &mut rng, e, *s);
480 }
481 }
482 }
483 }
484 };
485 }
486 test_minmax_quantizer_encoding!(
487 test_minmax_encoding_1bit,
488 100,
489 1,
490 0xa32d5658097a1c35,
491 vec![0.5, 0.5, 0.5]
492 );
493 test_minmax_quantizer_encoding!(
494 test_minmax_encoding_2bit,
495 100,
496 2,
497 0xf60c0c8d1aadc126,
498 vec![0.5, 0.5, 0.5]
499 );
500 test_minmax_quantizer_encoding!(
501 test_minmax_encoding_4bit,
502 100,
503 4,
504 0x09fa14c42a9d7d98,
505 vec![1.0e-2, 1.0e-2, 3.0e-2]
506 );
507 test_minmax_quantizer_encoding!(
508 test_minmax_encoding_8bit,
509 100,
510 8,
511 0xaedf3d2a223b7b77,
512 vec![2.0e-3, 2.0e-3, 7.0e-3]
513 );
514
515 macro_rules! expand_to_bitrates {
516 ($name:ident, $func:ident) => {
517 #[test]
518 fn $name() {
519 $func::<1>();
520 $func::<2>();
521 $func::<4>();
522 $func::<8>();
523 }
524 };
525 }
526
527 fn test_all_same_value_vector<const NBITS: usize>()
529 where
530 Unsigned: Representation<NBITS>,
531 MinMaxQuantizer:
532 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
533 {
534 let dim = 30;
535 let quantizer = MinMaxQuantizer::new(
536 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
537 Positive::new(1.0).unwrap(),
538 );
539 let constant_value = 42.5f32;
540 let vector = vec![constant_value; dim];
541
542 let mut encoded = Data::new_boxed(dim);
543 let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
544
545 assert!(
546 result.is_ok(),
547 "Constant-value vector should compress successfully"
548 );
549
550 assert!(result.unwrap().as_f32().abs() <= 1e-6);
551
552 let reconstructed = reconstruct_minmax(encoded.reborrow());
554 for &val in &reconstructed {
555 assert!(
556 (val - constant_value).abs() < 1e-3,
557 "Reconstructed value {} should be close to original {}. Compressed vector is {:?}",
558 val,
559 constant_value,
560 encoded.meta(),
561 );
562 }
563 }
564
565 fn test_two_distinct_values<const NBITS: usize>()
567 where
568 Unsigned: Representation<NBITS>,
569 MinMaxQuantizer:
570 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
571 {
572 let dim = 20;
573 let quantizer = MinMaxQuantizer::new(
574 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
575 Positive::new(1.0).unwrap(),
576 );
577
578 let val1 = -10.0f32;
579 let val2 = 15.0f32;
580 let mut vector = vec![val1; dim];
581 for i in vector.iter_mut().skip(dim) {
583 *i = val2;
584 }
585
586 let mut encoded = Data::new_boxed(dim);
587 let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
588
589 assert!(
590 result.is_ok(),
591 "Two-value vector should compress successfully"
592 );
593
594 assert!(result.unwrap().as_f32().abs() <= 1e-6);
595
596 let mut codes_used = std::collections::HashSet::new();
598 for i in 0..dim {
599 codes_used.insert(encoded.vector().get(i).unwrap());
600 }
601
602 if NBITS > 1 {
604 assert!(
605 codes_used.len() <= 2,
606 "Should use at most 2 distinct codes for 2-value input, but used: {:?}",
607 codes_used
608 );
609 }
610
611 let reconstructed = reconstruct_minmax(encoded.reborrow());
613 for ((i, val), v) in reconstructed.into_iter().enumerate().zip(&vector) {
614 assert!(
616 (val - v).abs() < 1e-4,
617 "Reconstructed value in dim : {i} is {val}, when it should be {v}."
618 );
619 }
620 }
621
622 fn test_nan_input_error<const NBITS: usize>()
625 where
626 Unsigned: Representation<NBITS>,
627 MinMaxQuantizer:
628 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
629 {
630 let dim = 100;
631 let quantizer = MinMaxQuantizer::new(
632 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
633 Positive::new(1.0).unwrap(),
634 );
635
636 let mut vector_nan = vec![1.0f32; dim];
638 vector_nan[33] = f32::NAN;
639 let mut encoded = Data::new_boxed(dim);
640 let result = quantizer.compress_into(&vector_nan, encoded.reborrow_mut());
641 assert!(result.is_err(), "Vector with NaN should cause an error");
642
643 let meta = encoded.meta();
644 assert_eq!(meta.dim as usize, dim);
645 }
646
647 expand_to_bitrates!(all_same_values_vector, test_all_same_value_vector);
648 expand_to_bitrates!(two_distinct_values, test_two_distinct_values);
649 expand_to_bitrates!(nan_input_error, test_nan_input_error);
650
651 #[test]
653 #[should_panic(expected = "assertion `left == right` failed\n left: 15\n right: 10")]
654 fn test_dimension_mismatch_panic()
655 where
656 Unsigned: Representation<8>,
657 MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, 8>, Output = L2Loss>,
658 {
659 let expected_dim = 10;
660 let quantizer = MinMaxQuantizer::new(
661 Transform::Null(NullTransform::new(NonZeroUsize::new(expected_dim).unwrap())),
662 Positive::new(1.0).unwrap(),
663 );
664
665 let wrong_vector = vec![1.0f32; expected_dim + 5]; let mut encoded = Data::new_boxed(expected_dim);
668
669 let _ = quantizer.compress_into(&wrong_vector, encoded.reborrow_mut());
671 }
672}