use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
use crate::qjl::QJL;
use crate::turboquant_mse::TurboQuantMSE;
use crate::utils::{inner_product, norm, validate_unit_vector};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct ProdQuantized {
pub mse_indices: Vec<u8>,
pub qjl_signs: Vec<bool>,
pub residual_norm: f64,
pub bit_width: u8,
pub dim: usize,
}
impl ProdQuantized {
pub fn bytes(&self) -> f64 {
let stage1_bits = self.mse_indices.len() as f64 * (self.bit_width - 1) as f64;
let stage2_bits = self.qjl_signs.len() as f64;
(stage1_bits + stage2_bits + 32.0) / 8.0
}
pub fn compression_ratio(&self) -> f64 {
(self.dim as f64 * 4.0) / self.bytes()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TurboQuantProd {
mse_stage: TurboQuantMSE,
qjl_stage: QJL,
pub dim: usize,
pub bit_width: u8,
}
impl TurboQuantProd {
fn validate_quantized(&self, q: &ProdQuantized) -> Result<()> {
if q.dim != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: q.dim,
});
}
if q.bit_width != self.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: q.bit_width,
});
}
if q.mse_indices.len() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: "TurboQuantProd MSE indices".into(),
expected: self.dim,
got: q.mse_indices.len(),
});
}
if q.qjl_signs.len() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: "TurboQuantProd QJL signs".into(),
expected: self.dim,
got: q.qjl_signs.len(),
});
}
if !q.residual_norm.is_finite() || q.residual_norm < 0.0 {
return Err(TurboQuantError::InvalidValue {
context: "TurboQuantProd residual norm".into(),
value: q.residual_norm,
});
}
Ok(())
}
pub fn new(dim: usize, bit_width: u8, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
if !(2..=8).contains(&bit_width) {
return Err(TurboQuantError::InvalidBitWidth(bit_width));
}
let mse_bits = bit_width - 1;
let mse_stage = TurboQuantMSE::new(dim, mse_bits, seed)?;
let qjl_stage = QJL::new(dim, seed + 1)?;
Ok(Self {
mse_stage,
qjl_stage,
dim,
bit_width,
})
}
pub fn quantize(&self, x: &[f64]) -> Result<ProdQuantized> {
if x.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: x.len(),
});
}
validate_unit_vector(x, "TurboQuantProd input")?;
let q_mse = self.mse_stage.quantize(x)?;
let x_approx = self.mse_stage.dequantize(&q_mse)?;
let residual: Vec<f64> = x
.iter()
.zip(x_approx.iter())
.map(|(xi, xi_a)| xi - xi_a)
.collect();
let residual_norm = norm(&residual);
let q_qjl = self.qjl_stage.quantize(&residual)?;
Ok(ProdQuantized {
mse_indices: q_mse.indices,
qjl_signs: q_qjl.signs,
residual_norm,
bit_width: self.bit_width,
dim: self.dim,
})
}
pub fn dequantize(&self, q: &ProdQuantized) -> Result<Vec<f64>> {
self.validate_quantized(q)?;
let x_mse = self
.mse_stage
.dequantize_parts(&q.mse_indices, q.bit_width - 1, q.dim)?;
let x_qjl = self
.qjl_stage
.dequantize_parts(&q.qjl_signs, q.residual_norm, q.dim)?;
let result: Vec<f64> = x_mse.iter().zip(x_qjl.iter()).map(|(a, b)| a + b).collect();
Ok(result)
}
pub fn estimate_inner_product(&self, q: &ProdQuantized, y: &[f64]) -> Result<f64> {
if y.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: y.len(),
});
}
self.validate_quantized(q)?;
let x_mse = self
.mse_stage
.dequantize_parts(&q.mse_indices, q.bit_width - 1, q.dim)?;
let ip_mse = inner_product(&x_mse, y);
let ip_qjl =
self.qjl_stage
.estimate_inner_product_parts(&q.qjl_signs, q.residual_norm, q.dim, y)?;
Ok(ip_mse + ip_qjl)
}
pub fn distortion_bound(&self, y_norm: f64) -> f64 {
let b = self.bit_width as f64;
let d = self.dim as f64;
3.0_f64.sqrt() * std::f64::consts::PI * std::f64::consts::PI * y_norm * y_norm
/ (d * (4.0_f64).powf(b))
}
pub fn mse_stage(&self) -> &TurboQuantMSE {
&self.mse_stage
}
pub fn qjl_stage(&self) -> &QJL {
&self.qjl_stage
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::{inner_product, normalize};
fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&x).unwrap()
}
#[test]
fn test_quantize_dequantize_shape() {
let tq = TurboQuantProd::new(64, 3, 42).unwrap();
let x = random_unit_vector(64, 1);
let q = tq.quantize(&x).unwrap();
assert_eq!(q.mse_indices.len(), 64);
assert_eq!(q.qjl_signs.len(), 64);
let recon = tq.dequantize(&q).unwrap();
assert_eq!(recon.len(), 64);
}
#[test]
fn test_inner_product_estimate_quality() {
let dim = 256;
let tq = TurboQuantProd::new(dim, 4, 7).unwrap();
let mut total_err = 0.0;
let n = 10;
for seed in 0..n {
let x = random_unit_vector(dim, seed);
let y = random_unit_vector(dim, seed + 50);
let true_ip = inner_product(&x, &y);
let q = tq.quantize(&x).unwrap();
let est_ip = tq.estimate_inner_product(&q, &y).unwrap();
total_err += (true_ip - est_ip).abs();
}
let avg_err = total_err / n as f64;
assert!(
avg_err < 0.15,
"avg inner product error {} too large",
avg_err
);
}
#[test]
fn test_invalid_bit_width() {
assert!(TurboQuantProd::new(64, 1, 1).is_err()); assert!(TurboQuantProd::new(64, 9, 1).is_err()); }
#[test]
fn test_dimension_mismatch_quantize() {
let tq = TurboQuantProd::new(64, 3, 1).unwrap();
let x = vec![0.0; 128]; assert!(tq.quantize(&x).is_err());
}
#[test]
fn test_dimension_mismatch_estimate_ip() {
let tq = TurboQuantProd::new(64, 3, 1).unwrap();
let x = random_unit_vector(64, 1);
let q = tq.quantize(&x).unwrap();
let bad_y = vec![0.0; 128];
assert!(tq.estimate_inner_product(&q, &bad_y).is_err());
}
#[test]
fn test_distortion_bound() {
let tq = TurboQuantProd::new(128, 4, 1).unwrap();
let bound = tq.distortion_bound(1.0);
assert!(bound > 0.0 && bound < 1.0, "bound={}", bound);
let bound2 = tq.distortion_bound(2.0);
assert!(
(bound2 / bound - 4.0).abs() < 1e-10,
"bound should scale as y_norm²"
);
}
#[test]
fn test_dequantize_dimension_mismatch() {
let tq = TurboQuantProd::new(64, 3, 1).unwrap();
let bad_q = ProdQuantized {
mse_indices: vec![0; 128],
qjl_signs: vec![true; 128],
residual_norm: 0.1,
bit_width: 3,
dim: 128,
};
assert!(tq.dequantize(&bad_q).is_err());
}
#[test]
fn test_prod_improves_inner_product_estimate() {
let dim = 128;
let n_trials = 20;
let tq_mse = TurboQuantMSE::new(dim, 3, 1).unwrap();
let tq_prod = TurboQuantProd::new(dim, 3, 1).unwrap();
let mut mse_ip_err = 0.0;
let mut prod_ip_err = 0.0;
for seed in 0..n_trials {
let x = random_unit_vector(dim, seed);
let y = random_unit_vector(dim, seed + 100);
let true_ip = inner_product(&x, &y);
let q_mse = tq_mse.quantize(&x).unwrap();
let x_mse = tq_mse.dequantize(&q_mse).unwrap();
mse_ip_err += (true_ip - inner_product(&x_mse, &y)).abs();
let q_prod = tq_prod.quantize(&x).unwrap();
let est_ip = tq_prod.estimate_inner_product(&q_prod, &y).unwrap();
prod_ip_err += (true_ip - est_ip).abs();
}
let avg_mse_err = mse_ip_err / n_trials as f64;
let avg_prod_err = prod_ip_err / n_trials as f64;
assert!(avg_mse_err < 0.5, "MSE IP error too large: {}", avg_mse_err);
assert!(
avg_prod_err < 0.5,
"Prod IP error too large: {}",
avg_prod_err
);
}
#[test]
fn test_quantize_rejects_non_unit_vector() {
let tq = TurboQuantProd::new(8, 3, 1).unwrap();
let x = vec![1.0; 8];
assert!(matches!(
tq.quantize(&x),
Err(TurboQuantError::NotUnitVector(_))
));
}
#[test]
fn test_dequantize_rejects_wrong_sign_count() {
let tq = TurboQuantProd::new(8, 3, 1).unwrap();
let q = ProdQuantized {
mse_indices: vec![0; 8],
qjl_signs: vec![true; 7],
residual_norm: 0.1,
bit_width: 3,
dim: 8,
};
assert!(matches!(
tq.dequantize(&q),
Err(TurboQuantError::LengthMismatch { .. })
));
}
}