use super::KvCacheError;
pub struct KvCacheView<'a> {
pub key: &'a mut [f32],
pub value: &'a mut [f32],
pub max_tokens: usize,
pub head_dim: usize,
pub num_heads: usize,
pub used_tokens: usize,
}
impl<'a> KvCacheView<'a> {
pub fn token_stride(&self) -> Option<usize> {
self.num_heads.checked_mul(self.head_dim)
}
pub fn append_token(&mut self, key_token: &[f32], value_token: &[f32]) -> Result<(), KvCacheError> {
let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
if key_token.len() != stride || value_token.len() != stride {
return Err(KvCacheError::ShapeMismatch);
}
if self.used_tokens >= self.max_tokens {
return Err(KvCacheError::CapacityExceeded);
}
let off = self.used_tokens.checked_mul(stride).ok_or(KvCacheError::ShapeMismatch)?;
if off + stride > self.key.len() || off + stride > self.value.len() {
return Err(KvCacheError::CapacityExceeded);
}
self.key[off..off + stride].copy_from_slice(key_token);
self.value[off..off + stride].copy_from_slice(value_token);
self.used_tokens += 1;
Ok(())
}
pub fn len_tokens(&self) -> usize {
self.used_tokens
}
pub fn is_empty(&self) -> bool {
self.used_tokens == 0
}
pub fn is_full(&self) -> bool {
self.used_tokens >= self.max_tokens
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.used_tokens)
}
pub fn clear(&mut self) {
self.used_tokens = 0;
}
pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
if token_index >= self.used_tokens {
return Err(KvCacheError::CapacityExceeded);
}
let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
let off = token_index
.checked_mul(stride)
.ok_or(KvCacheError::ShapeMismatch)?;
let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
if end > self.key.len() || end > self.value.len() {
return Err(KvCacheError::ShapeMismatch);
}
Ok((&self.key[off..end], &self.value[off..end]))
}
}