turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
//! KV Cache Compression Demo
//!
//! Simulates compressing a batch of 64 key/value vectors (dim=128),
//! as would appear in a transformer attention layer.
//!
//! Demonstrates:
//! - PolarQuant for KV cache compression
//! - Compression ratio and reconstruction error
//! - Inner product preservation (critical for attention computation)

use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use turboquant::polar::PolarQuant;
use turboquant::turboquant_mse::TurboQuantMSE;
use turboquant::turboquant_prod::TurboQuantProd;
use turboquant::utils::{inner_product, normalize};

fn main() {
    let dim = 128usize;
    let batch_size = 64usize;
    let seed = 99u64;

    println!("╔══════════════════════════════════════════════════════════╗");
    println!("║         TurboQuant KV Cache Compression Demo             ║");
    println!("╚══════════════════════════════════════════════════════════╝");
    println!();
    println!("  Batch size : {} vectors", batch_size);
    println!("  Dimension  : {}", dim);
    println!("  F32 storage: {} bytes per vector", dim * 4);
    println!();

    // ── Generate simulated KV cache vectors ──────────────────────────────
    // In practice these would be attention key/value projections.
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
    let normal = Normal::new(0.0, 1.0).unwrap();

    let keys: Vec<Vec<f64>> = (0..batch_size)
        .map(|_| {
            let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
            normalize(&raw).unwrap()
        })
        .collect();

    let _values: Vec<Vec<f64>> = (0..batch_size)
        .map(|_| {
            let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
            normalize(&raw).unwrap()
        })
        .collect();

    // Query vector (the "current" attention query)
    let raw_query: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
    let query = normalize(&raw_query).unwrap();

    // ── Baseline: true attention scores ──────────────────────────────────
    let true_scores: Vec<f64> = keys.iter().map(|k| inner_product(k, &query)).collect();

    // ── Method 1: PolarQuant ──────────────────────────────────────────────
    println!("┌─────────────────────────────────────────────────────────────┐");
    println!("│  PolarQuant (Hierarchical Polar Transform)                   │");
    println!("└─────────────────────────────────────────────────────────────┘");

    for bit_width in [4u8, 8] {
        let pq = PolarQuant::new(dim, seed, bit_width).expect("PolarQuant init failed");

        let mut total_mse = 0.0f64;
        let mut total_ip_err = 0.0f64;
        let mut total_bytes = 0.0f64;

        let _quantized_keys: Vec<_> = keys
            .iter()
            .enumerate()
            .map(|(i, k)| {
                let q = pq.quantize(k).expect("quantize failed");
                let recon = pq.dequantize(&q).expect("dequantize failed");

                let mse: f64 = k
                    .iter()
                    .zip(recon.iter())
                    .map(|(a, b)| (a - b) * (a - b))
                    .sum::<f64>()
                    / dim as f64;

                total_mse += mse;
                total_bytes += q.bytes();

                // True attention score vs reconstructed
                let est_score = inner_product(&recon, &query);
                total_ip_err += (true_scores[i] - est_score).abs();

                q
            })
            .collect();

        let avg_mse = total_mse / batch_size as f64;
        let avg_ip_err = total_ip_err / batch_size as f64;
        let avg_bytes = total_bytes / batch_size as f64;
        let ratio = (dim as f64 * 4.0) / avg_bytes;

        println!("  {}-bit PolarQuant:", bit_width);
        println!("    Avg reconstruction MSE    : {:.6}", avg_mse);
        println!("    Avg attention score error : {:.6}", avg_ip_err);
        println!("    Avg bytes/vector          : {:.1}", avg_bytes);
        println!("    Compression ratio         : {:.2}x", ratio);
        println!();
    }

    // ── Method 2: TurboQuantMSE ───────────────────────────────────────────
    println!("┌─────────────────────────────────────────────────────────────┐");
    println!("│  TurboQuantMSE (Random Rotation + Lloyd-Max)                │");
    println!("└─────────────────────────────────────────────────────────────┘");

    for bit_width in [2u8, 4] {
        let tq = TurboQuantMSE::new(dim, bit_width, seed).expect("TurboQuantMSE init failed");

        let mut total_mse = 0.0f64;
        let mut total_ip_err = 0.0f64;
        let mut total_bytes = 0.0f64;

        for (i, key) in keys.iter().enumerate() {
            let q = tq.quantize(key).expect("quantize failed");
            let recon = tq.dequantize(&q).expect("dequantize failed");

            let mse: f64 = key
                .iter()
                .zip(recon.iter())
                .map(|(a, b)| (a - b) * (a - b))
                .sum::<f64>()
                / dim as f64;

            total_mse += mse;
            total_bytes += q.bytes();

            let est_score = inner_product(&recon, &query);
            total_ip_err += (true_scores[i] - est_score).abs();
        }

        let avg_mse = total_mse / batch_size as f64;
        let avg_ip_err = total_ip_err / batch_size as f64;
        let avg_bytes = total_bytes / batch_size as f64;
        let ratio = (dim as f64 * 4.0) / avg_bytes;

        println!("  {}-bit TurboQuantMSE:", bit_width);
        println!("    Avg reconstruction MSE    : {:.6}", avg_mse);
        println!("    Avg attention score error : {:.6}", avg_ip_err);
        println!("    Avg bytes/vector          : {:.1}", avg_bytes);
        println!("    Compression ratio         : {:.2}x", ratio);
        println!();
    }

    // ── Method 3: TurboQuantProd ──────────────────────────────────────────
    println!("┌─────────────────────────────────────────────────────────────┐");
    println!("│  TurboQuantProd (MSE + QJL Residual, Inner-Product Optimal) │");
    println!("└─────────────────────────────────────────────────────────────┘");

    for bit_width in [3u8, 4] {
        let tq = TurboQuantProd::new(dim, bit_width, seed).expect("TurboQuantProd init failed");

        let mut total_ip_err = 0.0f64;
        let mut total_bytes = 0.0f64;

        for (i, key) in keys.iter().enumerate() {
            let q = tq.quantize(key).expect("quantize failed");
            let est_score = tq
                .estimate_inner_product(&q, &query)
                .expect("estimation failed");
            total_ip_err += (true_scores[i] - est_score).abs();
            total_bytes += q.bytes();
        }

        let avg_ip_err = total_ip_err / batch_size as f64;
        let avg_bytes = total_bytes / batch_size as f64;
        let ratio = (dim as f64 * 4.0) / avg_bytes;
        let bound = tq.distortion_bound(1.0); // unit query vector

        println!("  {}-bit TurboQuantProd:", bit_width);
        println!("    Avg attention score error : {:.6}", avg_ip_err);
        println!("    Theoretical IP bound      : {:.8}", bound);
        println!("    Avg bytes/vector          : {:.1}", avg_bytes);
        println!("    Compression ratio         : {:.2}x", ratio);
        println!();
    }

    // ── Softmax attention accuracy ────────────────────────────────────────
    println!("┌─────────────────────────────────────────────────────────────┐");
    println!("│  Softmax Attention Accuracy (Top-1 key retrieval)           │");
    println!("└─────────────────────────────────────────────────────────────┘");

    let true_top1 = argmax(&true_scores);

    // TurboQuantProd 4-bit scores
    let tq_prod = TurboQuantProd::new(dim, 4, seed).unwrap();
    let prod_scores: Vec<f64> = keys
        .iter()
        .map(|k| {
            let q = tq_prod.quantize(k).unwrap();
            tq_prod.estimate_inner_product(&q, &query).unwrap()
        })
        .collect();
    let prod_top1 = argmax(&prod_scores);

    println!(
        "  True top-1 key: {}, TurboQuantProd 4-bit top-1: {}  {}",
        true_top1,
        prod_top1,
        if true_top1 == prod_top1 { "" } else { "" }
    );

    let true_score_top = true_scores[true_top1];
    let prod_score_top = prod_scores[prod_top1];
    println!(
        "  True score: {:.6}, Estimated: {:.6}",
        true_score_top, prod_score_top
    );

    println!();
    println!(
        "  Total KV cache uncompressed: {} bytes",
        batch_size * 2 * dim * 4
    );
    let prod_4bit_bytes = batch_size * 2 * (dim * 4 / 8 + 4); // rough estimate
    println!(
        "  Total KV cache (TurboQuantProd 4-bit, approx): {} bytes",
        prod_4bit_bytes
    );
    println!(
        "  Approx compression: {:.1}x",
        (batch_size * 2 * dim * 4) as f64 / prod_4bit_bytes as f64
    );
}

fn argmax(v: &[f64]) -> usize {
    v.iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
        .map(|(i, _)| i)
        .unwrap_or(0)
}