1use super::KvCacheError;
2
3pub struct KvCacheView<'a> {
4 pub key: &'a mut [f32],
5 pub value: &'a mut [f32],
6 pub max_tokens: usize,
7 pub head_dim: usize,
8 pub num_heads: usize,
9 pub used_tokens: usize,
10}
11
12impl<'a> KvCacheView<'a> {
13 pub fn token_stride(&self) -> Option<usize> {
14 self.num_heads.checked_mul(self.head_dim)
15 }
16
17 pub fn append_token(&mut self, key_token: &[f32], value_token: &[f32]) -> Result<(), KvCacheError> {
18 let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
19 if key_token.len() != stride || value_token.len() != stride {
20 return Err(KvCacheError::ShapeMismatch);
21 }
22 if self.used_tokens >= self.max_tokens {
23 return Err(KvCacheError::CapacityExceeded);
24 }
25
26 let off = self.used_tokens.checked_mul(stride).ok_or(KvCacheError::ShapeMismatch)?;
27 if off + stride > self.key.len() || off + stride > self.value.len() {
28 return Err(KvCacheError::CapacityExceeded);
29 }
30 self.key[off..off + stride].copy_from_slice(key_token);
31 self.value[off..off + stride].copy_from_slice(value_token);
32 self.used_tokens += 1;
33 Ok(())
34 }
35
36 pub fn len_tokens(&self) -> usize {
37 self.used_tokens
38 }
39
40 pub fn is_empty(&self) -> bool {
41 self.used_tokens == 0
42 }
43
44 pub fn is_full(&self) -> bool {
45 self.used_tokens >= self.max_tokens
46 }
47
48 pub fn remaining_tokens(&self) -> usize {
49 self.max_tokens.saturating_sub(self.used_tokens)
50 }
51
52 pub fn clear(&mut self) {
53 self.used_tokens = 0;
54 }
55
56 pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
57 if token_index >= self.used_tokens {
58 return Err(KvCacheError::CapacityExceeded);
59 }
60 let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
61 let off = token_index
62 .checked_mul(stride)
63 .ok_or(KvCacheError::ShapeMismatch)?;
64 let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
65 if end > self.key.len() || end > self.value.len() {
66 return Err(KvCacheError::ShapeMismatch);
67 }
68 Ok((&self.key[off..end], &self.value[off..end]))
69 }
70}