use super::vectors::{DataMutRef, FullQueryMut, MinMaxCompensation, MinMaxIP, MinMaxL2Squared};
use core::f32;
use crate::{
AsFunctor, CompressInto,
algorithms::Transform,
alloc::{GlobalAllocator, ScopedAllocator},
bits::{Representation, Unsigned},
minmax::{MinMaxCosine, MinMaxCosineNormalized, vectors::FullQueryMeta},
num::Positive,
scalar::{InputContainsNaN, bit_scale},
};
pub struct MinMaxQuantizer {
transform: Transform<GlobalAllocator>,
grid_scale: Positive<f32>,
}
impl MinMaxQuantizer {
pub fn new(transform: Transform<GlobalAllocator>, grid_scale: Positive<f32>) -> Self {
Self {
transform,
grid_scale,
}
}
pub fn dim(&self) -> usize {
self.transform.input_dim()
}
pub fn output_dim(&self) -> usize {
self.transform.output_dim()
}
fn get_range<const NBITS: usize>(&self, vec: &[f32]) -> (f32, f32) {
let (min, max) = match NBITS {
1 => {
let (mut min, mut min_count) = (0.0f32, 0.0f32);
let (mut max, mut max_count) = (0.0f32, 0.0f32);
let mean = vec.iter().sum::<f32>() / (vec.len() as f32);
vec.iter().for_each(|x| {
let m = f32::from((*x < mean) as u8);
min += m * x;
min_count += m;
max += (1.0 - m) * x;
max_count += 1.0 - m;
});
((min / min_count).min(mean), (max / max_count).max(mean))
}
_ => {
vec .iter()
.fold((f32::NAN, f32::NAN), |(cmin, cmax), &e| {
(cmin.min(e), cmax.max(e))
})
}
};
let width = (max - min) / 2.0;
let mid = min + width;
(
mid - width * self.grid_scale.into_inner(),
mid + width * self.grid_scale.into_inner(),
)
}
fn compress<const NBITS: usize, T>(
&self,
from: &[T],
mut into: DataMutRef<'_, NBITS>,
) -> Result<L2Loss, InputContainsNaN>
where
T: Copy + Into<f32>,
Unsigned: Representation<NBITS>,
{
let mut into_vec = into.vector_mut();
assert_eq!(from.len(), self.dim());
assert_eq!(self.output_dim(), into_vec.len());
let domain = Unsigned::domain_const::<NBITS>();
let domain_min = *domain.start() as f32;
let domain_max = *domain.end() as f32;
let mut vec = vec![f32::default(); self.output_dim()];
#[allow(clippy::unwrap_used)]
self.transform
.transform_into(
&mut vec,
&from.iter().map(|&x| x.into()).collect::<Vec<f32>>(),
ScopedAllocator::global(),
)
.unwrap();
let (min, max) = self.get_range::<NBITS>(&vec);
let inverse_scale = (max - min).max(1e-8) / bit_scale::<NBITS>(); let mut norm_squared: f32 = 0.0;
let mut code_sum: f32 = 0.0;
let mut loss: f32 = 0.0;
let mut nan_check = false;
vec.iter().enumerate().for_each(|(i, &v)| {
nan_check |= v.is_nan();
let code = ((v - min) / inverse_scale)
.clamp(domain_min, domain_max)
.round();
let v_r = (code * inverse_scale) + min; norm_squared += v_r * v_r;
code_sum += code;
loss += (v_r - v).powi(2);
unsafe {
into_vec.set_unchecked(i, code as u8);
}
});
let meta = MinMaxCompensation {
dim: self.output_dim() as u32,
b: min,
a: inverse_scale,
n: inverse_scale * code_sum,
norm_squared,
};
into.set_meta(meta);
if nan_check {
Err(InputContainsNaN)
} else {
Ok(match Positive::new(loss) {
Ok(p) => L2Loss::Positive(p),
Err(_) => L2Loss::Zero,
})
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum L2Loss {
Zero,
Positive(Positive<f32>),
}
impl L2Loss {
pub fn as_f32(&self) -> f32 {
match self {
L2Loss::Zero => 0.0,
L2Loss::Positive(p) => p.into_inner(),
}
}
}
impl<const NBITS: usize, T> CompressInto<&[T], DataMutRef<'_, NBITS>> for MinMaxQuantizer
where
T: Copy + Into<f32>,
Unsigned: Representation<NBITS>,
{
type Error = InputContainsNaN;
type Output = L2Loss;
fn compress_into(&self, from: &[T], to: DataMutRef<'_, NBITS>) -> Result<L2Loss, Self::Error> {
self.compress::<NBITS, T>(from, to)
}
}
impl<'a, T> CompressInto<&[T], FullQueryMut<'a>> for MinMaxQuantizer
where
T: Copy + Into<f32>,
{
type Error = InputContainsNaN;
type Output = ();
fn compress_into(&self, from: &[T], mut to: FullQueryMut<'a>) -> Result<(), Self::Error> {
assert_eq!(from.len(), self.dim());
assert_eq!(self.output_dim(), to.len());
let from: Vec<f32> = from.iter().map(|&x| x.into()).collect();
if from.iter().any(|x| x.is_nan()) {
return Err(InputContainsNaN);
}
#[allow(clippy::unwrap_used)]
self.transform
.transform_into(to.vector_mut(), &from, ScopedAllocator::global())
.unwrap();
let norm_squared = to.vector().iter().map(|x| *x * *x).sum::<f32>();
let sum = to.vector().iter().sum::<f32>();
*to.meta_mut() = FullQueryMeta { norm_squared, sum };
Ok(())
}
}
macro_rules! impl_functor {
($dist:ident) => {
impl AsFunctor<$dist> for MinMaxQuantizer {
fn as_functor(&self) -> $dist {
$dist
}
}
};
}
impl_functor!(MinMaxIP);
impl_functor!(MinMaxL2Squared);
impl_functor!(MinMaxCosine);
impl_functor!(MinMaxCosineNormalized);
#[cfg(test)]
#[cfg(not(miri))]
mod minmax_quantizer_tests {
use std::num::NonZeroUsize;
use diskann_utils::{Reborrow, ReborrowMut};
use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
use rand::{
SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
use crate::{
algorithms::transforms::NullTransform,
alloc::GlobalAllocator,
minmax::vectors::{Data, DataRef, FullQuery, FullQueryMut},
};
fn reconstruct_minmax<const NBITS: usize>(v: DataRef<'_, NBITS>) -> Vec<f32>
where
Unsigned: Representation<NBITS>,
{
(0..v.len())
.map(|i| {
let m = v.meta();
v.vector().get(i).unwrap() as f32 * m.a + m.b
})
.collect()
}
fn test_quantizer_encoding_random<const NBITS: usize>(
dim: usize,
rng: &mut StdRng,
relative_err: f32,
scale: f32,
) where
Unsigned: Representation<NBITS>,
MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>
+ for<'a, 'b> CompressInto<&'a [f32], FullQueryMut<'b>, Output = ()>,
{
let distribution = Uniform::new_inclusive::<f32, f32>(-1.0, 1.0).unwrap();
let quantizer = MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
Positive::new(scale).unwrap(),
);
assert_eq!(quantizer.dim(), dim);
let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
let mut encoded = Data::new_boxed(dim);
let loss = quantizer
.compress_into(&*vector, encoded.reborrow_mut())
.unwrap();
let reconstructed = reconstruct_minmax::<NBITS>(encoded.reborrow());
assert_eq!(reconstructed.len(), dim);
let reconstruction_error: f32 = SquaredL2::evaluate(&*vector, &*reconstructed);
let norm = vector.iter().map(|x| x * x).sum::<f32>();
assert!(
(reconstruction_error / norm) <= relative_err,
"Expected vector : {:?} to be reconstructed within error {} but instead got : {:?}, with error {} for dim : {}",
&vector,
relative_err,
&reconstructed,
reconstruction_error / norm,
dim,
);
assert!((loss.as_f32() - reconstruction_error) <= 1e-4);
let expected_code_sum = (0..dim)
.map(|i| encoded.vector().get(i).unwrap() as f32)
.sum::<f32>();
let code_sum = encoded.reborrow().meta().n / encoded.reborrow().meta().a;
assert!(
(code_sum - expected_code_sum).abs() <= 2e-5 * (dim as f32),
"Encoded vector with dim : {dim} is {:?}, got error : {} for vector : {:?}",
encoded.reborrow(),
(code_sum - expected_code_sum).abs(),
&vector,
);
let recon_norm_sq = reconstructed.iter().map(|x| x * x).sum::<f32>();
assert!((encoded.reborrow().meta().norm_squared - recon_norm_sq).abs() <= 1e-3);
let mut f = FullQuery::new_in(dim, GlobalAllocator).unwrap();
quantizer
.compress_into(vector.as_slice(), f.reborrow_mut())
.unwrap();
f.vector()
.iter()
.enumerate()
.zip(vector.iter())
.for_each(|((i, x), y)| {
assert!(
(*x - *y).abs() < 1e-10,
"Full Query did not compress dimension {i} with value {} correctly, got {} instead.",
*y,
*x,
)
});
assert!(
(f.meta().norm_squared - norm).abs() < 1e-10,
"Full Query norm in meta should be {norm} but instead got {}",
f.meta().norm_squared
);
let sum = vector.iter().sum::<f32>();
assert!(
(f.meta().sum - sum) < 1e-10,
"Full Query norm in meta should be {sum} but instead got {}",
f.meta().sum
);
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const TRIALS: usize = 2;
} else {
const TRIALS: usize = 10;
}
}
macro_rules! test_minmax_quantizer_encoding {
($name:ident, $dim:literal, $nbits:literal, $seed:literal, $err:expr) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
let scales = [1.0, 1.1, 0.9];
for (s, e) in scales.iter().zip($err) {
for d in 10..$dim {
for _ in 0..TRIALS {
test_quantizer_encoding_random::<$nbits>(d, &mut rng, e, *s);
}
}
}
}
};
}
test_minmax_quantizer_encoding!(
test_minmax_encoding_1bit,
100,
1,
0xa32d5658097a1c35,
vec![0.5, 0.5, 0.5]
);
test_minmax_quantizer_encoding!(
test_minmax_encoding_2bit,
100,
2,
0xf60c0c8d1aadc126,
vec![0.5, 0.5, 0.5]
);
test_minmax_quantizer_encoding!(
test_minmax_encoding_4bit,
100,
4,
0x09fa14c42a9d7d98,
vec![1.0e-2, 1.0e-2, 3.0e-2]
);
test_minmax_quantizer_encoding!(
test_minmax_encoding_8bit,
100,
8,
0xaedf3d2a223b7b77,
vec![2.0e-3, 2.0e-3, 7.0e-3]
);
macro_rules! expand_to_bitrates {
($name:ident, $func:ident) => {
#[test]
fn $name() {
$func::<1>();
$func::<2>();
$func::<4>();
$func::<8>();
}
};
}
fn test_all_same_value_vector<const NBITS: usize>()
where
Unsigned: Representation<NBITS>,
MinMaxQuantizer:
for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
{
let dim = 30;
let quantizer = MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
Positive::new(1.0).unwrap(),
);
let constant_value = 42.5f32;
let vector = vec![constant_value; dim];
let mut encoded = Data::new_boxed(dim);
let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
assert!(
result.is_ok(),
"Constant-value vector should compress successfully"
);
assert!(result.unwrap().as_f32().abs() <= 1e-6);
let reconstructed = reconstruct_minmax(encoded.reborrow());
for &val in &reconstructed {
assert!(
(val - constant_value).abs() < 1e-3,
"Reconstructed value {} should be close to original {}. Compressed vector is {:?}",
val,
constant_value,
encoded.meta(),
);
}
}
fn test_two_distinct_values<const NBITS: usize>()
where
Unsigned: Representation<NBITS>,
MinMaxQuantizer:
for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
{
let dim = 20;
let quantizer = MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
Positive::new(1.0).unwrap(),
);
let val1 = -10.0f32;
let val2 = 15.0f32;
let mut vector = vec![val1; dim];
for i in vector.iter_mut().skip(dim) {
*i = val2;
}
let mut encoded = Data::new_boxed(dim);
let result = quantizer.compress_into(&vector, encoded.reborrow_mut());
assert!(
result.is_ok(),
"Two-value vector should compress successfully"
);
assert!(result.unwrap().as_f32().abs() <= 1e-6);
let mut codes_used = std::collections::HashSet::new();
for i in 0..dim {
codes_used.insert(encoded.vector().get(i).unwrap());
}
if NBITS > 1 {
assert!(
codes_used.len() <= 2,
"Should use at most 2 distinct codes for 2-value input, but used: {:?}",
codes_used
);
}
let reconstructed = reconstruct_minmax(encoded.reborrow());
for ((i, val), v) in reconstructed.into_iter().enumerate().zip(&vector) {
assert!(
(val - v).abs() < 1e-4,
"Reconstructed value in dim : {i} is {val}, when it should be {v}."
);
}
}
fn test_nan_input_error<const NBITS: usize>()
where
Unsigned: Representation<NBITS>,
MinMaxQuantizer:
for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, NBITS>, Output = L2Loss>,
{
let dim = 100;
let quantizer = MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(dim).unwrap())),
Positive::new(1.0).unwrap(),
);
let mut vector_nan = vec![1.0f32; dim];
vector_nan[33] = f32::NAN;
let mut encoded = Data::new_boxed(dim);
let result = quantizer.compress_into(&vector_nan, encoded.reborrow_mut());
assert!(result.is_err(), "Vector with NaN should cause an error");
let meta = encoded.meta();
assert_eq!(meta.dim as usize, dim);
}
expand_to_bitrates!(all_same_values_vector, test_all_same_value_vector);
expand_to_bitrates!(two_distinct_values, test_two_distinct_values);
expand_to_bitrates!(nan_input_error, test_nan_input_error);
#[test]
#[should_panic(expected = "assertion `left == right` failed\n left: 15\n right: 10")]
fn test_dimension_mismatch_panic()
where
Unsigned: Representation<8>,
MinMaxQuantizer: for<'a, 'b> CompressInto<&'a [f32], DataMutRef<'b, 8>, Output = L2Loss>,
{
let expected_dim = 10;
let quantizer = MinMaxQuantizer::new(
Transform::Null(NullTransform::new(NonZeroUsize::new(expected_dim).unwrap())),
Positive::new(1.0).unwrap(),
);
let wrong_vector = vec![1.0f32; expected_dim + 5]; let mut encoded = Data::new_boxed(expected_dim);
let _ = quantizer.compress_into(&wrong_vector, encoded.reborrow_mut());
}
}