use sapient_core::Tensor;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LayerKVCache {
pub keys: Vec<Tensor>,
pub values: Vec<Tensor>,
}
impl LayerKVCache {
pub fn empty() -> Self {
Self {
keys: Vec::new(),
values: Vec::new(),
}
}
pub fn append(&mut self, k: Tensor, v: Tensor) -> usize {
self.keys.push(k);
self.values.push(v);
self.keys.len()
}
pub fn seq_len(&self) -> usize {
self.keys.len()
}
pub fn clear(&mut self) {
self.keys.clear();
self.values.clear();
}
}
#[derive(Debug, Clone)]
pub struct KVCache {
layers: Vec<LayerKVCache>,
}
impl KVCache {
pub fn new(n_layers: usize) -> Self {
Self {
layers: (0..n_layers).map(|_| LayerKVCache::empty()).collect(),
}
}
pub fn layer(&self, idx: usize) -> &LayerKVCache {
&self.layers[idx]
}
pub fn layer_mut(&mut self, idx: usize) -> &mut LayerKVCache {
&mut self.layers[idx]
}
pub fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.seq_len()).unwrap_or(0)
}
pub fn clear(&mut self) {
for l in &mut self.layers {
l.clear();
}
}
pub fn n_layers(&self) -> usize {
self.layers.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kv_cache_grows() {
let mut cache = KVCache::new(4);
assert_eq!(cache.seq_len(), 0);
let k = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
let v = Tensor::zeros(vec![1, 2, 1, 64], sapient_core::DType::F32).unwrap();
cache.layer_mut(0).append(k.clone(), v.clone());
cache.layer_mut(0).append(k, v);
assert_eq!(cache.layer(0).seq_len(), 2);
}
#[test]
fn kv_cache_clear() {
let mut cache = KVCache::new(2);
let t = Tensor::zeros(vec![1, 1, 1, 64], sapient_core::DType::F32).unwrap();
cache.layer_mut(0).append(t.clone(), t);
assert_eq!(cache.seq_len(), 1);
cache.clear();
assert_eq!(cache.seq_len(), 0);
}
}