Skip to main content

rnn/kv_cache/
view.rs

1use super::KvCacheError;
2
3pub struct KvCacheView<'a> {
4    pub key: &'a mut [f32],
5    pub value: &'a mut [f32],
6    pub max_tokens: usize,
7    pub head_dim: usize,
8    pub num_heads: usize,
9    pub used_tokens: usize,
10}
11
12impl<'a> KvCacheView<'a> {
13    pub fn token_stride(&self) -> Option<usize> {
14        self.num_heads.checked_mul(self.head_dim)
15    }
16
17    pub fn append_token(&mut self, key_token: &[f32], value_token: &[f32]) -> Result<(), KvCacheError> {
18        let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
19        if key_token.len() != stride || value_token.len() != stride {
20            return Err(KvCacheError::ShapeMismatch);
21        }
22        if self.used_tokens >= self.max_tokens {
23            return Err(KvCacheError::CapacityExceeded);
24        }
25
26        let off = self.used_tokens.checked_mul(stride).ok_or(KvCacheError::ShapeMismatch)?;
27        if off + stride > self.key.len() || off + stride > self.value.len() {
28            return Err(KvCacheError::CapacityExceeded);
29        }
30        self.key[off..off + stride].copy_from_slice(key_token);
31        self.value[off..off + stride].copy_from_slice(value_token);
32        self.used_tokens += 1;
33        Ok(())
34    }
35
36    pub fn len_tokens(&self) -> usize {
37        self.used_tokens
38    }
39
40    pub fn is_empty(&self) -> bool {
41        self.used_tokens == 0
42    }
43
44    pub fn is_full(&self) -> bool {
45        self.used_tokens >= self.max_tokens
46    }
47
48    pub fn remaining_tokens(&self) -> usize {
49        self.max_tokens.saturating_sub(self.used_tokens)
50    }
51
52    pub fn clear(&mut self) {
53        self.used_tokens = 0;
54    }
55
56    pub fn token_slices(&self, token_index: usize) -> Result<(&[f32], &[f32]), KvCacheError> {
57        if token_index >= self.used_tokens {
58            return Err(KvCacheError::CapacityExceeded);
59        }
60        let stride = self.token_stride().ok_or(KvCacheError::ShapeMismatch)?;
61        let off = token_index
62            .checked_mul(stride)
63            .ok_or(KvCacheError::ShapeMismatch)?;
64        let end = off.checked_add(stride).ok_or(KvCacheError::ShapeMismatch)?;
65        if end > self.key.len() || end > self.value.len() {
66            return Err(KvCacheError::ShapeMismatch);
67        }
68        Ok((&self.key[off..end], &self.value[off..end]))
69    }
70}