Skip to main content

forgellm_runtime/
kv_cache.rs

1//! KV cache for autoregressive transformer generation.
2//!
3//! Stores key and value projections for each layer across sequence positions.
4//! Designed for single-sequence generation (batch=1).
5
6/// KV cache for a single model.
7///
8/// Pre-allocates buffers for the maximum sequence length to avoid
9/// allocations during generation.
10#[derive(Debug, Clone)]
11pub struct KVCache {
12    /// Key cache: `[num_layers][max_seq_len * num_kv_heads * head_dim]`
13    k: Vec<Vec<f32>>,
14    /// Value cache: `[num_layers][max_seq_len * num_kv_heads * head_dim]`
15    v: Vec<Vec<f32>>,
16    /// Number of layers.
17    num_layers: usize,
18    /// Number of KV heads.
19    num_kv_heads: usize,
20    /// Head dimension.
21    head_dim: usize,
22    /// Current sequence length (number of tokens cached).
23    len: usize,
24}
25
26impl KVCache {
27    /// Create a new empty KV cache.
28    pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
29        Self {
30            k: (0..num_layers).map(|_| Vec::new()).collect(),
31            v: (0..num_layers).map(|_| Vec::new()).collect(),
32            num_layers,
33            num_kv_heads,
34            head_dim,
35            len: 0,
36        }
37    }
38
39    /// Create a new KV cache with pre-allocated capacity.
40    pub fn with_capacity(
41        num_layers: usize,
42        num_kv_heads: usize,
43        head_dim: usize,
44        max_seq_len: usize,
45    ) -> Self {
46        let entry_size = num_kv_heads * head_dim;
47        let capacity = max_seq_len * entry_size;
48        Self {
49            k: (0..num_layers)
50                .map(|_| Vec::with_capacity(capacity))
51                .collect(),
52            v: (0..num_layers)
53                .map(|_| Vec::with_capacity(capacity))
54                .collect(),
55            num_layers,
56            num_kv_heads,
57            head_dim,
58            len: 0,
59        }
60    }
61
62    /// Append K and V vectors for the current token to a specific layer.
63    ///
64    /// `k_data` and `v_data` should each have length `num_kv_heads * head_dim`.
65    pub fn append(&mut self, layer: usize, k_data: &[f32], v_data: &[f32]) {
66        debug_assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
67        debug_assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
68        self.k[layer].extend_from_slice(k_data);
69        self.v[layer].extend_from_slice(v_data);
70    }
71
72    /// Advance the sequence position by one token.
73    /// Call this after appending K/V data to all layers.
74    pub fn advance(&mut self) {
75        self.len += 1;
76    }
77
78    /// Get the full K cache for a layer: `[len * num_kv_heads * head_dim]`.
79    pub fn k(&self, layer: usize) -> &[f32] {
80        &self.k[layer]
81    }
82
83    /// Get the full V cache for a layer: `[len * num_kv_heads * head_dim]`.
84    pub fn v(&self, layer: usize) -> &[f32] {
85        &self.v[layer]
86    }
87
88    /// Current sequence length.
89    pub fn len(&self) -> usize {
90        self.len
91    }
92
93    /// Whether the cache is empty.
94    pub fn is_empty(&self) -> bool {
95        self.len == 0
96    }
97
98    /// Number of layers.
99    pub fn num_layers(&self) -> usize {
100        self.num_layers
101    }
102
103    /// Clear the cache for a new generation.
104    pub fn clear(&mut self) {
105        for layer_k in &mut self.k {
106            layer_k.clear();
107        }
108        for layer_v in &mut self.v {
109            layer_v.clear();
110        }
111        self.len = 0;
112    }
113
114    /// Entry size (num_kv_heads * head_dim) per token per layer.
115    pub fn entry_size(&self) -> usize {
116        self.num_kv_heads * self.head_dim
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn create_and_append() {
126        let mut cache = KVCache::new(2, 4, 16);
127        assert!(cache.is_empty());
128        assert_eq!(cache.len(), 0);
129        assert_eq!(cache.entry_size(), 64);
130
131        // Append one token to both layers
132        let kv_data = vec![1.0f32; 64];
133        cache.append(0, &kv_data, &kv_data);
134        cache.append(1, &kv_data, &kv_data);
135        cache.advance();
136
137        assert_eq!(cache.len(), 1);
138        assert_eq!(cache.k(0).len(), 64);
139        assert_eq!(cache.v(1).len(), 64);
140    }
141
142    #[test]
143    fn append_multiple_tokens() {
144        let mut cache = KVCache::new(1, 2, 8);
145
146        for token_idx in 0..5 {
147            let data = vec![token_idx as f32; 16];
148            cache.append(0, &data, &data);
149            cache.advance();
150        }
151
152        assert_eq!(cache.len(), 5);
153        assert_eq!(cache.k(0).len(), 80); // 5 tokens * 2 heads * 8 dim
154                                          // First entry should be 0.0, last should be 4.0
155        assert_eq!(cache.k(0)[0], 0.0);
156        assert_eq!(cache.k(0)[64], 4.0);
157    }
158
159    #[test]
160    fn clear_cache() {
161        let mut cache = KVCache::new(2, 4, 16);
162        let data = vec![1.0f32; 64];
163        cache.append(0, &data, &data);
164        cache.append(1, &data, &data);
165        cache.advance();
166        assert_eq!(cache.len(), 1);
167
168        cache.clear();
169        assert_eq!(cache.len(), 0);
170        assert!(cache.is_empty());
171        assert!(cache.k(0).is_empty());
172    }
173
174    #[test]
175    fn with_capacity() {
176        let cache = KVCache::with_capacity(16, 8, 64, 2048);
177        assert!(cache.is_empty());
178        assert_eq!(cache.num_layers(), 16);
179        assert_eq!(cache.entry_size(), 512);
180    }
181
182    // ── Real-world validation tests ──────────────────────────────────────
183
184    #[test]
185    fn clear_resets_completely_for_independent_generation() {
186        // After clear(), the cache should behave identically to a fresh cache.
187        // This matters for multi-turn serving where the model resets between requests.
188        let mut cache = KVCache::new(2, 4, 16);
189
190        // Fill cache with some data
191        let data_a = vec![1.0f32; 64];
192        for _ in 0..5 {
193            cache.append(0, &data_a, &data_a);
194            cache.append(1, &data_a, &data_a);
195            cache.advance();
196        }
197        assert_eq!(cache.len(), 5);
198        assert_eq!(cache.k(0).len(), 5 * 64);
199
200        // Clear and verify complete reset
201        cache.clear();
202        assert_eq!(cache.len(), 0);
203        assert!(cache.is_empty());
204        assert!(cache.k(0).is_empty());
205        assert!(cache.v(0).is_empty());
206        assert!(cache.k(1).is_empty());
207        assert!(cache.v(1).is_empty());
208
209        // Append new data after clear — should be independent of previous content
210        let data_b = vec![2.0f32; 64];
211        cache.append(0, &data_b, &data_b);
212        cache.append(1, &data_b, &data_b);
213        cache.advance();
214
215        assert_eq!(cache.len(), 1);
216        assert_eq!(cache.k(0).len(), 64);
217        // First element should be from data_b, not data_a
218        assert_eq!(
219            cache.k(0)[0],
220            2.0,
221            "after clear, new data should overwrite old content"
222        );
223    }
224
225    #[test]
226    fn cache_handles_max_realistic_sequence() {
227        // Simulate filling a cache up to a realistic max sequence length (512 tokens)
228        // for a small model and verify no data corruption at the boundary.
229        let num_layers = 4;
230        let num_kv_heads = 2;
231        let head_dim = 8;
232        let max_seq = 512;
233        let entry_size = num_kv_heads * head_dim; // 16
234
235        let mut cache = KVCache::with_capacity(num_layers, num_kv_heads, head_dim, max_seq);
236
237        for pos in 0..max_seq {
238            let data: Vec<f32> = (0..entry_size)
239                .map(|j| (pos * entry_size + j) as f32)
240                .collect();
241            for layer in 0..num_layers {
242                cache.append(layer, &data, &data);
243            }
244            cache.advance();
245        }
246
247        assert_eq!(cache.len(), max_seq);
248        assert_eq!(cache.k(0).len(), max_seq * entry_size);
249
250        // Verify data integrity at first and last positions
251        // First token in layer 0: values 0..16
252        assert_eq!(cache.k(0)[0], 0.0);
253        assert_eq!(cache.k(0)[entry_size - 1], (entry_size - 1) as f32);
254        // Last token in layer 0: values (511*16)..(511*16+16)
255        let last_start = (max_seq - 1) * entry_size;
256        assert_eq!(cache.k(0)[last_start], last_start as f32);
257    }
258}