use serde::{Deserialize, Serialize};
use crate::backend::{weighted_sum_in_place, ExecutionBackend};
use crate::batch::{
batch_attention_scores_mse, batch_dequantize_mse, batch_estimate_inner_products,
batch_quantize_mse, batch_quantize_prod, BatchQuantizedMSE, BatchQuantizedProd,
};
use crate::error::{Result, TurboQuantError};
use crate::turboquant_mse::TurboQuantMSE;
use crate::turboquant_prod::TurboQuantProd;
use crate::utils::{validate_finite_vector, validate_unit_vector};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantStrategy {
MSE,
Prod,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVCacheConfig {
pub dim: usize,
pub key_bits: u8,
pub value_bits: u8,
pub key_strategy: QuantStrategy,
pub seed: u64,
pub max_tokens: usize,
}
impl KVCacheConfig {
pub fn new(dim: usize) -> Self {
Self {
dim,
key_bits: 4,
value_bits: 4,
key_strategy: QuantStrategy::Prod,
seed: 42,
max_tokens: 0,
}
}
pub fn with_key_bits(mut self, bits: u8) -> Self {
self.key_bits = bits;
self
}
pub fn with_value_bits(mut self, bits: u8) -> Self {
self.value_bits = bits;
self
}
pub fn with_key_strategy(mut self, strategy: QuantStrategy) -> Self {
self.key_strategy = strategy;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
enum QuantizedKeys {
MSE(BatchQuantizedMSE),
Prod(BatchQuantizedProd),
}
#[derive(Debug)]
pub struct QuantizedKVCache {
config: KVCacheConfig,
key_quantizer_mse: Option<TurboQuantMSE>,
key_quantizer_prod: Option<TurboQuantProd>,
value_quantizer: TurboQuantMSE,
keys: Option<QuantizedKeys>,
values: Option<BatchQuantizedMSE>,
num_tokens: usize,
}
impl QuantizedKVCache {
fn validate_append_inputs(&self, keys: &[Vec<f64>], values: &[Vec<f64>]) -> Result<()> {
if keys.len() != values.len() {
return Err(TurboQuantError::LengthMismatch {
context: "KV cache append key/value count".into(),
expected: keys.len(),
got: values.len(),
});
}
for key in keys {
if key.len() != self.config.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.dim,
got: key.len(),
});
}
validate_unit_vector(key, "KV cache key")?;
}
for value in values {
if value.len() != self.config.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.dim,
got: value.len(),
});
}
validate_unit_vector(value, "KV cache value")?;
}
Ok(())
}
pub fn new(config: KVCacheConfig) -> Result<Self> {
if config.dim == 0 {
return Err(TurboQuantError::InvalidDimension(config.dim));
}
let value_quantizer = TurboQuantMSE::new(config.dim, config.value_bits, config.seed + 2)?;
let (key_quantizer_mse, key_quantizer_prod) = match config.key_strategy {
QuantStrategy::MSE => {
let kq = TurboQuantMSE::new(config.dim, config.key_bits, config.seed)?;
(Some(kq), None)
}
QuantStrategy::Prod => {
let kq = TurboQuantProd::new(config.dim, config.key_bits, config.seed)?;
(None, Some(kq))
}
};
Ok(Self {
config,
key_quantizer_mse,
key_quantizer_prod,
value_quantizer,
keys: None,
values: None,
num_tokens: 0,
})
}
pub fn append(&mut self, keys: &[Vec<f64>], values: &[Vec<f64>]) -> Result<()> {
if keys.is_empty() {
return Ok(());
}
self.validate_append_inputs(keys, values)?;
let new_keys = match self.config.key_strategy {
QuantStrategy::MSE => {
let kq = self.key_quantizer_mse.as_ref().ok_or_else(|| {
TurboQuantError::Internal("MSE key quantizer not initialized".into())
})?;
QuantizedKeys::MSE(batch_quantize_mse(kq, keys)?)
}
QuantStrategy::Prod => {
let kq = self.key_quantizer_prod.as_ref().ok_or_else(|| {
TurboQuantError::Internal("Prod key quantizer not initialized".into())
})?;
QuantizedKeys::Prod(batch_quantize_prod(kq, keys)?)
}
};
let new_vbatch = batch_quantize_mse(&self.value_quantizer, values)?;
match (&mut self.keys, new_keys) {
(Some(QuantizedKeys::MSE(existing)), QuantizedKeys::MSE(new_batch)) => {
existing.extend(&new_batch)?;
}
(Some(QuantizedKeys::Prod(existing)), QuantizedKeys::Prod(new_batch)) => {
existing.extend(&new_batch)?;
}
(None, new_batch) => self.keys = Some(new_batch),
(Some(_), _) => {
return Err(TurboQuantError::Internal(
"key cache state does not match configured quantization strategy".into(),
));
}
}
match &mut self.values {
Some(existing) => existing.extend(&new_vbatch)?,
None => self.values = Some(new_vbatch),
}
self.num_tokens += keys.len();
self.evict_if_needed();
Ok(())
}
pub fn push_token(&mut self, key: &[f64], value: &[f64]) -> Result<()> {
if key.len() != self.config.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.dim,
got: key.len(),
});
}
if value.len() != self.config.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.dim,
got: value.len(),
});
}
self.append(&[key.to_vec()], &[value.to_vec()])
}
fn evict_if_needed(&mut self) {
let max = self.config.max_tokens;
if max == 0 || self.num_tokens <= max {
return;
}
let excess = self.num_tokens - max;
match &mut self.keys {
Some(QuantizedKeys::MSE(batch)) => batch.drain_front(excess),
Some(QuantizedKeys::Prod(batch)) => batch.drain_front(excess),
None => {}
}
if let Some(vals) = &mut self.values {
vals.drain_front(excess);
}
self.num_tokens = max;
}
pub fn max_tokens(&self) -> usize {
self.config.max_tokens
}
pub fn attention_scores(&self, query: &[f64]) -> Result<Vec<f64>> {
if query.len() != self.config.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.dim,
got: query.len(),
});
}
validate_finite_vector(query, "KV cache query")?;
if self.num_tokens == 0 {
return Ok(Vec::new());
}
match &self.keys {
Some(QuantizedKeys::MSE(batch)) => {
let kq = self.key_quantizer_mse.as_ref().ok_or_else(|| {
TurboQuantError::Internal("MSE key quantizer not initialized".into())
})?;
batch_attention_scores_mse(kq, batch, query)
}
Some(QuantizedKeys::Prod(batch)) => {
let kq = self.key_quantizer_prod.as_ref().ok_or_else(|| {
TurboQuantError::Internal("Prod key quantizer not initialized".into())
})?;
batch_estimate_inner_products(kq, batch, query)
}
None => Ok(Vec::new()),
}
}
pub fn reconstruct_values(&self) -> Result<Vec<Vec<f64>>> {
match &self.values {
Some(batch) => batch_dequantize_mse(&self.value_quantizer, batch),
None => Ok(Vec::new()),
}
}
pub fn reconstruct_keys(&self) -> Result<Vec<Vec<f64>>> {
match &self.keys {
Some(QuantizedKeys::MSE(batch)) => {
let quantizer = self.key_quantizer_mse.as_ref().ok_or_else(|| {
TurboQuantError::Internal("MSE key quantizer not initialized".into())
})?;
batch_dequantize_mse(quantizer, batch)
}
Some(QuantizedKeys::Prod(batch)) => {
let quantizer = self.key_quantizer_prod.as_ref().ok_or_else(|| {
TurboQuantError::Internal("Prod key quantizer not initialized".into())
})?;
let mut keys = Vec::with_capacity(batch.len());
for index in 0..batch.len() {
let quantized = crate::turboquant_prod::ProdQuantized {
mse_indices: batch.unpack_mse_indices(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed Prod MSE row {index} in key reconstruction"
))
})?,
qjl_signs: batch.unpack_qjl_signs(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing packed Prod QJL row {index} in key reconstruction"
))
})?,
residual_norm: batch.residual_norm(index).ok_or_else(|| {
TurboQuantError::Internal(format!(
"missing residual norm row {index} in key reconstruction"
))
})?,
bit_width: batch.bit_width,
dim: batch.dim,
};
keys.push(quantizer.dequantize(&quantized)?);
}
Ok(keys)
}
None => Ok(Vec::new()),
}
}
pub fn attention_output(&self, query: &[f64], temperature: f64) -> Result<Vec<f64>> {
if self.num_tokens == 0 {
return Ok(vec![0.0; self.config.dim]);
}
if !temperature.is_finite() {
return Err(TurboQuantError::InvalidValue {
context: "attention temperature".into(),
value: temperature,
});
}
let scores = self.attention_scores(query)?;
let values = self.reconstruct_values()?;
if scores.len() != values.len() {
return Err(TurboQuantError::LengthMismatch {
context: "attention score/value count".into(),
expected: scores.len(),
got: values.len(),
});
}
let temp = if temperature <= 0.0 {
1.0 / (self.config.dim as f64).sqrt()
} else {
temperature
};
let scaled: Vec<f64> = scores.iter().map(|&s| s * temp).collect();
let max_score = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Vec<f64> = scaled.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f64 = exp_scores.iter().sum();
let weights: Vec<f64> = exp_scores.iter().map(|&e| e / sum_exp).collect();
let mut output = vec![0.0; self.config.dim];
for (w, v) in weights.iter().zip(values.iter()) {
weighted_sum_in_place(ExecutionBackend::default(), &mut output, *w, v);
}
Ok(output)
}
pub fn num_tokens(&self) -> usize {
self.num_tokens
}
pub fn dim(&self) -> usize {
self.config.dim
}
pub fn stats(&self) -> CacheStats {
let key_bytes = match &self.keys {
Some(QuantizedKeys::MSE(b)) => b.total_bytes(),
Some(QuantizedKeys::Prod(b)) => b.total_bytes(),
None => 0,
};
let value_bytes = self.values.as_ref().map_or(0, |b| b.total_bytes());
let uncompressed = self.num_tokens * self.config.dim * 4 * 2;
CacheStats {
num_tokens: self.num_tokens,
dim: self.config.dim,
key_bits: self.config.key_bits,
value_bits: self.config.value_bits,
key_strategy: self.config.key_strategy,
key_bytes,
value_bytes,
total_bytes: key_bytes + value_bytes,
uncompressed_bytes: uncompressed,
compression_ratio: if key_bytes + value_bytes > 0 {
uncompressed as f64 / (key_bytes + value_bytes) as f64
} else {
0.0
},
}
}
pub fn clear(&mut self) {
self.keys = None;
self.values = None;
self.num_tokens = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub num_tokens: usize,
pub dim: usize,
pub key_bits: u8,
pub value_bits: u8,
pub key_strategy: QuantStrategy,
pub key_bytes: usize,
pub value_bytes: usize,
pub total_bytes: usize,
pub uncompressed_bytes: usize,
pub compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiHeadConfig {
pub num_heads: usize,
pub head_config: KVCacheConfig,
}
impl MultiHeadConfig {
pub fn new(num_heads: usize, head_config: KVCacheConfig) -> Self {
Self {
num_heads,
head_config,
}
}
}
#[derive(Debug)]
pub struct MultiHeadKVCache {
heads: Vec<QuantizedKVCache>,
config: MultiHeadConfig,
}
impl MultiHeadKVCache {
pub fn new(config: MultiHeadConfig) -> Result<Self> {
if config.num_heads == 0 {
return Err(TurboQuantError::InvalidDimension(0));
}
let mut heads = Vec::with_capacity(config.num_heads);
for h in 0..config.num_heads {
let mut hc = config.head_config.clone();
hc.seed = config.head_config.seed.wrapping_add(h as u64 * 1000);
heads.push(QuantizedKVCache::new(hc)?);
}
Ok(Self { heads, config })
}
pub fn append_all(&mut self, keys: &[Vec<Vec<f64>>], values: &[Vec<Vec<f64>>]) -> Result<()> {
if keys.len() != self.config.num_heads {
return Err(TurboQuantError::LengthMismatch {
context: "multi-head key head count".into(),
expected: self.config.num_heads,
got: keys.len(),
});
}
if values.len() != self.config.num_heads {
return Err(TurboQuantError::LengthMismatch {
context: "multi-head value head count".into(),
expected: self.config.num_heads,
got: values.len(),
});
}
let expected_tokens = keys.first().map_or(0, |head_keys| head_keys.len());
for (head_index, head_keys) in keys.iter().enumerate() {
if head_keys.len() != expected_tokens {
return Err(TurboQuantError::LengthMismatch {
context: format!("multi-head key token count for head {head_index}"),
expected: expected_tokens,
got: head_keys.len(),
});
}
}
for (head_index, head_values) in values.iter().enumerate() {
if head_values.len() != expected_tokens {
return Err(TurboQuantError::LengthMismatch {
context: format!("multi-head value token count for head {head_index}"),
expected: expected_tokens,
got: head_values.len(),
});
}
}
for (head_index, head) in self.heads.iter().enumerate() {
head.validate_append_inputs(&keys[head_index], &values[head_index])?;
}
for (h, head) in self.heads.iter_mut().enumerate() {
head.append(&keys[h], &values[h])?;
}
Ok(())
}
pub fn attention_scores_all(&self, queries: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if queries.len() != self.config.num_heads {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.num_heads,
got: queries.len(),
});
}
let mut all_scores = Vec::with_capacity(self.config.num_heads);
for (h, head) in self.heads.iter().enumerate() {
all_scores.push(head.attention_scores(&queries[h])?);
}
Ok(all_scores)
}
pub fn attention_output_all(
&self,
queries: &[Vec<f64>],
temperature: f64,
) -> Result<Vec<Vec<f64>>> {
if queries.len() != self.config.num_heads {
return Err(TurboQuantError::DimensionMismatch {
expected: self.config.num_heads,
got: queries.len(),
});
}
let mut all_outputs = Vec::with_capacity(self.config.num_heads);
for (h, head) in self.heads.iter().enumerate() {
all_outputs.push(head.attention_output(&queries[h], temperature)?);
}
Ok(all_outputs)
}
pub fn concat_outputs(&self, outputs: &[Vec<f64>]) -> Result<Vec<f64>> {
if outputs.len() != self.config.num_heads {
return Err(TurboQuantError::LengthMismatch {
context: "multi-head output count".into(),
expected: self.config.num_heads,
got: outputs.len(),
});
}
let total_dim = self.config.num_heads * self.config.head_config.dim;
let mut concatenated = Vec::with_capacity(total_dim);
for (head_index, output) in outputs.iter().enumerate() {
if output.len() != self.config.head_config.dim {
return Err(TurboQuantError::LengthMismatch {
context: format!("output dimension for head {head_index}"),
expected: self.config.head_config.dim,
got: output.len(),
});
}
concatenated.extend_from_slice(output);
}
Ok(concatenated)
}
pub fn num_heads(&self) -> usize {
self.config.num_heads
}
pub fn head_dim(&self) -> usize {
self.config.head_config.dim
}
pub fn num_tokens(&self) -> usize {
self.heads.first().map_or(0, |h| h.num_tokens())
}
pub fn head(&self, index: usize) -> Option<&QuantizedKVCache> {
self.heads.get(index)
}
pub fn stats(&self) -> MultiHeadCacheStats {
let head_stats: Vec<CacheStats> = self.heads.iter().map(|h| h.stats()).collect();
let total_bytes: usize = head_stats.iter().map(|s| s.total_bytes).sum();
let uncompressed: usize = head_stats.iter().map(|s| s.uncompressed_bytes).sum();
MultiHeadCacheStats {
num_heads: self.config.num_heads,
num_tokens: self.num_tokens(),
head_dim: self.config.head_config.dim,
total_bytes,
uncompressed_bytes: uncompressed,
compression_ratio: if total_bytes > 0 {
uncompressed as f64 / total_bytes as f64
} else {
0.0
},
head_stats,
}
}
pub fn clear(&mut self) {
for head in &mut self.heads {
head.clear();
}
}
pub fn reconstruct_keys_all(&self) -> Result<Vec<Vec<Vec<f64>>>> {
self.heads
.iter()
.map(QuantizedKVCache::reconstruct_keys)
.collect()
}
pub fn reconstruct_values_all(&self) -> Result<Vec<Vec<Vec<f64>>>> {
self.heads
.iter()
.map(QuantizedKVCache::reconstruct_values)
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiHeadCacheStats {
pub num_heads: usize,
pub num_tokens: usize,
pub head_dim: usize,
pub total_bytes: usize,
pub uncompressed_bytes: usize,
pub compression_ratio: f64,
pub head_stats: Vec<CacheStats>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::{inner_product, normalize};
use serde_json;
fn random_unit_vectors(dim: usize, count: usize, seed: u64) -> Vec<Vec<f64>> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..count)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect()
}
#[test]
fn test_kv_cache_basic_mse() {
let config = KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
assert_eq!(cache.num_tokens(), 0);
let keys = random_unit_vectors(64, 8, 1);
let values = random_unit_vectors(64, 8, 2);
cache.append(&keys, &values).unwrap();
assert_eq!(cache.num_tokens(), 8);
let query = &random_unit_vectors(64, 1, 3)[0];
let scores = cache.attention_scores(query).unwrap();
assert_eq!(scores.len(), 8);
for (i, &score) in scores.iter().enumerate() {
let true_ip = inner_product(&keys[i], query);
assert!(
(true_ip - score).abs() < 0.2,
"key {}: true={:.4}, est={:.4}",
i,
true_ip,
score
);
}
}
#[test]
fn test_kv_cache_basic_prod() {
let config = KVCacheConfig::new(64)
.with_key_bits(3)
.with_key_strategy(QuantStrategy::Prod);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 8, 10);
let values = random_unit_vectors(64, 8, 20);
cache.append(&keys, &values).unwrap();
let query = &random_unit_vectors(64, 1, 30)[0];
let scores = cache.attention_scores(query).unwrap();
assert_eq!(scores.len(), 8);
}
#[test]
fn test_kv_cache_incremental_append() {
let config = KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys1 = random_unit_vectors(64, 4, 1);
let vals1 = random_unit_vectors(64, 4, 2);
cache.append(&keys1, &vals1).unwrap();
assert_eq!(cache.num_tokens(), 4);
let keys2 = random_unit_vectors(64, 3, 3);
let vals2 = random_unit_vectors(64, 3, 4);
cache.append(&keys2, &vals2).unwrap();
assert_eq!(cache.num_tokens(), 7);
}
#[test]
fn test_kv_cache_attention_output() {
let config = KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 8, 1);
let values = random_unit_vectors(64, 8, 2);
cache.append(&keys, &values).unwrap();
let query = &random_unit_vectors(64, 1, 3)[0];
let output = cache.attention_output(query, 0.0).unwrap();
assert_eq!(output.len(), 64);
for &v in &output {
assert!(v.is_finite(), "output contains non-finite value");
}
}
#[test]
fn test_kv_cache_stats() {
let config = KVCacheConfig::new(128)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(128, 16, 1);
let values = random_unit_vectors(128, 16, 2);
cache.append(&keys, &values).unwrap();
let stats = cache.stats();
assert_eq!(stats.num_tokens, 16);
assert_eq!(stats.dim, 128);
assert!(stats.compression_ratio > 1.0);
assert!(stats.total_bytes > 0);
assert!(stats.total_bytes < stats.uncompressed_bytes);
}
#[test]
fn test_kv_cache_clear() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 4, 1);
let values = random_unit_vectors(64, 4, 2);
cache.append(&keys, &values).unwrap();
assert_eq!(cache.num_tokens(), 4);
cache.clear();
assert_eq!(cache.num_tokens(), 0);
let query = &random_unit_vectors(64, 1, 3)[0];
let scores = cache.attention_scores(query).unwrap();
assert!(scores.is_empty());
}
#[test]
fn test_kv_cache_empty_attention() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let cache = QuantizedKVCache::new(config).unwrap();
let query = &random_unit_vectors(64, 1, 1)[0];
let scores = cache.attention_scores(query).unwrap();
assert!(scores.is_empty());
let output = cache.attention_output(query, 0.0).unwrap();
assert_eq!(output.len(), 64);
assert!(output.iter().all(|&v| v == 0.0));
}
#[test]
fn test_kv_cache_dimension_mismatch() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 4, 1);
let values = random_unit_vectors(64, 4, 2);
cache.append(&keys, &values).unwrap();
let bad_query = vec![0.0; 32];
assert!(cache.attention_scores(&bad_query).is_err());
}
#[test]
fn test_kv_cache_reconstruct_values() {
let config = KVCacheConfig::new(64)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 4, 1);
let values = random_unit_vectors(64, 4, 2);
cache.append(&keys, &values).unwrap();
let recon_values = cache.reconstruct_values().unwrap();
assert_eq!(recon_values.len(), 4);
for (orig, recon) in values.iter().zip(recon_values.iter()) {
let mse: f64 = orig
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ 64.0;
assert!(mse < 0.1, "Value reconstruction MSE {} too large", mse);
}
}
#[test]
fn test_kv_cache_incremental_scores_match() {
let dim = 64;
let config = KVCacheConfig::new(dim)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE)
.with_seed(42);
let keys1 = random_unit_vectors(dim, 4, 1);
let vals1 = random_unit_vectors(dim, 4, 2);
let keys2 = random_unit_vectors(dim, 3, 3);
let vals2 = random_unit_vectors(dim, 3, 4);
let mut bulk = QuantizedKVCache::new(config.clone()).unwrap();
let mut all_keys = keys1.clone();
all_keys.extend_from_slice(&keys2);
let mut all_vals = vals1.clone();
all_vals.extend_from_slice(&vals2);
bulk.append(&all_keys, &all_vals).unwrap();
let mut incr = QuantizedKVCache::new(config).unwrap();
incr.append(&keys1, &vals1).unwrap();
incr.append(&keys2, &vals2).unwrap();
let query = &random_unit_vectors(dim, 1, 99)[0];
let bulk_scores = bulk.attention_scores(query).unwrap();
let incr_scores = incr.attention_scores(query).unwrap();
assert_eq!(bulk_scores.len(), incr_scores.len());
for (b, i) in bulk_scores.iter().zip(incr_scores.iter()) {
assert!(
(b - i).abs() < 1e-10,
"bulk={}, incr={}, diff={}",
b,
i,
(b - i).abs()
);
}
}
#[test]
fn test_kv_cache_append_empty() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
cache.append(&[], &[]).unwrap();
assert_eq!(cache.num_tokens(), 0);
let keys = random_unit_vectors(64, 2, 1);
let vals = random_unit_vectors(64, 2, 2);
cache.append(&keys, &vals).unwrap();
assert_eq!(cache.num_tokens(), 2);
cache.append(&[], &[]).unwrap();
assert_eq!(cache.num_tokens(), 2);
}
#[test]
fn test_multi_head_basic() {
let config = MultiHeadConfig::new(
4,
KVCacheConfig::new(32)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
assert_eq!(cache.num_heads(), 4);
assert_eq!(cache.head_dim(), 32);
assert_eq!(cache.num_tokens(), 0);
let keys: Vec<Vec<Vec<f64>>> = (0..4)
.map(|h| random_unit_vectors(32, 6, 100 + h as u64))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..4)
.map(|h| random_unit_vectors(32, 6, 200 + h as u64))
.collect();
cache.append_all(&keys, &values).unwrap();
assert_eq!(cache.num_tokens(), 6);
let queries: Vec<Vec<f64>> = (0..4)
.map(|h| {
random_unit_vectors(32, 1, 300 + h as u64)
.into_iter()
.next()
.unwrap()
})
.collect();
let scores = cache.attention_scores_all(&queries).unwrap();
assert_eq!(scores.len(), 4);
for head_scores in &scores {
assert_eq!(head_scores.len(), 6);
}
}
#[test]
fn test_multi_head_attention_output() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 4, 10 + h as u64))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 4, 20 + h as u64))
.collect();
cache.append_all(&keys, &values).unwrap();
let queries: Vec<Vec<f64>> = (0..2)
.map(|h| {
random_unit_vectors(32, 1, 30 + h as u64)
.into_iter()
.next()
.unwrap()
})
.collect();
let outputs = cache.attention_output_all(&queries, 0.0).unwrap();
assert_eq!(outputs.len(), 2);
for output in &outputs {
assert_eq!(output.len(), 32);
for &v in output {
assert!(v.is_finite());
}
}
let concat = cache.concat_outputs(&outputs).unwrap();
assert_eq!(concat.len(), 64); }
#[test]
fn test_multi_head_stats() {
let config = MultiHeadConfig::new(
4,
KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..4)
.map(|h| random_unit_vectors(64, 8, h as u64))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..4)
.map(|h| random_unit_vectors(64, 8, 100 + h as u64))
.collect();
cache.append_all(&keys, &values).unwrap();
let stats = cache.stats();
assert_eq!(stats.num_heads, 4);
assert_eq!(stats.num_tokens, 8);
assert_eq!(stats.head_dim, 64);
assert!(stats.compression_ratio > 1.0);
assert_eq!(stats.head_stats.len(), 4);
}
#[test]
fn test_multi_head_incremental_append() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32)
.with_key_bits(4)
.with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys1: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 3, h as u64))
.collect();
let vals1: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 3, 50 + h as u64))
.collect();
cache.append_all(&keys1, &vals1).unwrap();
assert_eq!(cache.num_tokens(), 3);
let keys2: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 5, 100 + h as u64))
.collect();
let vals2: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 5, 150 + h as u64))
.collect();
cache.append_all(&keys2, &vals2).unwrap();
assert_eq!(cache.num_tokens(), 8);
}
#[test]
fn test_multi_head_clear() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 4, h as u64))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 4, 10 + h as u64))
.collect();
cache.append_all(&keys, &values).unwrap();
assert_eq!(cache.num_tokens(), 4);
cache.clear();
assert_eq!(cache.num_tokens(), 0);
}
#[test]
fn test_multi_head_prod_strategy() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(64)
.with_key_bits(3)
.with_key_strategy(QuantStrategy::Prod),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(64, 4, h as u64))
.collect();
let values: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(64, 4, 10 + h as u64))
.collect();
cache.append_all(&keys, &values).unwrap();
let queries: Vec<Vec<f64>> = (0..2)
.map(|h| {
random_unit_vectors(64, 1, 20 + h as u64)
.into_iter()
.next()
.unwrap()
})
.collect();
let scores = cache.attention_scores_all(&queries).unwrap();
assert_eq!(scores.len(), 2);
for head_scores in &scores {
assert_eq!(head_scores.len(), 4);
}
}
#[test]
fn test_push_token() {
let config = KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
assert_eq!(cache.num_tokens(), 0);
let keys = random_unit_vectors(64, 5, 1);
let values = random_unit_vectors(64, 5, 2);
for i in 0..5 {
cache.push_token(&keys[i], &values[i]).unwrap();
}
assert_eq!(cache.num_tokens(), 5);
let query = &random_unit_vectors(64, 1, 99)[0];
let scores = cache.attention_scores(query).unwrap();
assert_eq!(scores.len(), 5);
}
#[test]
fn test_sliding_window_eviction() {
let config = KVCacheConfig::new(64)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE)
.with_max_tokens(5);
let mut cache = QuantizedKVCache::new(config).unwrap();
assert_eq!(cache.max_tokens(), 5);
let keys = random_unit_vectors(64, 5, 1);
let vals = random_unit_vectors(64, 5, 2);
cache.append(&keys, &vals).unwrap();
assert_eq!(cache.num_tokens(), 5);
let keys2 = random_unit_vectors(64, 3, 3);
let vals2 = random_unit_vectors(64, 3, 4);
cache.append(&keys2, &vals2).unwrap();
assert_eq!(cache.num_tokens(), 5);
let query = &random_unit_vectors(64, 1, 99)[0];
let scores = cache.attention_scores(query).unwrap();
assert_eq!(scores.len(), 5);
}
#[test]
fn test_sliding_window_token_by_token() {
let config = KVCacheConfig::new(32)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE)
.with_max_tokens(3);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(32, 10, 1);
let vals = random_unit_vectors(32, 10, 2);
for i in 0..10 {
cache.push_token(&keys[i], &vals[i]).unwrap();
assert!(
cache.num_tokens() <= 3,
"token {}: count={}, expected <=3",
i,
cache.num_tokens()
);
}
assert_eq!(cache.num_tokens(), 3);
let recon = cache.reconstruct_values().unwrap();
assert_eq!(recon.len(), 3);
}
#[test]
fn test_sliding_window_prod_strategy() {
let config = KVCacheConfig::new(64)
.with_key_bits(3)
.with_key_strategy(QuantStrategy::Prod)
.with_max_tokens(4);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 6, 1);
let vals = random_unit_vectors(64, 6, 2);
cache.append(&keys, &vals).unwrap();
assert_eq!(cache.num_tokens(), 4);
let query = &random_unit_vectors(64, 1, 99)[0];
let scores = cache.attention_scores(query).unwrap();
assert_eq!(scores.len(), 4);
}
#[test]
fn test_unlimited_capacity() {
let config = KVCacheConfig::new(32)
.with_key_strategy(QuantStrategy::MSE)
.with_max_tokens(0);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(32, 100, 1);
let vals = random_unit_vectors(32, 100, 2);
cache.append(&keys, &vals).unwrap();
assert_eq!(cache.num_tokens(), 100);
}
#[test]
fn test_push_token_dimension_mismatch() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let bad_key = vec![0.0; 32];
let good_val = random_unit_vectors(64, 1, 1).into_iter().next().unwrap();
assert!(cache.push_token(&bad_key, &good_val).is_err());
let good_key = random_unit_vectors(64, 1, 1).into_iter().next().unwrap();
let bad_val = vec![0.0; 32];
assert!(cache.push_token(&good_key, &bad_val).is_err());
}
#[test]
fn test_multi_head_zero_heads() {
let config = MultiHeadConfig::new(
0,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
assert!(MultiHeadKVCache::new(config).is_err());
}
#[test]
fn test_kv_cache_invalid_dimension() {
let config = KVCacheConfig::new(0);
assert!(QuantizedKVCache::new(config).is_err());
}
#[test]
fn test_kv_cache_keys_values_length_mismatch() {
let config = KVCacheConfig::new(64).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(64, 3, 1);
let vals = random_unit_vectors(64, 5, 2); assert!(cache.append(&keys, &vals).is_err());
}
#[test]
fn test_append_is_transactional_on_value_failure() {
let config = KVCacheConfig::new(8).with_key_strategy(QuantStrategy::MSE);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(8, 2, 1);
let mut values = random_unit_vectors(8, 2, 2);
values[1][0] = 10.0;
assert!(matches!(
cache.append(&keys, &values),
Err(TurboQuantError::NotUnitVector(_))
));
assert_eq!(cache.num_tokens(), 0);
assert!(cache
.attention_scores(&random_unit_vectors(8, 1, 3)[0])
.unwrap()
.is_empty());
}
#[test]
fn test_multi_head_head_count_mismatch() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys: Vec<Vec<Vec<f64>>> = (0..3) .map(|h| random_unit_vectors(32, 2, h as u64))
.collect();
let vals: Vec<Vec<Vec<f64>>> = (0..2)
.map(|h| random_unit_vectors(32, 2, 10 + h as u64))
.collect();
assert!(cache.append_all(&keys, &vals).is_err());
}
#[test]
fn test_multi_head_requires_matching_token_counts() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
let mut cache = MultiHeadKVCache::new(config).unwrap();
let keys = vec![random_unit_vectors(32, 2, 1), random_unit_vectors(32, 3, 2)];
let values = vec![random_unit_vectors(32, 2, 3), random_unit_vectors(32, 3, 4)];
assert!(matches!(
cache.append_all(&keys, &values),
Err(TurboQuantError::LengthMismatch { .. })
));
assert_eq!(cache.num_tokens(), 0);
}
#[test]
fn test_kv_cache_debug_trait() {
let config = KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE);
let cache = QuantizedKVCache::new(config).unwrap();
let _debug_str = format!("{:?}", cache);
}
#[test]
fn test_multi_head_debug_trait() {
let config = MultiHeadConfig::new(
2,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
let cache = MultiHeadKVCache::new(config).unwrap();
let _debug_str = format!("{:?}", cache);
}
#[test]
fn test_quant_strategy_serde() {
let strategy = QuantStrategy::Prod;
let json = serde_json::to_string(&strategy).unwrap();
let deser: QuantStrategy = serde_json::from_str(&json).unwrap();
assert_eq!(strategy, deser);
}
#[test]
fn test_multi_head_head_access() {
let config = MultiHeadConfig::new(
4,
KVCacheConfig::new(32).with_key_strategy(QuantStrategy::MSE),
);
let cache = MultiHeadKVCache::new(config).unwrap();
assert!(cache.head(0).is_some());
assert!(cache.head(3).is_some());
assert!(cache.head(4).is_none());
}
#[test]
fn test_sliding_window_values_are_newest() {
let dim = 32;
let config = KVCacheConfig::new(dim)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE)
.with_max_tokens(2);
let mut cache = QuantizedKVCache::new(config).unwrap();
let keys = random_unit_vectors(dim, 3, 1);
let vals = random_unit_vectors(dim, 3, 2);
for i in 0..3 {
cache.push_token(&keys[i], &vals[i]).unwrap();
}
assert_eq!(cache.num_tokens(), 2);
let config2 = KVCacheConfig::new(dim)
.with_key_bits(4)
.with_value_bits(4)
.with_key_strategy(QuantStrategy::MSE);
let mut cache2 = QuantizedKVCache::new(config2).unwrap();
cache2.append(&keys[1..3], &vals[1..3]).unwrap();
let query = &random_unit_vectors(dim, 1, 99)[0];
let scores1 = cache.attention_scores(query).unwrap();
let scores2 = cache2.attention_scores(query).unwrap();
assert_eq!(scores1.len(), scores2.len());
for (a, b) in scores1.iter().zip(scores2.iter()) {
assert!(
(a - b).abs() < 1e-10,
"evicted cache={}, fresh cache={}, diff={}",
a,
b,
(a - b).abs()
);
}
}
}