use crate::QuantError;
use crate::codebook::Codebook;
use crate::pack;
use crate::rotation::Rotation;
pub struct TurboQuantMse {
rotation: Rotation,
codebook: &'static Codebook,
bits: u8,
scale: f32,
}
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub packed_indices: Vec<u8>,
pub norm: f32,
pub bits: u8,
pub dimension: usize,
}
impl TurboQuantMse {
pub fn new(dimension: usize, bits: u8, seed: u64) -> Result<Self, QuantError> {
let codebook = Codebook::for_bits(bits)?;
let rotation = Rotation::new(dimension, seed);
let scale = (dimension as f32).sqrt();
Ok(Self {
rotation,
codebook,
bits,
scale,
})
}
pub fn dimension(&self) -> usize {
self.rotation.dimension()
}
pub fn bits(&self) -> u8 {
self.bits
}
pub fn seed(&self) -> u64 {
self.rotation.seed()
}
pub fn quantize(&self, x: &[f32]) -> Result<QuantizedVector, QuantError> {
let dim = self.rotation.dimension();
if x.len() != dim {
return Err(QuantError::DimensionMismatch {
expected: dim,
got: x.len(),
});
}
let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let mut y = if norm > 0.0 {
x.iter().map(|v| v / norm).collect::<Vec<_>>()
} else {
vec![0.0; dim]
};
self.rotation.forward(&mut y);
for val in &mut y {
*val *= self.scale;
}
let indices: Vec<u8> = y
.iter()
.map(|&v| self.codebook.quantize_scalar(v))
.collect();
let packed_indices = pack::pack_indices(&indices, self.bits)?;
Ok(QuantizedVector {
packed_indices,
norm,
bits: self.bits,
dimension: dim,
})
}
pub fn dequantize(&self, q: &QuantizedVector) -> Result<Vec<f32>, QuantError> {
let dim = q.dimension;
let indices = pack::unpack_indices(&q.packed_indices, q.bits, dim)?;
let mut y: Vec<f32> = indices
.iter()
.map(|&idx| self.codebook.dequantize_scalar(idx))
.collect();
let inv_scale = 1.0 / self.scale;
for val in &mut y {
*val *= inv_scale;
}
self.rotation.inverse(&mut y);
for val in &mut y {
*val *= q.norm;
}
Ok(y)
}
pub fn dequantize_into(&self, q: &QuantizedVector, out: &mut [f32]) -> Result<(), QuantError> {
let result = self.dequantize(q)?;
out.copy_from_slice(&result);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_unit_vector(dim: usize, seed: u64) -> Vec<f32> {
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = StdRng::seed_from_u64(seed);
let normal = StandardNormal;
let mut v: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
v
}
#[test]
fn quantize_dequantize_roundtrip() {
let dim = 128;
let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
let x = random_unit_vector(dim, 7);
let q = quant.quantize(&x).unwrap();
let x_hat = quant.dequantize(&q).unwrap();
assert_eq!(x_hat.len(), dim);
let mse: f32 = x
.iter()
.zip(x_hat.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>();
assert!(mse < 0.5, "MSE too high: {mse} (expected < 0.5 for 2-bit)");
}
#[test]
fn mse_decreases_with_bits() {
let dim = 256;
let x = random_unit_vector(dim, 13);
let mut prev_mse = f32::MAX;
for bits in 1..=4 {
let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
let q = quant.quantize(&x).unwrap();
let x_hat = quant.dequantize(&q).unwrap();
let mse: f32 = x
.iter()
.zip(x_hat.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>();
assert!(
mse < prev_mse,
"{bits}-bit MSE ({mse}) not less than {}-bit ({prev_mse})",
bits - 1
);
prev_mse = mse;
}
}
#[test]
fn preserves_norm() {
let dim = 64;
let quant = TurboQuantMse::new(dim, 3, 42).unwrap();
let x: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) * 0.1).collect();
let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let q = quant.quantize(&x).unwrap();
let x_hat = quant.dequantize(&q).unwrap();
let norm_hat: f32 = x_hat.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!(
(norm_orig - norm_hat).abs() / norm_orig < 0.3,
"norm diverged: {norm_orig} → {norm_hat}"
);
}
#[test]
fn zero_vector() {
let dim = 32;
let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
let x = vec![0.0f32; dim];
let q = quant.quantize(&x).unwrap();
assert_eq!(q.norm, 0.0);
let x_hat = quant.dequantize(&q).unwrap();
for v in &x_hat {
assert_eq!(*v, 0.0);
}
}
#[test]
fn dimension_mismatch() {
let quant = TurboQuantMse::new(32, 2, 42).unwrap();
let x = vec![1.0; 64];
assert!(quant.quantize(&x).is_err());
}
#[test]
fn average_mse_matches_theory() {
let dim = 256;
let bits = 2;
let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
let n_trials = 100;
let total_mse: f32 = (0..n_trials)
.map(|seed| {
let x = random_unit_vector(dim, seed + 1000);
let q = quant.quantize(&x).unwrap();
let x_hat = quant.dequantize(&q).unwrap();
x.iter()
.zip(x_hat.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>()
})
.sum();
let avg_mse = total_mse / n_trials as f32;
assert!(
avg_mse < 0.35,
"average MSE = {avg_mse}, expected < 0.35 for 2-bit"
);
}
}