pub struct KvCacheStd {
pub key: Vec<f32>,
pub value: Vec<f32>,
pub max_tokens: usize,
pub head_dim: usize,
pub num_heads: usize,
pub used_tokens: usize,
}
#[derive(Debug)]
pub enum KvCacheError {
ShapeMismatch,
CapacityExceeded,
}
impl core::fmt::Display for KvCacheError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
KvCacheError::ShapeMismatch => write!(f, "shape mismatch"),
KvCacheError::CapacityExceeded => write!(f, "capacity exceeded"),
}
}
}
impl std::error::Error for KvCacheError {}
impl KvCacheStd {
pub fn new(max_tokens: usize, head_dim: usize, num_heads: usize) -> Self {
let stride = head_dim
.checked_mul(num_heads)
.expect("kv_cache: head_dim * num_heads overflow");
let total = max_tokens
.checked_mul(stride)
.expect("kv_cache: max_tokens * stride overflow");
Self {
key: vec![0f32; total],
value: vec![0f32; total],
max_tokens,
head_dim,
num_heads,
used_tokens: 0,
}
}
fn token_stride(&self) -> Option<usize> {
self.num_heads.checked_mul(self.head_dim)
}
pub fn append_token(
&mut self,
key_token: &[f32],
value_token: &[f32],
) -> Result<(), KvCacheError> {
let mut view = native_neural_network::kv_cache::KvCacheView {
key: self.key.as_mut_slice(),
value: self.value.as_mut_slice(),
max_tokens: self.max_tokens,
head_dim: self.head_dim,
num_heads: self.num_heads,
used_tokens: self.used_tokens,
};
match view.append_token(key_token, value_token) {
Ok(()) => {
self.used_tokens = view.used_tokens;
Ok(())
}
Err(native_neural_network::kv_cache::KvCacheError::ShapeMismatch) => {
Err(KvCacheError::ShapeMismatch)
}
Err(native_neural_network::kv_cache::KvCacheError::CapacityExceeded) => {
Err(KvCacheError::CapacityExceeded)
}
}
}
pub fn len_tokens(&self) -> usize {
self.used_tokens
}
pub fn is_empty(&self) -> bool {
self.used_tokens == 0
}
pub fn is_full(&self) -> bool {
self.used_tokens >= self.max_tokens
}
pub fn remaining_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.used_tokens)
}
pub fn clear(&mut self) {
self.used_tokens = 0;
}
pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
if token_index >= self.used_tokens {
return Err(KvCacheError::CapacityExceeded);
}
let off = token_index
.checked_mul(stride)
.ok_or(KvCacheError::ShapeMismatch)?;
let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
if end > self.key.len() || end > self.value.len() {
return Err(KvCacheError::ShapeMismatch);
}
Ok((&self.key[off..end], &self.value[off..end]))
}
}
impl core::fmt::Debug for KvCacheStd {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("KvCacheStd")
.field("max_tokens", &self.max_tokens)
.field("used_tokens", &self.used_tokens)
.finish()
}
}
pub struct KvCacheViewStd<'a> {
pub key: &'a mut [f32],
pub value: &'a mut [f32],
pub max_tokens: usize,
pub head_dim: usize,
pub num_heads: usize,
pub used_tokens: usize,
}
impl<'a> KvCacheViewStd<'a> {
pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
if token_index >= self.used_tokens {
return Err(KvCacheError::CapacityExceeded);
}
let stride = self
.num_heads
.checked_mul(self.head_dim)
.ok_or(KvCacheError::ShapeMismatch)?;
let off = token_index
.checked_mul(stride)
.ok_or(KvCacheError::ShapeMismatch)?;
let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
if end > self.key.len() || end > self.value.len() {
return Err(KvCacheError::ShapeMismatch);
}
Ok((&self.key[off..end], &self.value[off..end]))
}
}
pub type KvCacheView<'a> = KvCacheViewStd<'a>;
fn kv_cache_init_hook() -> Result<(), crate::InitError> {
Ok(())
}
pub fn register_module_init_hooks() -> Result<(), crate::RegisterError> {
crate::register_init_hook(
"kv_cache",
crate::InitSubsystem::KvCache,
kv_cache_init_hook,
)
}