use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use turboquant::turboquant_mse::TurboQuantMSE;
use turboquant::turboquant_prod::TurboQuantProd;
use turboquant::utils::{inner_product, norm, normalize};
fn main() {
let dim = 128usize;
let seed = 42u64;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let raw_x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
let x = normalize(&raw_x).unwrap();
assert!((norm(&x) - 1.0).abs() < 1e-10, "x must be on unit sphere");
let raw_q: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
let query = normalize(&raw_q).unwrap();
let true_ip = inner_product(&x, &query);
println!("╔══════════════════════════════════════════════════════╗");
println!("║ TurboQuant Basic Quantization Demo ║");
println!("╚══════════════════════════════════════════════════════╝");
println!();
println!(" Dimension : {}", dim);
println!(" ‖x‖₂ : {:.10}", norm(&x));
println!(" True ⟨x,q⟩: {:.6}", true_ip);
println!();
println!("┌─────────────────────────────────────────────────────────────────┐");
println!("│ TurboQuantMSE: Random Rotation + Lloyd-Max Scalar Quantization │");
println!("├────────┬──────────────┬──────────────┬────────────┬─────────────┤");
println!("│ Bits │ Actual MSE │ Theor. Bound │ Ratio │ Storage │");
println!("├────────┼──────────────┼──────────────┼────────────┼─────────────┤");
for bits in [2u8, 3, 4] {
let tq = TurboQuantMSE::new(dim, bits, seed).expect("Failed to create TurboQuantMSE");
let q = tq.quantize(&x).expect("Quantization failed");
let x_approx = tq.dequantize(&q).expect("Dequantization failed");
let mse: f64 = x
.iter()
.zip(x_approx.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ dim as f64;
let bound = tq.distortion_bound();
let ratio = q.compression_ratio();
let storage = q.bytes();
println!(
"│ {:2} │ {:.8} │ {:.8} │ {:5.1}x │ {:.1} bytes │",
bits, mse, bound, ratio, storage
);
}
println!("└────────┴──────────────┴──────────────┴────────────┴─────────────┘");
println!();
let tq4 = TurboQuantMSE::new(dim, 4, seed).unwrap();
let q4 = tq4.quantize(&x).unwrap();
let x4 = tq4.dequantize(&q4).unwrap();
let error_vec: Vec<f64> = x.iter().zip(x4.iter()).map(|(a, b)| a - b).collect();
let max_err = error_vec.iter().map(|e| e.abs()).fold(0.0_f64, f64::max);
let mean_err = error_vec.iter().map(|e| e.abs()).sum::<f64>() / dim as f64;
println!(" 4-bit analysis:");
println!(" Max component error : {:.6}", max_err);
println!(" Mean component error: {:.6}", mean_err);
println!(" ‖x - x̃‖₂ : {:.6}", norm(&error_vec));
println!();
println!("┌─────────────────────────────────────────────────────────────────┐");
println!("│ TurboQuantProd: Two-Stage Inner-Product-Optimal Quantization │");
println!("├────────┬──────────────┬────────────┬──────────────┬─────────────┤");
println!("│ Bits │ Est. ⟨x,q⟩ │ True ⟨x,q⟩│ Error │ Ratio │");
println!("├────────┼──────────────┼────────────┼──────────────┼─────────────┤");
for bits in [2u8, 3, 4] {
let tq = TurboQuantProd::new(dim, bits, seed).expect("Failed to create TurboQuantProd");
let q = tq.quantize(&x).expect("Quantization failed");
let est_ip = tq
.estimate_inner_product(&q, &query)
.expect("Estimation failed");
let err = (true_ip - est_ip).abs();
println!(
"│ {:2} │ {:.6} │ {:.6} │ {:.6} │ {:.1}x │",
bits,
est_ip,
true_ip,
err,
q.compression_ratio()
);
}
println!("└────────┴──────────────┴────────────┴──────────────┴─────────────┘");
println!();
println!("Statistical summary (100 random vectors, 4-bit MSE):");
let n_trials = 100usize;
let tq_stat = TurboQuantMSE::new(dim, 4, seed).unwrap();
let (total_mse, total_ip_err) = {
let mut rng2 = rand::rngs::StdRng::seed_from_u64(seed + 100);
let mut sum_mse = 0.0f64;
let mut sum_ip_err = 0.0f64;
let tq_prod = TurboQuantProd::new(dim, 4, seed).unwrap();
for _ in 0..n_trials {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng2)).collect();
let v = normalize(&raw).unwrap();
let raw2: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng2)).collect();
let q2 = normalize(&raw2).unwrap();
let true_ip2 = inner_product(&v, &q2);
let mse = tq_stat.actual_mse(&v).unwrap();
sum_mse += mse;
let pq = tq_prod.quantize(&v).unwrap();
let est = tq_prod.estimate_inner_product(&pq, &q2).unwrap();
sum_ip_err += (true_ip2 - est).abs();
}
(sum_mse / n_trials as f64, sum_ip_err / n_trials as f64)
};
println!(" Mean MSE (4-bit TurboQuantMSE) : {:.8}", total_mse);
println!(
" Theor. bound : {:.8}",
tq_stat.distortion_bound()
);
println!(
" Mean |IP error| (4-bit TurboQuantProd): {:.6}",
total_ip_err
);
}