turboquant-plus-rs 0.1.0

TurboQuant: KV cache compression via PolarQuant + QJL — Rust port
Documentation
//! KV Cache integration layer for TurboQuant.
//!
//! Compresses transformer KV cache tensors using TurboQuant (for K cache, inner product
//! preservation) and PolarQuant MSE-only (for V cache, MSE preservation).
//!
//! KV cache shape: (num_layers, num_heads, seq_len, head_dim)
//! Quantization is along head_dim — each (head_dim,) vector is quantized independently.

use ndarray::{Array2, Array4};

use crate::turboquant::{CompressedVectors, TurboQuant, TurboQuantMSE};
use crate::utils::{pack_indices_batch, unpack_indices_batch};
use ndarray::Array1;

/// Container for a compressed KV cache.
pub struct CompressedKVCache {
    /// Per-layer, per-head compressed K vectors.
    pub k_compressed: Vec<Vec<CompressedVectors>>,
    /// Per-layer, per-head compressed V indices (packed, per token row).
    pub v_indices_packed: Vec<Vec<Vec<Vec<u8>>>>,
    /// Per-layer, per-head compressed V norms (float32, per token row).
    pub v_norms: Vec<Vec<Vec<f32>>>,

    pub num_layers: usize,
    pub num_heads: usize,
    pub seq_len: usize,
    pub head_dim: usize,
    pub k_bit_width: usize,
    pub v_bit_width: usize,
    /// Per-layer key bit-width used during compression.
    pub k_layer_bits: Vec<usize>,
}

/// Compress and decompress transformer KV cache tensors.
///
/// Uses:
/// - TurboQuant (Algorithm 2) for K cache — inner product preservation matters
///   for attention score computation (Q @ K^T)
/// - TurboQuantMSE (Algorithm 1) for V cache — MSE preservation matters
///   for value reconstruction (attn_weights @ V)
pub struct KVCacheCompressor {
    pub head_dim: usize,
    pub k_bits: usize,
    pub v_bits: usize,
    v_quantizer: TurboQuantMSE,
    seed: u64,
    norm_correction: bool,
    boundary_layers: usize,
    boundary_k_bits: Option<usize>,
}

impl KVCacheCompressor {
    /// Create a new KV cache compressor.
    ///
    /// # Arguments
    /// * `head_dim` - Dimension of each attention head vector.
    /// * `k_bits` - Bit-width for K cache (TurboQuant, inner product).
    /// * `v_bits` - Bit-width for V cache (PolarQuant MSE-only).
    /// * `seed` - Random seed.
    /// * `norm_correction` - Whether to apply norm correction in PolarQuant.
    pub fn new(
        head_dim: usize,
        k_bits: usize,
        v_bits: usize,
        seed: u64,
        norm_correction: bool,
    ) -> Self {
        let v_quantizer = TurboQuantMSE::new(head_dim, v_bits, seed + 500, norm_correction);

        Self {
            head_dim,
            k_bits,
            v_bits,
            v_quantizer,
            seed,
            norm_correction,
            boundary_layers: 0,
            boundary_k_bits: None,
        }
    }

    /// Create a compressor with boundary-layer K precision protection.
    ///
    /// The first and last `boundary_layers` transformer layers use `boundary_k_bits`
    /// for K-cache quantization, while middle layers use `k_bits`.
    pub fn with_boundary_k_bits(
        head_dim: usize,
        k_bits: usize,
        v_bits: usize,
        seed: u64,
        norm_correction: bool,
        boundary_layers: usize,
        boundary_k_bits: usize,
    ) -> Self {
        assert!(
            boundary_k_bits >= k_bits,
            "boundary_k_bits should be >= base k_bits"
        );
        let mut compressor = Self::new(head_dim, k_bits, v_bits, seed, norm_correction);
        compressor.boundary_layers = boundary_layers;
        compressor.boundary_k_bits = Some(boundary_k_bits);
        compressor
    }

    fn layer_k_bits(&self, layer: usize, num_layers: usize) -> usize {
        if let Some(boundary_bits) = self.boundary_k_bits {
            let n = self.boundary_layers.min(num_layers / 2);
            if layer < n || layer >= num_layers.saturating_sub(n) {
                return boundary_bits;
            }
        }
        self.k_bits
    }

    /// Compress full KV cache tensors.
    ///
    /// # Arguments
    /// * `k_cache` - Key cache, shape (num_layers, num_heads, seq_len, head_dim).
    /// * `v_cache` - Value cache, same shape.
    pub fn compress(&self, k_cache: &Array4<f64>, v_cache: &Array4<f64>) -> CompressedKVCache {
        let shape = k_cache.shape();
        let (num_layers, num_heads, seq_len, head_dim) = (shape[0], shape[1], shape[2], shape[3]);
        assert_eq!(head_dim, self.head_dim, "head_dim mismatch");
        assert_eq!(k_cache.shape(), v_cache.shape(), "K and V cache shape mismatch");

        let mut result = CompressedKVCache {
            k_compressed: Vec::with_capacity(num_layers),
            v_indices_packed: Vec::with_capacity(num_layers),
            v_norms: Vec::with_capacity(num_layers),
            num_layers,
            num_heads,
            seq_len,
            head_dim,
            k_bit_width: self.k_bits,
            v_bit_width: self.v_bits,
            k_layer_bits: Vec::with_capacity(num_layers),
        };

        let base_k_quantizer = TurboQuant::new(self.head_dim, self.k_bits, self.seed, self.norm_correction);
        let boundary_k_quantizer = self
            .boundary_k_bits
            .map(|b| TurboQuant::new(self.head_dim, b, self.seed, self.norm_correction));

        for layer in 0..num_layers {
            let mut k_layer = Vec::with_capacity(num_heads);
            let mut v_layer_idx = Vec::with_capacity(num_heads);
            let mut v_layer_norms = Vec::with_capacity(num_heads);
            let layer_k_bits = self.layer_k_bits(layer, num_layers);
            result.k_layer_bits.push(layer_k_bits);

            let k_quantizer = if layer_k_bits == self.k_bits {
                &base_k_quantizer
            } else {
                boundary_k_quantizer
                    .as_ref()
                    .expect("boundary quantizer missing for boundary layer bits")
            };

            for head in 0..num_heads {
                // Extract (seq_len, head_dim) slice for this layer/head
                let mut k_vecs = Array2::zeros((seq_len, head_dim));
                let mut v_vecs = Array2::zeros((seq_len, head_dim));
                for s in 0..seq_len {
                    for d in 0..head_dim {
                        k_vecs[[s, d]] = k_cache[[layer, head, s, d]];
                        v_vecs[[s, d]] = v_cache[[layer, head, s, d]];
                    }
                }

                // K: batch quantize all seq positions
                let k_compressed = k_quantizer.quantize_batch(&k_vecs);
                k_layer.push(k_compressed);

                // V: MSE quantize
                let (v_indices, v_norms) = self.v_quantizer.quantize_batch(&v_vecs);
                v_layer_idx.push(pack_indices_batch(&v_indices, self.v_bits));
                v_layer_norms.push(v_norms.iter().map(|&v| v as f32).collect());
            }

            result.k_compressed.push(k_layer);
            result.v_indices_packed.push(v_layer_idx);
            result.v_norms.push(v_layer_norms);
        }

        result
    }

    /// Decompress back to full KV cache tensors.
    ///
    /// Returns (k_cache, v_cache) both shape (num_layers, num_heads, seq_len, head_dim).
    pub fn decompress(&self, compressed: &CompressedKVCache) -> (Array4<f64>, Array4<f64>) {
        let mut k_cache = Array4::zeros((
            compressed.num_layers,
            compressed.num_heads,
            compressed.seq_len,
            compressed.head_dim,
        ));
        let mut v_cache = Array4::zeros(k_cache.raw_dim());

        for layer in 0..compressed.num_layers {
            let layer_k_bits = compressed
                .k_layer_bits
                .get(layer)
                .copied()
                .unwrap_or(compressed.k_bit_width);
            let k_quantizer =
                TurboQuant::new(self.head_dim, layer_k_bits, self.seed, self.norm_correction);
            for head in 0..compressed.num_heads {
                // K: dequantize
                let k_hat = k_quantizer.dequantize(&compressed.k_compressed[layer][head]);
                for s in 0..compressed.seq_len {
                    for d in 0..compressed.head_dim {
                        k_cache[[layer, head, s, d]] = k_hat[[s, d]];
                    }
                }

                // V: dequantize
                let v_indices = unpack_indices_batch(
                    &compressed.v_indices_packed[layer][head],
                    compressed.head_dim,
                    compressed.v_bit_width,
                );
                let v_norms = Array1::from_vec(
                    compressed.v_norms[layer][head]
                        .iter()
                        .map(|&v| v as f64)
                        .collect::<Vec<_>>(),
                );
                let v_hat = self.v_quantizer.dequantize_batch(&v_indices, &v_norms);
                for s in 0..compressed.seq_len {
                    for d in 0..compressed.head_dim {
                        v_cache[[layer, head, s, d]] = v_hat[[s, d]];
                    }
                }
            }
        }

        (k_cache, v_cache)
    }

    /// Compute memory usage statistics.
    pub fn memory_stats(
        &self,
        seq_len: usize,
        num_layers: usize,
        num_heads: usize,
    ) -> MemoryStats {
        let n_vectors = num_layers * num_heads * seq_len;
        let original_bytes = n_vectors * self.head_dim * 2; // fp16

        let mut k_bits_total = 0usize;
        for layer in 0..num_layers {
            let layer_k_bits = self.layer_k_bits(layer, num_layers);
            let layer_vectors = num_heads * seq_len;
            k_bits_total += layer_vectors * (self.head_dim * layer_k_bits + 64);
        }
        // V (MSE): b bits per coord + one 32-bit norm
        let v_bits_total = n_vectors * (self.head_dim * self.v_bits + 32);

        let compressed_bytes = (k_bits_total + v_bits_total) / 8;

        MemoryStats {
            original_mb: original_bytes as f64 / 1024.0 / 1024.0,
            compressed_mb: compressed_bytes as f64 / 1024.0 / 1024.0,
            compression_ratio: original_bytes as f64 / compressed_bytes as f64,
            k_bits_per_value: self.k_bits,
            v_bits_per_value: self.v_bits,
        }
    }
}

/// Memory usage statistics for KV cache compression.
#[derive(Debug)]
pub struct MemoryStats {
    pub original_mb: f64,
    pub compressed_mb: f64,
    pub compression_ratio: f64,
    pub k_bits_per_value: usize,
    pub v_bits_per_value: usize,
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array4;

    #[test]
    fn test_compress_decompress_roundtrip() {
        let compressor = KVCacheCompressor::new(16, 3, 3, 42, true);

        // Small KV cache: 1 layer, 1 head, 4 tokens, head_dim=16
        let k_cache = Array4::from_shape_fn((1, 1, 4, 16), |(_, _, s, d)| {
            ((s * 16 + d) as f64 + 1.0) / 100.0
        });
        let v_cache = Array4::from_shape_fn((1, 1, 4, 16), |(_, _, s, d)| {
            ((s * 16 + d) as f64 + 50.0) / 100.0
        });

        let compressed = compressor.compress(&k_cache, &v_cache);
        let (k_hat, v_hat) = compressor.decompress(&compressed);

        assert_eq!(k_hat.shape(), k_cache.shape());
        assert_eq!(v_hat.shape(), v_cache.shape());

        // Check that reconstruction is reasonable (not exact, but bounded)
        let k_err: f64 = (&k_cache - &k_hat).mapv(|v| v * v).sum();
        let k_orig: f64 = k_cache.mapv(|v| v * v).sum();
        assert!(
            k_err / k_orig < 1.0,
            "K cache relative MSE: {}",
            k_err / k_orig
        );

        let v_err: f64 = (&v_cache - &v_hat).mapv(|v| v * v).sum();
        let v_orig: f64 = v_cache.mapv(|v| v * v).sum();
        assert!(
            v_err / v_orig < 1.0,
            "V cache relative MSE: {}",
            v_err / v_orig
        );
    }

    #[test]
    fn test_memory_stats() {
        let compressor = KVCacheCompressor::new(128, 3, 3, 42, true);
        let stats = compressor.memory_stats(1024, 32, 32);
        assert!(stats.compression_ratio > 1.0, "Should compress");
        assert!(stats.compressed_mb < stats.original_mb, "Should be smaller");
    }

    #[test]
    fn test_boundary_k_bits_policy() {
        let compressor = KVCacheCompressor::with_boundary_k_bits(16, 3, 3, 42, true, 1, 4);

        let k_cache = Array4::from_shape_fn((4, 1, 2, 16), |(l, _, s, d)| {
            ((l * 32 + s * 16 + d) as f64 + 1.0) / 100.0
        });
        let v_cache = Array4::from_shape_fn((4, 1, 2, 16), |(l, _, s, d)| {
            ((l * 32 + s * 16 + d) as f64 + 10.0) / 100.0
        });

        let compressed = compressor.compress(&k_cache, &v_cache);
        assert_eq!(compressed.k_layer_bits, vec![4, 3, 3, 4]);

        let (k_hat, v_hat) = compressor.decompress(&compressed);
        assert_eq!(k_hat.shape(), k_cache.shape());
        assert_eq!(v_hat.shape(), v_cache.shape());
    }
}