use crate::backend::{sum_squared_error, ExecutionBackend};
use serde::{Deserialize, Serialize};
use crate::codebook::{generate_codebook, Codebook};
use crate::error::{Result, TurboQuantError};
use crate::rotation::RandomRotation;
use crate::scalar_quant::ScalarQuantizer;
use crate::utils::validate_unit_vector;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct QuantizedVector {
pub indices: Vec<u8>,
pub bit_width: u8,
pub dim: usize,
}
impl QuantizedVector {
pub fn bytes(&self) -> f64 {
(self.indices.len() as f64 * self.bit_width as f64) / 8.0
}
pub fn compression_ratio(&self) -> f64 {
(self.dim as f64 * 4.0) / self.bytes()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurboQuantMSE {
rotation: RandomRotation,
quantizer: ScalarQuantizer,
pub dim: usize,
pub bit_width: u8,
}
impl TurboQuantMSE {
fn validate_quantized(&self, q: &QuantizedVector) -> Result<()> {
if q.dim != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: q.dim,
});
}
if q.bit_width != self.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: q.bit_width,
});
}
if q.indices.len() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: "TurboQuantMSE quantized vector indices".into(),
expected: self.dim,
got: q.indices.len(),
});
}
self.quantizer.validate_indices(&q.indices)
}
pub(crate) fn dequantize_parts(
&self,
indices: &[u8],
bit_width: u8,
dim: usize,
) -> Result<Vec<f64>> {
let q = QuantizedVector {
indices: indices.to_vec(),
bit_width,
dim,
};
self.validate_quantized(&q)?;
let y_approx = self.quantizer.dequantize_batch(indices);
self.rotation.rotate_inverse(&y_approx)
}
pub fn new(dim: usize, bit_width: u8, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
if !(1..=8).contains(&bit_width) {
return Err(TurboQuantError::InvalidBitWidth(bit_width));
}
let rotation = RandomRotation::new(dim, seed)?;
let codebook = generate_codebook(dim, bit_width, 100)?;
let quantizer = ScalarQuantizer::from_codebook(codebook);
Ok(Self {
rotation,
quantizer,
dim,
bit_width,
})
}
pub fn with_codebook(dim: usize, codebook: Codebook, seed: u64) -> Result<Self> {
let bit_width = codebook.bit_width;
let rotation = RandomRotation::new(dim, seed)?;
let quantizer = ScalarQuantizer::from_codebook(codebook);
Ok(Self {
rotation,
quantizer,
dim,
bit_width,
})
}
pub fn quantize(&self, x: &[f64]) -> Result<QuantizedVector> {
if x.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: x.len(),
});
}
validate_unit_vector(x, "TurboQuantMSE input")?;
let y = self.rotation.rotate(x)?;
let indices = self.quantizer.quantize_batch(&y);
Ok(QuantizedVector {
indices,
bit_width: self.bit_width,
dim: self.dim,
})
}
pub fn dequantize(&self, q: &QuantizedVector) -> Result<Vec<f64>> {
self.dequantize_parts(&q.indices, q.bit_width, q.dim)
}
pub fn distortion_bound(&self) -> f64 {
let b = self.bit_width as f64;
3.0_f64.sqrt() * std::f64::consts::PI / 2.0 * (0.25_f64).powf(b)
}
pub fn actual_mse(&self, x: &[f64]) -> Result<f64> {
let q = self.quantize(x)?;
let x_approx = self.dequantize(&q)?;
let mse = sum_squared_error(ExecutionBackend::default(), x, &x_approx) / self.dim as f64;
Ok(mse)
}
pub fn codebook(&self) -> &Codebook {
&self.quantizer.codebook
}
#[cfg_attr(not(feature = "gpu"), allow(dead_code))]
pub(crate) fn rotation(&self) -> &RandomRotation {
&self.rotation
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::normalize;
fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&x).unwrap()
}
#[test]
fn test_quantize_dequantize_shape() {
let tq = TurboQuantMSE::new(64, 4, 42).unwrap();
let x = random_unit_vector(64, 1);
let q = tq.quantize(&x).unwrap();
assert_eq!(q.indices.len(), 64);
assert_eq!(q.bit_width, 4);
let x_approx = tq.dequantize(&q).unwrap();
assert_eq!(x_approx.len(), 64);
}
#[test]
fn test_mse_decreases_with_bits() {
let dim = 256;
let x = random_unit_vector(dim, 999);
let mse2 = TurboQuantMSE::new(dim, 2, 42)
.unwrap()
.actual_mse(&x)
.unwrap();
let mse4 = TurboQuantMSE::new(dim, 4, 42)
.unwrap()
.actual_mse(&x)
.unwrap();
assert!(
mse4 < mse2,
"4-bit MSE {} should < 2-bit MSE {}",
mse4,
mse2
);
}
#[test]
fn test_distortion_bound() {
let tq = TurboQuantMSE::new(512, 4, 42).unwrap();
let bound = tq.distortion_bound();
assert!(bound > 0.0 && bound < 0.1, "bound={}", bound);
}
#[test]
fn test_dimension_mismatch() {
let tq = TurboQuantMSE::new(64, 2, 1).unwrap();
let x = vec![0.0; 128];
assert!(tq.quantize(&x).is_err());
}
#[test]
fn test_compression_ratio() {
let tq = TurboQuantMSE::new(128, 4, 1).unwrap();
let x = random_unit_vector(128, 2);
let q = tq.quantize(&x).unwrap();
assert!((q.compression_ratio() - 8.0).abs() < 0.01);
}
#[test]
fn test_with_codebook() {
let dim = 64;
let tq_orig = TurboQuantMSE::new(dim, 4, 42).unwrap();
let codebook = tq_orig.codebook().clone();
let tq_shared = TurboQuantMSE::with_codebook(dim, codebook, 99).unwrap();
assert_eq!(tq_shared.dim, dim);
assert_eq!(tq_shared.bit_width, 4);
let x = random_unit_vector(dim, 1);
let q = tq_shared.quantize(&x).unwrap();
assert_eq!(q.indices.len(), dim);
let recon = tq_shared.dequantize(&q).unwrap();
assert_eq!(recon.len(), dim);
}
#[test]
fn test_dequantize_dimension_mismatch() {
let tq = TurboQuantMSE::new(64, 4, 42).unwrap();
let wrong_q = QuantizedVector {
indices: vec![0; 128],
bit_width: 4,
dim: 128,
};
assert!(tq.dequantize(&wrong_q).is_err());
}
#[test]
fn test_invalid_dimension_zero() {
assert!(TurboQuantMSE::new(0, 4, 1).is_err());
}
#[test]
fn test_invalid_bit_width() {
assert!(TurboQuantMSE::new(64, 0, 1).is_err());
assert!(TurboQuantMSE::new(64, 9, 1).is_err());
}
#[test]
fn test_quantize_rejects_non_unit_vector() {
let tq = TurboQuantMSE::new(8, 4, 1).unwrap();
let x = vec![1.0; 8];
assert!(matches!(
tq.quantize(&x),
Err(TurboQuantError::NotUnitVector(_))
));
}
#[test]
fn test_dequantize_rejects_invalid_index() {
let tq = TurboQuantMSE::new(8, 2, 1).unwrap();
let q = QuantizedVector {
indices: vec![4; 8],
bit_width: 2,
dim: 8,
};
assert!(matches!(
tq.dequantize(&q),
Err(TurboQuantError::InvalidQuantizationIndex { .. })
));
}
}