use candle_core::Tensor;
use crate::error::{MIError, Result};
#[derive(Debug, Clone)]
pub struct KVCache {
keys: Vec<Option<Tensor>>,
values: Vec<Option<Tensor>>,
}
impl KVCache {
#[must_use]
pub fn new(n_layers: usize) -> Self {
Self {
keys: vec![None; n_layers],
values: vec![None; n_layers],
}
}
pub fn seq_len(&self) -> Result<usize> {
match self.keys.iter().find_map(Option::as_ref) {
Some(k) => Ok(k.dim(2)?),
None => Ok(0),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.keys.iter().all(Option::is_none)
}
#[must_use]
pub const fn n_layers(&self) -> usize {
self.keys.len()
}
pub fn clear(&mut self) {
for k in &mut self.keys {
*k = None;
}
for v in &mut self.values {
*v = None;
}
}
pub fn layer_mut(
&mut self,
layer: usize,
) -> Result<(&mut Option<Tensor>, &mut Option<Tensor>)> {
if layer >= self.keys.len() {
return Err(MIError::Hook(format!(
"layer {layer} out of range for KV cache"
)));
}
#[allow(clippy::indexing_slicing)]
Ok((&mut self.keys[layer], &mut self.values[layer]))
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let key_mem: usize = self
.keys
.iter()
.filter_map(Option::as_ref)
.map(|k| k.elem_count() * k.dtype().size_in_bytes())
.sum();
let value_mem: usize = self
.values
.iter()
.filter_map(Option::as_ref)
.map(|v| v.elem_count() * v.dtype().size_in_bytes())
.sum();
key_mem + value_mem
}
pub fn trim_to(&mut self, max_seq_len: usize) -> Result<bool> {
let current_len = self.seq_len()?;
if current_len <= max_seq_len {
return Ok(false);
}
let trim_start = current_len - max_seq_len;
for tensor in self.keys.iter_mut().flatten() {
*tensor = tensor.narrow(2, trim_start, max_seq_len)?;
}
for tensor in self.values.iter_mut().flatten() {
*tensor = tensor.narrow(2, trim_start, max_seq_len)?;
}
Ok(true)
}
pub fn enforce_memory_limit(&mut self, max_bytes: usize) -> Result<bool> {
let current = self.memory_usage();
if current > max_bytes {
let current_len = self.seq_len()?;
let target_len = (current_len * 3) / 4;
if target_len > 0 {
self.trim_to(target_len)?;
return Ok(true);
}
}
Ok(false)
}
}
impl Default for KVCache {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn new_cache() {
let cache = KVCache::new(32);
assert_eq!(cache.n_layers(), 32);
assert!(cache.is_empty());
assert_eq!(cache.seq_len().unwrap(), 0);
assert_eq!(cache.memory_usage(), 0);
}
#[test]
fn clear_cache() {
let mut cache = KVCache::new(4);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn layer_mut_valid() {
let mut cache = KVCache::new(4);
let (k, v) = cache.layer_mut(2).unwrap();
assert!(k.is_none());
assert!(v.is_none());
}
#[test]
fn layer_mut_out_of_range() {
let mut cache = KVCache::new(4);
assert!(cache.layer_mut(10).is_err());
}
#[test]
fn default_cache() {
let cache = KVCache::default();
assert_eq!(cache.n_layers(), 0);
assert!(cache.is_empty());
}
}