Skip to main content

forgellm_runtime/
kv_cache.rs

1//! KV cache for autoregressive transformer generation.
2//!
3//! Stores key and value projections for each layer across sequence positions.
4//! Designed for single-sequence generation (batch=1).
5
6/// KV cache for a single model.
7///
8/// Pre-allocates buffers for the maximum sequence length to avoid
9/// allocations during generation.
10#[derive(Debug, Clone)]
11pub struct KVCache {
12    /// Key cache: `[num_layers][max_seq_len * num_kv_heads * head_dim]`
13    k: Vec<Vec<f32>>,
14    /// Value cache: `[num_layers][max_seq_len * num_kv_heads * head_dim]`
15    v: Vec<Vec<f32>>,
16    /// Number of layers.
17    num_layers: usize,
18    /// Number of KV heads.
19    num_kv_heads: usize,
20    /// Head dimension.
21    head_dim: usize,
22    /// Current sequence length (number of tokens cached).
23    len: usize,
24}
25
26impl KVCache {
27    /// Create a new empty KV cache.
28    pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
29        Self {
30            k: (0..num_layers).map(|_| Vec::new()).collect(),
31            v: (0..num_layers).map(|_| Vec::new()).collect(),
32            num_layers,
33            num_kv_heads,
34            head_dim,
35            len: 0,
36        }
37    }
38
39    /// Create a new KV cache with pre-allocated capacity.
40    pub fn with_capacity(
41        num_layers: usize,
42        num_kv_heads: usize,
43        head_dim: usize,
44        max_seq_len: usize,
45    ) -> Self {
46        let entry_size = num_kv_heads * head_dim;
47        let capacity = max_seq_len * entry_size;
48        Self {
49            k: (0..num_layers)
50                .map(|_| Vec::with_capacity(capacity))
51                .collect(),
52            v: (0..num_layers)
53                .map(|_| Vec::with_capacity(capacity))
54                .collect(),
55            num_layers,
56            num_kv_heads,
57            head_dim,
58            len: 0,
59        }
60    }
61
62    /// Append K and V vectors for the current token to a specific layer.
63    ///
64    /// `k_data` and `v_data` should each have length `num_kv_heads * head_dim`.
65    pub fn append(&mut self, layer: usize, k_data: &[f32], v_data: &[f32]) {
66        debug_assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
67        debug_assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
68        self.k[layer].extend_from_slice(k_data);
69        self.v[layer].extend_from_slice(v_data);
70    }
71
72    /// Advance the sequence position by one token.
73    /// Call this after appending K/V data to all layers.
74    pub fn advance(&mut self) {
75        self.len += 1;
76    }
77
78    /// Get the full K cache for a layer: `[len * num_kv_heads * head_dim]`.
79    pub fn k(&self, layer: usize) -> &[f32] {
80        &self.k[layer]
81    }
82
83    /// Get the full V cache for a layer: `[len * num_kv_heads * head_dim]`.
84    pub fn v(&self, layer: usize) -> &[f32] {
85        &self.v[layer]
86    }
87
88    /// Current sequence length.
89    pub fn len(&self) -> usize {
90        self.len
91    }
92
93    /// Whether the cache is empty.
94    pub fn is_empty(&self) -> bool {
95        self.len == 0
96    }
97
98    /// Number of layers.
99    pub fn num_layers(&self) -> usize {
100        self.num_layers
101    }
102
103    /// Clear the cache for a new generation.
104    pub fn clear(&mut self) {
105        for layer_k in &mut self.k {
106            layer_k.clear();
107        }
108        for layer_v in &mut self.v {
109            layer_v.clear();
110        }
111        self.len = 0;
112    }
113
114    /// Entry size (num_kv_heads * head_dim) per token per layer.
115    pub fn entry_size(&self) -> usize {
116        self.num_kv_heads * self.head_dim
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn create_and_append() {
126        let mut cache = KVCache::new(2, 4, 16);
127        assert!(cache.is_empty());
128        assert_eq!(cache.len(), 0);
129        assert_eq!(cache.entry_size(), 64);
130
131        // Append one token to both layers
132        let kv_data = vec![1.0f32; 64];
133        cache.append(0, &kv_data, &kv_data);
134        cache.append(1, &kv_data, &kv_data);
135        cache.advance();
136
137        assert_eq!(cache.len(), 1);
138        assert_eq!(cache.k(0).len(), 64);
139        assert_eq!(cache.v(1).len(), 64);
140    }
141
142    #[test]
143    fn append_multiple_tokens() {
144        let mut cache = KVCache::new(1, 2, 8);
145
146        for token_idx in 0..5 {
147            let data = vec![token_idx as f32; 16];
148            cache.append(0, &data, &data);
149            cache.advance();
150        }
151
152        assert_eq!(cache.len(), 5);
153        assert_eq!(cache.k(0).len(), 80); // 5 tokens * 2 heads * 8 dim
154                                          // First entry should be 0.0, last should be 4.0
155        assert_eq!(cache.k(0)[0], 0.0);
156        assert_eq!(cache.k(0)[64], 4.0);
157    }
158
159    #[test]
160    fn clear_cache() {
161        let mut cache = KVCache::new(2, 4, 16);
162        let data = vec![1.0f32; 64];
163        cache.append(0, &data, &data);
164        cache.append(1, &data, &data);
165        cache.advance();
166        assert_eq!(cache.len(), 1);
167
168        cache.clear();
169        assert_eq!(cache.len(), 0);
170        assert!(cache.is_empty());
171        assert!(cache.k(0).is_empty());
172    }
173
174    #[test]
175    fn with_capacity() {
176        let cache = KVCache::with_capacity(16, 8, 64, 2048);
177        assert!(cache.is_empty());
178        assert_eq!(cache.num_layers(), 16);
179        assert_eq!(cache.entry_size(), 512);
180    }
181}