use ndarray::{Array2, Array4};
use crate::turboquant::{CompressedVectors, TurboQuant, TurboQuantMSE};
use crate::utils::{pack_indices_batch, unpack_indices_batch};
use ndarray::Array1;
pub struct CompressedKVCache {
pub k_compressed: Vec<Vec<CompressedVectors>>,
pub v_indices_packed: Vec<Vec<Vec<Vec<u8>>>>,
pub v_norms: Vec<Vec<Vec<f32>>>,
pub num_layers: usize,
pub num_heads: usize,
pub seq_len: usize,
pub head_dim: usize,
pub k_bit_width: usize,
pub v_bit_width: usize,
pub k_layer_bits: Vec<usize>,
}
pub struct KVCacheCompressor {
pub head_dim: usize,
pub k_bits: usize,
pub v_bits: usize,
v_quantizer: TurboQuantMSE,
seed: u64,
norm_correction: bool,
boundary_layers: usize,
boundary_k_bits: Option<usize>,
}
impl KVCacheCompressor {
pub fn new(
head_dim: usize,
k_bits: usize,
v_bits: usize,
seed: u64,
norm_correction: bool,
) -> Self {
let v_quantizer = TurboQuantMSE::new(head_dim, v_bits, seed + 500, norm_correction);
Self {
head_dim,
k_bits,
v_bits,
v_quantizer,
seed,
norm_correction,
boundary_layers: 0,
boundary_k_bits: None,
}
}
pub fn with_boundary_k_bits(
head_dim: usize,
k_bits: usize,
v_bits: usize,
seed: u64,
norm_correction: bool,
boundary_layers: usize,
boundary_k_bits: usize,
) -> Self {
assert!(
boundary_k_bits >= k_bits,
"boundary_k_bits should be >= base k_bits"
);
let mut compressor = Self::new(head_dim, k_bits, v_bits, seed, norm_correction);
compressor.boundary_layers = boundary_layers;
compressor.boundary_k_bits = Some(boundary_k_bits);
compressor
}
fn layer_k_bits(&self, layer: usize, num_layers: usize) -> usize {
if let Some(boundary_bits) = self.boundary_k_bits {
let n = self.boundary_layers.min(num_layers / 2);
if layer < n || layer >= num_layers.saturating_sub(n) {
return boundary_bits;
}
}
self.k_bits
}
pub fn compress(&self, k_cache: &Array4<f64>, v_cache: &Array4<f64>) -> CompressedKVCache {
let shape = k_cache.shape();
let (num_layers, num_heads, seq_len, head_dim) = (shape[0], shape[1], shape[2], shape[3]);
assert_eq!(head_dim, self.head_dim, "head_dim mismatch");
assert_eq!(k_cache.shape(), v_cache.shape(), "K and V cache shape mismatch");
let mut result = CompressedKVCache {
k_compressed: Vec::with_capacity(num_layers),
v_indices_packed: Vec::with_capacity(num_layers),
v_norms: Vec::with_capacity(num_layers),
num_layers,
num_heads,
seq_len,
head_dim,
k_bit_width: self.k_bits,
v_bit_width: self.v_bits,
k_layer_bits: Vec::with_capacity(num_layers),
};
let base_k_quantizer = TurboQuant::new(self.head_dim, self.k_bits, self.seed, self.norm_correction);
let boundary_k_quantizer = self
.boundary_k_bits
.map(|b| TurboQuant::new(self.head_dim, b, self.seed, self.norm_correction));
for layer in 0..num_layers {
let mut k_layer = Vec::with_capacity(num_heads);
let mut v_layer_idx = Vec::with_capacity(num_heads);
let mut v_layer_norms = Vec::with_capacity(num_heads);
let layer_k_bits = self.layer_k_bits(layer, num_layers);
result.k_layer_bits.push(layer_k_bits);
let k_quantizer = if layer_k_bits == self.k_bits {
&base_k_quantizer
} else {
boundary_k_quantizer
.as_ref()
.expect("boundary quantizer missing for boundary layer bits")
};
for head in 0..num_heads {
let mut k_vecs = Array2::zeros((seq_len, head_dim));
let mut v_vecs = Array2::zeros((seq_len, head_dim));
for s in 0..seq_len {
for d in 0..head_dim {
k_vecs[[s, d]] = k_cache[[layer, head, s, d]];
v_vecs[[s, d]] = v_cache[[layer, head, s, d]];
}
}
let k_compressed = k_quantizer.quantize_batch(&k_vecs);
k_layer.push(k_compressed);
let (v_indices, v_norms) = self.v_quantizer.quantize_batch(&v_vecs);
v_layer_idx.push(pack_indices_batch(&v_indices, self.v_bits));
v_layer_norms.push(v_norms.iter().map(|&v| v as f32).collect());
}
result.k_compressed.push(k_layer);
result.v_indices_packed.push(v_layer_idx);
result.v_norms.push(v_layer_norms);
}
result
}
pub fn decompress(&self, compressed: &CompressedKVCache) -> (Array4<f64>, Array4<f64>) {
let mut k_cache = Array4::zeros((
compressed.num_layers,
compressed.num_heads,
compressed.seq_len,
compressed.head_dim,
));
let mut v_cache = Array4::zeros(k_cache.raw_dim());
for layer in 0..compressed.num_layers {
let layer_k_bits = compressed
.k_layer_bits
.get(layer)
.copied()
.unwrap_or(compressed.k_bit_width);
let k_quantizer =
TurboQuant::new(self.head_dim, layer_k_bits, self.seed, self.norm_correction);
for head in 0..compressed.num_heads {
let k_hat = k_quantizer.dequantize(&compressed.k_compressed[layer][head]);
for s in 0..compressed.seq_len {
for d in 0..compressed.head_dim {
k_cache[[layer, head, s, d]] = k_hat[[s, d]];
}
}
let v_indices = unpack_indices_batch(
&compressed.v_indices_packed[layer][head],
compressed.head_dim,
compressed.v_bit_width,
);
let v_norms = Array1::from_vec(
compressed.v_norms[layer][head]
.iter()
.map(|&v| v as f64)
.collect::<Vec<_>>(),
);
let v_hat = self.v_quantizer.dequantize_batch(&v_indices, &v_norms);
for s in 0..compressed.seq_len {
for d in 0..compressed.head_dim {
v_cache[[layer, head, s, d]] = v_hat[[s, d]];
}
}
}
}
(k_cache, v_cache)
}
pub fn memory_stats(
&self,
seq_len: usize,
num_layers: usize,
num_heads: usize,
) -> MemoryStats {
let n_vectors = num_layers * num_heads * seq_len;
let original_bytes = n_vectors * self.head_dim * 2;
let mut k_bits_total = 0usize;
for layer in 0..num_layers {
let layer_k_bits = self.layer_k_bits(layer, num_layers);
let layer_vectors = num_heads * seq_len;
k_bits_total += layer_vectors * (self.head_dim * layer_k_bits + 64);
}
let v_bits_total = n_vectors * (self.head_dim * self.v_bits + 32);
let compressed_bytes = (k_bits_total + v_bits_total) / 8;
MemoryStats {
original_mb: original_bytes as f64 / 1024.0 / 1024.0,
compressed_mb: compressed_bytes as f64 / 1024.0 / 1024.0,
compression_ratio: original_bytes as f64 / compressed_bytes as f64,
k_bits_per_value: self.k_bits,
v_bits_per_value: self.v_bits,
}
}
}
#[derive(Debug)]
pub struct MemoryStats {
pub original_mb: f64,
pub compressed_mb: f64,
pub compression_ratio: f64,
pub k_bits_per_value: usize,
pub v_bits_per_value: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array4;
#[test]
fn test_compress_decompress_roundtrip() {
let compressor = KVCacheCompressor::new(16, 3, 3, 42, true);
let k_cache = Array4::from_shape_fn((1, 1, 4, 16), |(_, _, s, d)| {
((s * 16 + d) as f64 + 1.0) / 100.0
});
let v_cache = Array4::from_shape_fn((1, 1, 4, 16), |(_, _, s, d)| {
((s * 16 + d) as f64 + 50.0) / 100.0
});
let compressed = compressor.compress(&k_cache, &v_cache);
let (k_hat, v_hat) = compressor.decompress(&compressed);
assert_eq!(k_hat.shape(), k_cache.shape());
assert_eq!(v_hat.shape(), v_cache.shape());
let k_err: f64 = (&k_cache - &k_hat).mapv(|v| v * v).sum();
let k_orig: f64 = k_cache.mapv(|v| v * v).sum();
assert!(
k_err / k_orig < 1.0,
"K cache relative MSE: {}",
k_err / k_orig
);
let v_err: f64 = (&v_cache - &v_hat).mapv(|v| v * v).sum();
let v_orig: f64 = v_cache.mapv(|v| v * v).sum();
assert!(
v_err / v_orig < 1.0,
"V cache relative MSE: {}",
v_err / v_orig
);
}
#[test]
fn test_memory_stats() {
let compressor = KVCacheCompressor::new(128, 3, 3, 42, true);
let stats = compressor.memory_stats(1024, 32, 32);
assert!(stats.compression_ratio > 1.0, "Should compress");
assert!(stats.compressed_mb < stats.original_mb, "Should be smaller");
}
#[test]
fn test_boundary_k_bits_policy() {
let compressor = KVCacheCompressor::with_boundary_k_bits(16, 3, 3, 42, true, 1, 4);
let k_cache = Array4::from_shape_fn((4, 1, 2, 16), |(l, _, s, d)| {
((l * 32 + s * 16 + d) as f64 + 1.0) / 100.0
});
let v_cache = Array4::from_shape_fn((4, 1, 2, 16), |(l, _, s, d)| {
((l * 32 + s * 16 + d) as f64 + 10.0) / 100.0
});
let compressed = compressor.compress(&k_cache, &v_cache);
assert_eq!(compressed.k_layer_bits, vec![4, 3, 3, 4]);
let (k_hat, v_hat) = compressor.decompress(&compressed);
assert_eq!(k_hat.shape(), k_cache.shape());
assert_eq!(v_hat.shape(), v_cache.shape());
}
}