1use super::vectors::{DataMutRef, FullQueryMut, MinMaxCompensation, MinMaxIP, MinMaxL2Squared};
7use core::f32;
8
9use crate::{
10 AsFunctor, CompressInto,
11 algorithms::Transform,
12 alloc::{GlobalAllocator, ScopedAllocator},
13 bits::{Representation, Unsigned},
14 minmax::{MinMaxCosine, MinMaxCosineNormalized, vectors::FullQueryMeta},
15 num::Positive,
16 scalar::{InputContainsNaN, bit_scale},
17};
18
19pub struct MinMaxQuantizer {
51 transform: Transform<GlobalAllocator>,
55
56 grid_scale: Positive<f32>,
59}
60
61impl MinMaxQuantizer {
62 pub fn new(transform: Transform<GlobalAllocator>, grid_scale: Positive<f32>) -> Self {
64 Self {
65 transform,
66 grid_scale,
67 }
68 }
69
70 pub fn dim(&self) -> usize {
72 self.transform.input_dim()
73 }
74
75 pub fn output_dim(&self) -> usize {
80 self.transform.output_dim()
81 }
82
83 fn get_range<const NBITS: usize>(&self, vec: &[f32]) -> (f32, f32) {
99 let (min, max) = match NBITS {
100 1 => {
101 let (mut min, mut min_count) = (0.0f32, 0.0f32);
102 let (mut max, mut max_count) = (0.0f32, 0.0f32);
103
104 let mean = vec.iter().sum::<f32>() / (vec.len() as f32);
105
106 vec.iter().for_each(|x| {
107 let m = f32::from((*x < mean) as u8);
108 min += m * x;
109 min_count += m;
110 max += (1.0 - m) * x;
111 max_count += 1.0 - m;
112 });
113
114 ((min / min_count).min(mean), (max / max_count).max(mean))
115 }
116 _ => {
117 vec .iter()
119 .fold((f32::NAN, f32::NAN), |(cmin, cmax), &e| {
120 (cmin.min(e), cmax.max(e))
121 })
122 }
123 };
124
125 let width = (max - min) / 2.0;
126 let mid = min + width;
127
128 (
129 mid - width * self.grid_scale.into_inner(),
130 mid + width * self.grid_scale.into_inner(),
131 )
132 }
133
134 fn compress<const NBITS: usize, T>(
135 &self,
136 from: &[T],
137 mut into: DataMutRef<'_, NBITS>,
138 ) -> Result<L2Loss, InputContainsNaN>
139 where
140 T: Copy + Into<f32>,
141 Unsigned: Representation<NBITS>,
142 {
143 let mut into_vec = into.vector_mut();
144
145 assert_eq!(from.len(), self.dim());
146 assert_eq!(self.output_dim(), into_vec.len());
147
148 let domain = Unsigned::domain_const::<NBITS>();
149 let domain_min = *domain.start() as f32;
150 let domain_max = *domain.end() as f32;
151
152 let mut vec = vec![f32::default(); self.output_dim()];
153
154 #[allow(clippy::unwrap_used)]
156 self.transform
157 .transform_into(
158 &mut vec,
159 &from.iter().map(|&x| x.into()).collect::<Vec<f32>>(),
160 ScopedAllocator::global(),
161 )
162 .unwrap();
163
164 let (min, max) = self.get_range::<NBITS>(&vec);
165
166 let inverse_scale = (max - min).max(1e-8) / bit_scale::<NBITS>(); let mut norm_squared: f32 = 0.0;
168 let mut code_sum: f32 = 0.0;
169 let mut loss: f32 = 0.0;
170
171 let mut nan_check = false;
172
173 vec.iter().enumerate().for_each(|(i, &v)| {
174 nan_check |= v.is_nan();
175
176 let code = ((v - min) / inverse_scale)
177 .clamp(domain_min, domain_max)
178 .round();
179
180 let v_r = (code * inverse_scale) + min; norm_squared += v_r * v_r;
182 code_sum += code;
183 loss += (v_r - v).powi(2);
184
185 unsafe {
187 into_vec.set_unchecked(i, code as u8);
188 }
189 });
190
191 let meta = MinMaxCompensation {
192 dim: self.output_dim() as u32,
193 b: min,
194 a: inverse_scale,
195 n: inverse_scale * code_sum,
196 norm_squared,
197 };
198
199 into.set_meta(meta);
200
201 if nan_check {
202 Err(InputContainsNaN)
203 } else {
204 Ok(match Positive::new(loss) {
205 Ok(p) => L2Loss::Positive(p),
206 Err(_) => L2Loss::Zero,
207 })
208 }
209 }
210}
211
212#[derive(Clone, Copy, Debug)]
221pub enum L2Loss {
222 Zero,
223 Positive(Positive<f32>),
224}
225
226impl L2Loss {
227 pub fn as_f32(&self) -> f32 {
229 match self {
230 L2Loss::Zero => 0.0,
231 L2Loss::Positive(p) => p.into_inner(),
232 }
233 }
234}
235
236impl<const NBITS: usize, T> CompressInto<&[T], DataMutRef<'_, NBITS>> for MinMaxQuantizer
237where
238 T: Copy + Into<f32>,
239 Unsigned: Representation<NBITS>,
240{
241 type Error = InputContainsNaN;
242
243 type Output = L2Loss;
244
245 fn compress_into(&self, from: &[T], to: DataMutRef<'_, NBITS>) -> Result<L2Loss, Self::Error> {
262 self.compress::<NBITS, T>(from, to)
263 }
264}
265
266impl<'a, T> CompressInto<&[T], FullQueryMut<'a>> for MinMaxQuantizer
267where
268 T: Copy + Into<f32>,
269{
270 type Error = InputContainsNaN;
271
272 type Output = ();
273
274 fn compress_into(&self, from: &[T], mut to: FullQueryMut<'a>) -> Result<(), Self::Error> {
291 assert_eq!(from.len(), self.dim());
292 assert_eq!(self.output_dim(), to.len());
293
294 let from: Vec<f32> = from.iter().map(|&x| x.into()).collect();
296 if from.iter().any(|x| x.is_nan()) {
297 return Err(InputContainsNaN);
298 }
299
300 #[allow(clippy::unwrap_used)]
302 self.transform
303 .transform_into(to.vector_mut(), &from, ScopedAllocator::global())
304 .unwrap();
305
306 let norm_squared = to.vector().iter().map(|x| *x * *x).sum::<f32>();
307 let sum = to.vector().iter().sum::<f32>();
308
309 *to.meta_mut() = FullQueryMeta { norm_squared, sum };
310
311 Ok(())
312 }
313}
314
315macro_rules! impl_functor {
320 ($dist:ident) => {
321 impl AsFunctor<$dist> for MinMaxQuantizer {
322 fn as_functor(&self) -> $dist {
324 $dist
325 }
326 }
327 };
328}
329
330impl_functor!(MinMaxIP);
331impl_functor!(MinMaxL2Squared);
332impl_functor!(MinMaxCosine);
333impl_functor!(MinMaxCosineNormalized);
334
335#[cfg(test)]
339#[cfg(not(miri))]
340mod minmax_quantizer_tests {
341 use std::num::NonZeroUsize;
342
343 use diskann_utils::{Reborrow, ReborrowMut};
344 use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
345 use rand::{
346 SeedableRng,
347 distr::{Distribution, Uniform},
348 rngs::StdRng,
349 };
350
351 use super::*;
352 use crate::{
353 algorithms::transforms::NullTransform,
354 alloc::GlobalAllocator,
355 minmax::vectors::{Data, DataRef, FullQuery, FullQueryMut},
356 };
357
358 fn reconstruct_minmax<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
359 where
360 Unsigned: Representation<NBITS>,
361 {
362 (0..v.len())
363 .map(|i| {
364 let m = v.meta();
365 v.vector().get(i).unwrap() as f32 * m.a + m.b
366 })
367 .collect()
368 }
369
370 fn test_quantizer_encoding_random<const NBITS: usize>(
371 dim: usize,
372 rng: &mut StdRng,
373 relative_err: f32,
374 scale: f32,
375 ) where
376 Unsigned: Representation<NBITS>,
377 MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>
378 + for<'a, 'b> CompressInto<&'a [f32], FullQueryMut<'b>, Output = ()>,
379 {
380 let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
381
382 let quantizer = MinMaxQuantizer::new(
383 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
384 Positive::new(scale).unwrap(),
385 );
386
387 assert_eq!(quantizer.dim(), dim);
388
389 let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
390
391 let mut encoded = Data::new_boxed(dim);
392 let loss = quantizer
393 .compress_into(&*vector, encoded.reborrow_mut())
394 .unwrap();
395
396 let reconstructed = reconstruct_minmax::<NBITS>(encoded.reborrow());
397 assert_eq!(reconstructed.len(), dim);
398
399 let reconstruction_error: f32 = SquaredL2::evaluate(&*vector, &*reconstructed);
400 let norm = vector.iter().map(|x| x * x).sum::<f32>();
401 assert!(
402 (reconstruction_error / norm) <= relative_err,
403 "Expected vector : {:?} to be reconstructed within error {} but instead got : {:?}, with error {} for dim : {}",
404 &vector,
405 relative_err,
406 &reconstructed,
407 reconstruction_error / norm,
408 dim,
409 );
410
411 assert!((loss.as_f32() - reconstruction_error) <= 1e-4);
412
413 let expected_code_sum = (0..dim)
414 .map(|i| encoded.vector().get(i).unwrap() as f32)
415 .sum::<f32>();
416 let code_sum = encoded.reborrow().meta().n / encoded.reborrow().meta().a;
417 assert!(
418 (code_sum - expected_code_sum).abs() <= 2e-5 * (dim as f32),
419 "Encoded vector with dim : {dim} is {:?}, got error : {} for vector : {:?}",
420 encoded.reborrow(),
421 (code_sum - expected_code_sum).abs(),
422 &vector,
423 );
424 let recon_norm_sq = reconstructed.iter().map(|x| x * x).sum::<f32>();
425 assert!((encoded.reborrow().meta().norm_squared - recon_norm_sq).abs() <= 1e-3);
426
427 let mut f = FullQuery::new_in(dim, GlobalAllocator).unwrap();
429 quantizer
430 .compress_into(vector.as_slice(), f.reborrow_mut())
431 .unwrap();
432
433 f.vector()
434 .iter()
435 .enumerate()
436 .zip(vector.iter())
437 .for_each(|((i, x), y)| {
438 assert!(
439 (*x - *y).abs() < 1e-10,
440 "Full Query did not compress dimension {i} with value {} correctly, got {} instead.",
441 *y,
442 *x,
443 )
444 });
445
446 assert!(
447 (f.meta().norm_squared - norm).abs() < 1e-10,
448 "Full Query norm in meta should be {norm} but instead got {}",
449 f.meta().norm_squared
450 );
451
452 let sum = vector.iter().sum::<f32>();
453 assert!(
454 (f.meta().sum - sum) < 1e-10,
455 "Full Query norm in meta should be {sum} but instead got {}",
456 f.meta().sum
457 );
458 }
459
460 cfg_if::cfg_if! {
461 if #[cfg(miri)] {
462 const TRIALS: usize = 2;
466 } else {
467 const TRIALS: usize = 10;
468 }
469 }
470
471 macro_rules! test_minmax_quantizer_encoding {
472 ($name:ident, $dim:literal, $nbits:literal, $seed:literal, $err:expr) => {
473 #[test]
474 fn $name() {
475 let mut rng = StdRng::seed_from_u64($seed);
476 let scales = [1.0, 1.1, 0.9];
477 for (s, e) in scales.iter().zip($err) {
478 for d in 10..$dim {
479 for _ in 0..TRIALS {
480 test_quantizer_encoding_random::<$nbits>(d, &mut rng, e, *s);
481 }
482 }
483 }
484 }
485 };
486 }
487 test_minmax_quantizer_encoding!(
488 test_minmax_encoding_1bit,
489 100,
490 1,
491 0xa32d5658097a1c35,
492 vec![0.5, 0.5, 0.5]
493 );
494 test_minmax_quantizer_encoding!(
495 test_minmax_encoding_2bit,
496 100,
497 2,
498 0xf60c0c8d1aadc126,
499 vec![0.5, 0.5, 0.5]
500 );
501 test_minmax_quantizer_encoding!(
502 test_minmax_encoding_4bit,
503 100,
504 4,
505 0x09fa14c42a9d7d98,
506 vec![1.0e-2, 1.0e-2, 3.0e-2]
507 );
508 test_minmax_quantizer_encoding!(
509 test_minmax_encoding_8bit,
510 100,
511 8,
512 0xaedf3d2a223b7b77,
513 vec![2.0e-3, 2.0e-3, 7.0e-3]
514 );
515
516 macro_rules! expand_to_bitrates {
517 ($name:ident, $func:ident) => {
518 #[test]
519 fn $name() {
520 $func::<1>();
521 $func::<2>();
522 $func::<4>();
523 $func::<8>();
524 }
525 };
526 }
527
528 fn test_all_same_value_vector<const NBITS: usize>()
530 where
531 Unsigned: Representation<NBITS>,
532 MinMaxQuantizer:
533 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
534 {
535 let dim = 30;
536 let quantizer = MinMaxQuantizer::new(
537 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
538 Positive::new(1.0).unwrap(),
539 );
540 let constant_value = 42.5f32;
541 let vector = vec![constant_value; dim];
542
543 let mut encoded = Data::new_boxed(dim);
544 let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
545
546 assert!(
547 result.is_ok(),
548 "Constant-value vector should compress successfully"
549 );
550
551 assert!(result.unwrap().as_f32().abs() <= 1e-6);
552
553 let reconstructed = reconstruct_minmax(encoded.reborrow());
555 for &val in &reconstructed {
556 assert!(
557 (val - constant_value).abs() < 1e-3,
558 "Reconstructed value {} should be close to original {}. Compressed vector is {:?}",
559 val,
560 constant_value,
561 encoded.meta(),
562 );
563 }
564 }
565
566 fn test_two_distinct_values<const NBITS: usize>()
568 where
569 Unsigned: Representation<NBITS>,
570 MinMaxQuantizer:
571 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
572 {
573 let dim = 20;
574 let quantizer = MinMaxQuantizer::new(
575 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
576 Positive::new(1.0).unwrap(),
577 );
578
579 let val1 = -10.0f32;
580 let val2 = 15.0f32;
581 let mut vector = vec![val1; dim];
582 for i in vector.iter_mut().skip(dim) {
584 *i = val2;
585 }
586
587 let mut encoded = Data::new_boxed(dim);
588 let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
589
590 assert!(
591 result.is_ok(),
592 "Two-value vector should compress successfully"
593 );
594
595 assert!(result.unwrap().as_f32().abs() <= 1e-6);
596
597 let mut codes_used = std::collections::HashSet::new();
599 for i in 0..dim {
600 codes_used.insert(encoded.vector().get(i).unwrap());
601 }
602
603 if NBITS > 1 {
605 assert!(
606 codes_used.len() <= 2,
607 "Should use at most 2 distinct codes for 2-value input, but used: {:?}",
608 codes_used
609 );
610 }
611
612 let reconstructed = reconstruct_minmax(encoded.reborrow());
614 for ((i, val), v) in reconstructed.into_iter().enumerate().zip(&vector) {
615 assert!(
617 (val - v).abs() < 1e-4,
618 "Reconstructed value in dim : {i} is {val}, when it should be {v}."
619 );
620 }
621 }
622
623 fn test_nan_input_error<const NBITS: usize>()
626 where
627 Unsigned: Representation<NBITS>,
628 MinMaxQuantizer:
629 for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
630 {
631 let dim = 100;
632 let quantizer = MinMaxQuantizer::new(
633 Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
634 Positive::new(1.0).unwrap(),
635 );
636
637 let mut vector_nan = vec![1.0f32; dim];
639 vector_nan[33] = f32::NAN;
640 let mut encoded = Data::new_boxed(dim);
641 let result = quantizer.compress_into(&vector_nan, encoded.reborrow_mut());
642 assert!(result.is_err(), "Vector with NaN should cause an error");
643
644 let meta = encoded.meta();
645 assert_eq!(meta.dim as usize, dim);
646 }
647
648 expand_to_bitrates!(all_same_values_vector, test_all_same_value_vector);
649 expand_to_bitrates!(two_distinct_values, test_two_distinct_values);
650 expand_to_bitrates!(nan_input_error, test_nan_input_error);
651
652 #[test]
654 #[should_panic(expected = "assertion `left == right` failed\n left: 15\n right: 10")]
655 fn test_dimension_mismatch_panic()
656 where
657 Unsigned: Representation<8>,
658 MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, 8>, Output = L2Loss>,
659 {
660 let expected_dim = 10;
661 let quantizer = MinMaxQuantizer::new(
662 Transform::Null(NullTransform::new(NonZeroUsize::new(expected_dim).unwrap())),
663 Positive::new(1.0).unwrap(),
664 );
665
666 let wrong_vector = vec![1.0f32; expected_dim + 5]; let mut encoded = Data::new_boxed(expected_dim);
669
670 let _ = quantizer.compress_into(&wrong_vector, encoded.reborrow_mut());
672 }
673}