#[derive(Debug, Clone)]
pub struct KVCache {
k: Vec<Vec<f32>>,
v: Vec<Vec<f32>>,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
len: usize,
}
impl KVCache {
pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
Self {
k: (0..num_layers).map(|_| Vec::new()).collect(),
v: (0..num_layers).map(|_| Vec::new()).collect(),
num_layers,
num_kv_heads,
head_dim,
len: 0,
}
}
pub fn with_capacity(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
) -> Self {
let entry_size = num_kv_heads * head_dim;
let capacity = max_seq_len * entry_size;
Self {
k: (0..num_layers)
.map(|_| Vec::with_capacity(capacity))
.collect(),
v: (0..num_layers)
.map(|_| Vec::with_capacity(capacity))
.collect(),
num_layers,
num_kv_heads,
head_dim,
len: 0,
}
}
pub fn append(&mut self, layer: usize, k_data: &[f32], v_data: &[f32]) {
debug_assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
debug_assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
self.k[layer].extend_from_slice(k_data);
self.v[layer].extend_from_slice(v_data);
}
pub fn advance(&mut self) {
self.len += 1;
}
pub fn k(&self, layer: usize) -> &[f32] {
&self.k[layer]
}
pub fn v(&self, layer: usize) -> &[f32] {
&self.v[layer]
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn clear(&mut self) {
for layer_k in &mut self.k {
layer_k.clear();
}
for layer_v in &mut self.v {
layer_v.clear();
}
self.len = 0;
}
pub fn entry_size(&self) -> usize {
self.num_kv_heads * self.head_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_and_append() {
let mut cache = KVCache::new(2, 4, 16);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
assert_eq!(cache.entry_size(), 64);
let kv_data = vec![1.0f32; 64];
cache.append(0, &kv_data, &kv_data);
cache.append(1, &kv_data, &kv_data);
cache.advance();
assert_eq!(cache.len(), 1);
assert_eq!(cache.k(0).len(), 64);
assert_eq!(cache.v(1).len(), 64);
}
#[test]
fn append_multiple_tokens() {
let mut cache = KVCache::new(1, 2, 8);
for token_idx in 0..5 {
let data = vec![token_idx as f32; 16];
cache.append(0, &data, &data);
cache.advance();
}
assert_eq!(cache.len(), 5);
assert_eq!(cache.k(0).len(), 80); assert_eq!(cache.k(0)[0], 0.0);
assert_eq!(cache.k(0)[64], 4.0);
}
#[test]
fn clear_cache() {
let mut cache = KVCache::new(2, 4, 16);
let data = vec![1.0f32; 64];
cache.append(0, &data, &data);
cache.append(1, &data, &data);
cache.advance();
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
assert!(cache.k(0).is_empty());
}
#[test]
fn with_capacity() {
let cache = KVCache::with_capacity(16, 8, 64, 2048);
assert!(cache.is_empty());
assert_eq!(cache.num_layers(), 16);
assert_eq!(cache.entry_size(), 512);
}
#[test]
fn clear_resets_completely_for_independent_generation() {
let mut cache = KVCache::new(2, 4, 16);
let data_a = vec![1.0f32; 64];
for _ in 0..5 {
cache.append(0, &data_a, &data_a);
cache.append(1, &data_a, &data_a);
cache.advance();
}
assert_eq!(cache.len(), 5);
assert_eq!(cache.k(0).len(), 5 * 64);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
assert!(cache.k(0).is_empty());
assert!(cache.v(0).is_empty());
assert!(cache.k(1).is_empty());
assert!(cache.v(1).is_empty());
let data_b = vec![2.0f32; 64];
cache.append(0, &data_b, &data_b);
cache.append(1, &data_b, &data_b);
cache.advance();
assert_eq!(cache.len(), 1);
assert_eq!(cache.k(0).len(), 64);
assert_eq!(
cache.k(0)[0],
2.0,
"after clear, new data should overwrite old content"
);
}
#[test]
fn cache_handles_max_realistic_sequence() {
let num_layers = 4;
let num_kv_heads = 2;
let head_dim = 8;
let max_seq = 512;
let entry_size = num_kv_heads * head_dim;
let mut cache = KVCache::with_capacity(num_layers, num_kv_heads, head_dim, max_seq);
for pos in 0..max_seq {
let data: Vec<f32> = (0..entry_size)
.map(|j| (pos * entry_size + j) as f32)
.collect();
for layer in 0..num_layers {
cache.append(layer, &data, &data);
}
cache.advance();
}
assert_eq!(cache.len(), max_seq);
assert_eq!(cache.k(0).len(), max_seq * entry_size);
assert_eq!(cache.k(0)[0], 0.0);
assert_eq!(cache.k(0)[entry_size - 1], (entry_size - 1) as f32);
let last_start = (max_seq - 1) * entry_size;
assert_eq!(cache.k(0)[last_start], last_start as f32);
}
}