use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use turboquant::polar::PolarQuant;
use turboquant::turboquant_mse::TurboQuantMSE;
use turboquant::turboquant_prod::TurboQuantProd;
use turboquant::utils::{inner_product, normalize};
fn main() {
let dim = 128usize;
let batch_size = 64usize;
let seed = 99u64;
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ TurboQuant KV Cache Compression Demo ║");
println!("╚══════════════════════════════════════════════════════════╝");
println!();
println!(" Batch size : {} vectors", batch_size);
println!(" Dimension : {}", dim);
println!(" F32 storage: {} bytes per vector", dim * 4);
println!();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let keys: Vec<Vec<f64>> = (0..batch_size)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect();
let _values: Vec<Vec<f64>> = (0..batch_size)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect();
let raw_query: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
let query = normalize(&raw_query).unwrap();
let true_scores: Vec<f64> = keys.iter().map(|k| inner_product(k, &query)).collect();
println!("┌─────────────────────────────────────────────────────────────┐");
println!("│ PolarQuant (Hierarchical Polar Transform) │");
println!("└─────────────────────────────────────────────────────────────┘");
for bit_width in [4u8, 8] {
let pq = PolarQuant::new(dim, seed, bit_width).expect("PolarQuant init failed");
let mut total_mse = 0.0f64;
let mut total_ip_err = 0.0f64;
let mut total_bytes = 0.0f64;
let _quantized_keys: Vec<_> = keys
.iter()
.enumerate()
.map(|(i, k)| {
let q = pq.quantize(k).expect("quantize failed");
let recon = pq.dequantize(&q).expect("dequantize failed");
let mse: f64 = k
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ dim as f64;
total_mse += mse;
total_bytes += q.bytes();
let est_score = inner_product(&recon, &query);
total_ip_err += (true_scores[i] - est_score).abs();
q
})
.collect();
let avg_mse = total_mse / batch_size as f64;
let avg_ip_err = total_ip_err / batch_size as f64;
let avg_bytes = total_bytes / batch_size as f64;
let ratio = (dim as f64 * 4.0) / avg_bytes;
println!(" {}-bit PolarQuant:", bit_width);
println!(" Avg reconstruction MSE : {:.6}", avg_mse);
println!(" Avg attention score error : {:.6}", avg_ip_err);
println!(" Avg bytes/vector : {:.1}", avg_bytes);
println!(" Compression ratio : {:.2}x", ratio);
println!();
}
println!("┌─────────────────────────────────────────────────────────────┐");
println!("│ TurboQuantMSE (Random Rotation + Lloyd-Max) │");
println!("└─────────────────────────────────────────────────────────────┘");
for bit_width in [2u8, 4] {
let tq = TurboQuantMSE::new(dim, bit_width, seed).expect("TurboQuantMSE init failed");
let mut total_mse = 0.0f64;
let mut total_ip_err = 0.0f64;
let mut total_bytes = 0.0f64;
for (i, key) in keys.iter().enumerate() {
let q = tq.quantize(key).expect("quantize failed");
let recon = tq.dequantize(&q).expect("dequantize failed");
let mse: f64 = key
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ dim as f64;
total_mse += mse;
total_bytes += q.bytes();
let est_score = inner_product(&recon, &query);
total_ip_err += (true_scores[i] - est_score).abs();
}
let avg_mse = total_mse / batch_size as f64;
let avg_ip_err = total_ip_err / batch_size as f64;
let avg_bytes = total_bytes / batch_size as f64;
let ratio = (dim as f64 * 4.0) / avg_bytes;
println!(" {}-bit TurboQuantMSE:", bit_width);
println!(" Avg reconstruction MSE : {:.6}", avg_mse);
println!(" Avg attention score error : {:.6}", avg_ip_err);
println!(" Avg bytes/vector : {:.1}", avg_bytes);
println!(" Compression ratio : {:.2}x", ratio);
println!();
}
println!("┌─────────────────────────────────────────────────────────────┐");
println!("│ TurboQuantProd (MSE + QJL Residual, Inner-Product Optimal) │");
println!("└─────────────────────────────────────────────────────────────┘");
for bit_width in [3u8, 4] {
let tq = TurboQuantProd::new(dim, bit_width, seed).expect("TurboQuantProd init failed");
let mut total_ip_err = 0.0f64;
let mut total_bytes = 0.0f64;
for (i, key) in keys.iter().enumerate() {
let q = tq.quantize(key).expect("quantize failed");
let est_score = tq
.estimate_inner_product(&q, &query)
.expect("estimation failed");
total_ip_err += (true_scores[i] - est_score).abs();
total_bytes += q.bytes();
}
let avg_ip_err = total_ip_err / batch_size as f64;
let avg_bytes = total_bytes / batch_size as f64;
let ratio = (dim as f64 * 4.0) / avg_bytes;
let bound = tq.distortion_bound(1.0);
println!(" {}-bit TurboQuantProd:", bit_width);
println!(" Avg attention score error : {:.6}", avg_ip_err);
println!(" Theoretical IP bound : {:.8}", bound);
println!(" Avg bytes/vector : {:.1}", avg_bytes);
println!(" Compression ratio : {:.2}x", ratio);
println!();
}
println!("┌─────────────────────────────────────────────────────────────┐");
println!("│ Softmax Attention Accuracy (Top-1 key retrieval) │");
println!("└─────────────────────────────────────────────────────────────┘");
let true_top1 = argmax(&true_scores);
let tq_prod = TurboQuantProd::new(dim, 4, seed).unwrap();
let prod_scores: Vec<f64> = keys
.iter()
.map(|k| {
let q = tq_prod.quantize(k).unwrap();
tq_prod.estimate_inner_product(&q, &query).unwrap()
})
.collect();
let prod_top1 = argmax(&prod_scores);
println!(
" True top-1 key: {}, TurboQuantProd 4-bit top-1: {} {}",
true_top1,
prod_top1,
if true_top1 == prod_top1 { "✓" } else { "✗" }
);
let true_score_top = true_scores[true_top1];
let prod_score_top = prod_scores[prod_top1];
println!(
" True score: {:.6}, Estimated: {:.6}",
true_score_top, prod_score_top
);
println!();
println!(
" Total KV cache uncompressed: {} bytes",
batch_size * 2 * dim * 4
);
let prod_4bit_bytes = batch_size * 2 * (dim * 4 / 8 + 4); println!(
" Total KV cache (TurboQuantProd 4-bit, approx): {} bytes",
prod_4bit_bytes
);
println!(
" Approx compression: {:.1}x",
(batch_size * 2 * dim * 4) as f64 / prod_4bit_bytes as f64
);
}
fn argmax(v: &[f64]) -> usize {
v.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
}