use crate::error::{validate_finite, Result, TurboQuantError};
use crate::turbo::{TurboCode, TurboQuantizer};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub struct KvCacheConfig {
pub head_dim: usize,
pub bits: u8,
pub projections: usize,
pub seed: u64,
}
#[derive(Debug, Clone)]
pub struct KvCacheCompressor {
quantizer: TurboQuantizer,
keys: Vec<TurboCode>,
values: Vec<TurboCode>,
}
impl KvCacheCompressor {
pub fn new(config: &KvCacheConfig) -> Result<Self> {
let quantizer =
TurboQuantizer::new(config.head_dim, config.bits, config.projections, config.seed)?;
Ok(Self { quantizer, keys: Vec::new(), values: Vec::new() })
}
#[cfg_attr(
feature = "tracing-support",
tracing::instrument(
name = "bitpolar::kv_cache::push",
skip(self, key, value),
fields(cache_len = self.keys.len(), dim = self.quantizer.dim())
)
)]
pub fn push(&mut self, key: &[f32], value: &[f32]) -> Result<()> {
validate_finite(key)?;
validate_finite(value)?;
let kc = self.quantizer.encode(key)?;
let vc = self.quantizer.encode(value)?;
self.keys.push(kc);
self.values.push(vc);
Ok(())
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[cfg_attr(
feature = "tracing-support",
tracing::instrument(
name = "bitpolar::kv_cache::attention_scores",
skip(self, query),
fields(cache_len = self.keys.len(), dim = self.quantizer.dim())
)
)]
pub fn attention_scores(&self, query: &[f32]) -> Result<Vec<f32>> {
validate_finite(query)?;
let scale = 1.0 / crate::compat::math::sqrtf(self.quantizer.dim() as f32);
let scores: Result<Vec<f32>> = self
.keys
.iter()
.map(|kc| self.quantizer.inner_product_estimate(kc, query).map(|ip| ip * scale))
.collect();
scores
}
pub fn decode_values(&self) -> Vec<Vec<f32>> {
self.values.iter().map(|vc| self.quantizer.decode(vc)).collect()
}
pub fn compression_ratio(&self) -> f64 {
let total_codes: usize = self.keys.len() + self.values.len();
if total_codes == 0 {
return 1.0;
}
let original = total_codes * self.quantizer.dim() * core::mem::size_of::<f32>();
let compressed: usize = self
.keys
.iter()
.chain(self.values.iter())
.map(|c| c.size_bytes())
.sum();
if compressed == 0 {
return 1.0;
}
original as f64 / compressed as f64
}
pub fn clear(&mut self) {
self.keys.clear();
self.values.clear();
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadKvCache {
heads: Vec<KvCacheCompressor>,
}
impl MultiHeadKvCache {
pub fn new(num_heads: usize, config: &KvCacheConfig) -> Result<Self> {
if num_heads == 0 {
return Err(TurboQuantError::EmptyInput("num_heads must be > 0"));
}
let heads: Result<Vec<_>> = (0..num_heads)
.map(|h| {
let head_config = KvCacheConfig {
seed: config.seed.wrapping_add(h as u64),
..*config
};
KvCacheCompressor::new(&head_config)
})
.collect();
Ok(Self { heads: heads? })
}
pub fn push_token(&mut self, keys: &[&[f32]], values: &[&[f32]]) -> Result<()> {
if keys.len() != self.heads.len() || values.len() != self.heads.len() {
return Err(TurboQuantError::DimensionMismatch {
expected: self.heads.len(),
actual: keys.len(),
});
}
for (h, head) in self.heads.iter_mut().enumerate() {
head.push(keys[h], values[h])?;
}
Ok(())
}
pub fn attention_scores(&self, queries: &[&[f32]]) -> Result<Vec<Vec<f32>>> {
if queries.len() != self.heads.len() {
return Err(TurboQuantError::DimensionMismatch {
expected: self.heads.len(),
actual: queries.len(),
});
}
self.heads
.iter()
.zip(queries.iter())
.map(|(head, q)| head.attention_scores(q))
.collect()
}
pub fn len(&self) -> usize {
self.heads.first().map_or(0, |h| h.len())
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
for h in &mut self.heads {
h.clear();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_config() -> KvCacheConfig {
KvCacheConfig { head_dim: 8, bits: 4, projections: 16, seed: 42 }
}
#[test]
fn test_push_and_len() {
let mut cache = KvCacheCompressor::new(&sample_config()).unwrap();
assert!(cache.is_empty());
let k = vec![0.1_f32; 8];
let v = vec![0.2_f32; 8];
cache.push(&k, &v).unwrap();
assert_eq!(cache.len(), 1);
}
#[test]
fn test_attention_scores_shape() {
let mut cache = KvCacheCompressor::new(&sample_config()).unwrap();
for _ in 0..5 {
let k = vec![0.1_f32; 8];
let v = vec![0.2_f32; 8];
cache.push(&k, &v).unwrap();
}
let q = vec![0.1_f32; 8];
let scores = cache.attention_scores(&q).unwrap();
assert_eq!(scores.len(), 5);
}
#[test]
fn test_decode_values_shape() {
let mut cache = KvCacheCompressor::new(&sample_config()).unwrap();
let k = vec![0.1_f32; 8];
let v = vec![0.2_f32; 8];
cache.push(&k, &v).unwrap();
let decoded = cache.decode_values();
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0].len(), 8);
}
#[test]
fn test_compression_ratio_positive() {
let mut cache = KvCacheCompressor::new(&sample_config()).unwrap();
let k = vec![0.5_f32; 8];
let v = vec![0.3_f32; 8];
cache.push(&k, &v).unwrap();
assert!(cache.compression_ratio() > 0.0);
}
#[test]
fn test_clear() {
let mut cache = KvCacheCompressor::new(&sample_config()).unwrap();
let k = vec![0.1_f32; 8];
let v = vec![0.2_f32; 8];
cache.push(&k, &v).unwrap();
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_multi_head_push_and_len() {
let config = sample_config();
let mut mhc = MultiHeadKvCache::new(4, &config).unwrap();
let k_vecs: Vec<Vec<f32>> = (0..4).map(|_| vec![0.1_f32; 8]).collect();
let v_vecs: Vec<Vec<f32>> = (0..4).map(|_| vec![0.2_f32; 8]).collect();
let keys: Vec<&[f32]> = k_vecs.iter().map(|v| v.as_slice()).collect();
let vals: Vec<&[f32]> = v_vecs.iter().map(|v| v.as_slice()).collect();
mhc.push_token(&keys, &vals).unwrap();
assert_eq!(mhc.len(), 1);
}
#[test]
fn test_multi_head_attention_scores_shape() {
let config = sample_config();
let mut mhc = MultiHeadKvCache::new(2, &config).unwrap();
let k_vecs: Vec<Vec<f32>> = (0..2).map(|_| vec![0.1_f32; 8]).collect();
let v_vecs: Vec<Vec<f32>> = (0..2).map(|_| vec![0.2_f32; 8]).collect();
let keys: Vec<&[f32]> = k_vecs.iter().map(|v| v.as_slice()).collect();
let vals: Vec<&[f32]> = v_vecs.iter().map(|v| v.as_slice()).collect();
mhc.push_token(&keys, &vals).unwrap();
let q_vecs: Vec<Vec<f32>> = (0..2).map(|_| vec![0.1_f32; 8]).collect();
let queries: Vec<&[f32]> = q_vecs.iter().map(|v| v.as_slice()).collect();
let scores = mhc.attention_scores(&queries).unwrap();
assert_eq!(scores.len(), 2);
assert_eq!(scores[0].len(), 1);
}
}