Skip to main content

god_graph/transformer/kv_cache/
mod.rs

1//! KV Cache module for efficient autoregressive generation
2//!
3//! KV Cache caches key and value states to avoid recomputing them
4//! during autoregressive generation, significantly improving inference speed.
5
6use crate::tensor::DenseTensor;
7use crate::tensor::traits::TensorBase;
8
9/// KV Cache for caching key and value states during generation
10#[derive(Debug, Clone)]
11pub struct KVCache {
12    /// Key cache [num_layers, max_seq_len, hidden_dim]
13    key_cache: Vec<DenseTensor>,
14    /// Value cache [num_layers, max_seq_len, hidden_dim]
15    value_cache: Vec<DenseTensor>,
16    /// Current sequence length
17    current_len: usize,
18    /// Maximum sequence length
19    max_seq_len: usize,
20    /// Number of layers
21    num_layers: usize,
22    /// Hidden dimension
23    hidden_dim: usize,
24    /// Number of KV heads (for GQA)
25    num_kv_heads: usize,
26}
27
28impl KVCache {
29    /// Create a new KV cache
30    ///
31    /// # Arguments
32    /// * `num_layers` - Number of transformer layers
33    /// * `max_seq_len` - Maximum sequence length to cache
34    /// * `hidden_dim` - Hidden dimension
35    /// * `num_kv_heads` - Number of KV heads (for GQA)
36    pub fn new(num_layers: usize, max_seq_len: usize, hidden_dim: usize, num_kv_heads: usize) -> Self {
37        let head_dim = hidden_dim / num_kv_heads;
38        let key_cache = vec![
39            DenseTensor::zeros(vec![max_seq_len, num_kv_heads, head_dim]);
40            num_layers
41        ];
42        let value_cache = vec![
43            DenseTensor::zeros(vec![max_seq_len, num_kv_heads, head_dim]);
44            num_layers
45        ];
46
47        Self {
48            key_cache,
49            value_cache,
50            current_len: 0,
51            max_seq_len,
52            num_layers,
53            hidden_dim,
54            num_kv_heads,
55        }
56    }
57
58    /// Update cache with new key and value
59    ///
60    /// # Arguments
61    /// * `layer` - Layer index
62    /// * `key` - New key [batch_size, num_kv_heads, head_dim]
63    /// * `value` - New value [batch_size, num_kv_heads, head_dim]
64    /// * `position` - Position to cache at
65    pub fn update(
66        &mut self,
67        layer: usize,
68        key: &DenseTensor,
69        value: &DenseTensor,
70        position: usize,
71    ) {
72        if layer >= self.num_layers || position >= self.max_seq_len {
73            return;
74        }
75
76        // Update key cache
77        if let Some(layer_key) = self.key_cache.get_mut(layer) {
78            Self::copy_to_cache_static(layer_key, key, position, self.num_kv_heads);
79        }
80
81        // Update value cache
82        if let Some(layer_value) = self.value_cache.get_mut(layer) {
83            Self::copy_to_cache_static(layer_value, value, position, self.num_kv_heads);
84        }
85
86        // Update current length
87        if position >= self.current_len {
88            self.current_len = position + 1;
89        }
90    }
91
92    /// Copy tensor to cache at specified position (static method to avoid borrow issues)
93    #[inline]
94    fn copy_to_cache_static(cache: &mut DenseTensor, tensor: &DenseTensor, position: usize, num_kv_heads: usize) {
95        let batch_size = tensor.shape()[0];
96        let head_dim = tensor.shape()[2];
97
98        for b in 0..batch_size {
99            for h in 0..num_kv_heads {
100                let src_offset = (b * num_kv_heads + h) * head_dim;
101                let dst_offset = (position * num_kv_heads + h) * head_dim;
102
103                let src_slice = &tensor.data()[src_offset..src_offset + head_dim];
104                let cache_data = cache.data_mut();
105                cache_data[dst_offset..dst_offset + head_dim].copy_from_slice(src_slice);
106            }
107        }
108    }
109
110    /// Get cached keys and values for a layer
111    ///
112    /// # Arguments
113    /// * `layer` - Layer index
114    /// * `length` - Number of cached positions to retrieve
115    ///
116    /// # Returns
117    /// Tuple of (key_cache, value_cache) with shape [batch_size, length, num_kv_heads, head_dim]
118    pub fn get(&self, layer: usize, length: Option<usize>) -> Option<(DenseTensor, DenseTensor)> {
119        if layer >= self.num_layers {
120            return None;
121        }
122
123        let key_cache = self.key_cache.get(layer)?;
124        let value_cache = self.value_cache.get(layer)?;
125
126        let seq_len = length.unwrap_or(self.current_len);
127
128        // Slice cache to current length
129        let key = self.slice_cache(key_cache, seq_len);
130        let value = self.slice_cache(value_cache, seq_len);
131
132        Some((key, value))
133    }
134
135    /// Slice cache to specified length
136    fn slice_cache(&self, cache: &DenseTensor, length: usize) -> DenseTensor {
137        let num_kv_heads = cache.shape()[1];
138        let head_dim = cache.shape()[2];
139
140        let mut data = Vec::with_capacity(length * num_kv_heads * head_dim);
141
142        for pos in 0..length {
143            for h in 0..num_kv_heads {
144                let offset = (pos * num_kv_heads + h) * head_dim;
145                data.extend_from_slice(&cache.data()[offset..offset + head_dim]);
146            }
147        }
148
149        DenseTensor::new(data, vec![length, num_kv_heads, head_dim])
150    }
151
152    /// Get all cached keys and values for a layer (full history)
153    ///
154    /// # Arguments
155    /// * `layer` - Layer index
156    ///
157    /// # Returns
158    /// Tuple of (key_cache, value_cache)
159    pub fn get_all(&self, layer: usize) -> Option<(DenseTensor, DenseTensor)> {
160        self.get(layer, Some(self.current_len))
161    }
162
163    /// Reset cache for new sequence
164    pub fn reset(&mut self) {
165        self.current_len = 0;
166
167        // Zero out caches
168        for key_cache in &mut self.key_cache {
169            *key_cache = DenseTensor::zeros(key_cache.shape().to_vec());
170        }
171        for value_cache in &mut self.value_cache {
172            *value_cache = DenseTensor::zeros(value_cache.shape().to_vec());
173        }
174    }
175
176    /// Get current sequence length
177    pub fn current_len(&self) -> usize {
178        self.current_len
179    }
180
181    /// Get maximum sequence length
182    pub fn max_seq_len(&self) -> usize {
183        self.max_seq_len
184    }
185
186    /// Get number of layers
187    pub fn num_layers(&self) -> usize {
188        self.num_layers
189    }
190
191    /// Get hidden dimension
192    pub fn hidden_dim(&self) -> usize {
193        self.hidden_dim
194    }
195
196    /// Get number of KV heads
197    pub fn num_kv_heads(&self) -> usize {
198        self.num_kv_heads
199    }
200
201    /// Check if cache is full
202    pub fn is_full(&self) -> bool {
203        self.current_len >= self.max_seq_len
204    }
205
206    /// Get remaining capacity
207    pub fn remaining_capacity(&self) -> usize {
208        self.max_seq_len - self.current_len
209    }
210
211    /// Append new token's KV without position argument (auto-increment)
212    ///
213    /// # Arguments
214    /// * `layer` - Layer index
215    /// * `key` - New key [1, num_kv_heads, head_dim]
216    /// * `value` - New value [1, num_kv_heads, head_dim]
217    pub fn append(&mut self, layer: usize, key: &DenseTensor, value: &DenseTensor) {
218        if self.is_full() {
219            return;
220        }
221        self.update(layer, key, value, self.current_len);
222    }
223
224    /// Concatenate cached KV with new KV for attention computation
225    ///
226    /// # Arguments
227    /// * `layer` - Layer index
228    /// * `new_key` - New key to append
229    /// * `new_value` - New value to append
230    ///
231    /// # Returns
232    /// Tuple of (concatenated_key, concatenated_value)
233    pub fn get_with_new(
234        &self,
235        layer: usize,
236        new_key: &DenseTensor,
237        new_value: &DenseTensor,
238    ) -> Option<(DenseTensor, DenseTensor)> {
239        let (cached_key, cached_value) = self.get(layer, None)?;
240
241        // Concatenate along sequence dimension
242        let key = self.concat_along_seq(&cached_key, new_key);
243        let value = self.concat_along_seq(&cached_value, new_value);
244
245        Some((key, value))
246    }
247
248    /// Concatenate two tensors along sequence dimension
249    fn concat_along_seq(&self, cached: &DenseTensor, new: &DenseTensor) -> DenseTensor {
250        let cached_len = cached.shape()[0];
251        let num_kv_heads = cached.shape()[1];
252        let head_dim = cached.shape()[2];
253
254        let new_len = new.shape()[0];
255        let total_len = cached_len + new_len;
256
257        let mut data = Vec::with_capacity(total_len * num_kv_heads * head_dim);
258
259        // Copy cached data
260        data.extend_from_slice(cached.data());
261
262        // Copy new data
263        data.extend_from_slice(new.data());
264
265        DenseTensor::new(data, vec![total_len, num_kv_heads, head_dim])
266    }
267}
268
269/// Paged KV Cache for vLLM-style memory management
270#[derive(Debug, Clone)]
271pub struct PagedKVCache {
272    /// Block size (tokens per block)
273    block_size: usize,
274    /// Key blocks [num_blocks, block_size, num_kv_heads, head_dim]
275    key_blocks: Vec<DenseTensor>,
276    /// Value blocks
277    value_blocks: Vec<DenseTensor>,
278    /// Block table: logical block -> physical block
279    block_table: Vec<usize>,
280    /// Current sequence length
281    current_len: usize,
282    /// Maximum sequence length
283    max_seq_len: usize,
284    /// Number of layers
285    #[allow(dead_code)]
286    num_layers: usize,
287    /// Hidden dimension
288    #[allow(dead_code)]
289    hidden_dim: usize,
290    /// Number of KV heads
291    num_kv_heads: usize,
292}
293
294impl PagedKVCache {
295    /// Create a new paged KV cache
296    ///
297    /// # Arguments
298    /// * `num_layers` - Number of transformer layers
299    /// * `max_seq_len` - Maximum sequence length
300    /// * `hidden_dim` - Hidden dimension
301    /// * `num_kv_heads` - Number of KV heads
302    /// * `block_size` - Tokens per block (typical: 16 or 32)
303    pub fn new(
304        num_layers: usize,
305        max_seq_len: usize,
306        hidden_dim: usize,
307        num_kv_heads: usize,
308        block_size: usize,
309    ) -> Self {
310        let num_blocks = max_seq_len.div_ceil(block_size);
311        let head_dim = hidden_dim / num_kv_heads;
312
313        let key_blocks = vec![
314            DenseTensor::zeros(vec![num_blocks, block_size, num_kv_heads, head_dim]);
315            num_layers
316        ];
317        let value_blocks = vec![
318            DenseTensor::zeros(vec![num_blocks, block_size, num_kv_heads, head_dim]);
319            num_layers
320        ];
321
322        Self {
323            block_size,
324            key_blocks,
325            value_blocks,
326            block_table: Vec::new(),
327            current_len: 0,
328            max_seq_len,
329            num_layers,
330            hidden_dim,
331            num_kv_heads,
332        }
333    }
334
335    /// Allocate a new block
336    fn allocate_block(&mut self) -> Option<usize> {
337        if self.block_table.len() * self.block_size >= self.max_seq_len {
338            return None; // No more capacity
339        }
340
341        let block_id = self.block_table.len();
342        self.block_table.push(block_id);
343        Some(block_id)
344    }
345
346    /// Update cache with new key and value
347    ///
348    /// # Arguments
349    /// * `layer` - Layer index
350    /// * `key` - New key [1, num_kv_heads, head_dim]
351    /// * `value` - New value [1, num_kv_heads, head_dim]
352    pub fn append(&mut self, layer: usize, key: &DenseTensor, value: &DenseTensor) {
353        if self.current_len >= self.max_seq_len {
354            return;
355        }
356
357        // Check if we need a new block
358        if self.current_len % self.block_size == 0 {
359            self.allocate_block();
360        }
361
362        let block_id = self.block_table.len().saturating_sub(1);
363        let block_offset = self.current_len % self.block_size;
364
365        if let Some(key_block) = self.key_blocks.get_mut(layer) {
366            Self::copy_to_block_static(key_block, block_id, block_offset, key, self.block_size, self.num_kv_heads);
367        }
368
369        if let Some(value_block) = self.value_blocks.get_mut(layer) {
370            Self::copy_to_block_static(value_block, block_id, block_offset, value, self.block_size, self.num_kv_heads);
371        }
372
373        self.current_len += 1;
374    }
375
376    /// Copy tensor to block at specified offset (static method to avoid borrow issues)
377    #[inline]
378    fn copy_to_block_static(
379        block: &mut DenseTensor,
380        block_id: usize,
381        offset: usize,
382        tensor: &DenseTensor,
383        block_size: usize,
384        num_kv_heads: usize,
385    ) {
386        let head_dim = tensor.shape()[2];
387
388        for h in 0..num_kv_heads {
389            let src_offset = h * head_dim;
390            let dst_offset = ((block_id * block_size + offset) * num_kv_heads + h) * head_dim;
391
392            let src_slice = &tensor.data()[src_offset..src_offset + head_dim];
393            let block_data = block.data_mut();
394            block_data[dst_offset..dst_offset + head_dim].copy_from_slice(src_slice);
395        }
396    }
397
398    /// Get current sequence length
399    pub fn current_len(&self) -> usize {
400        self.current_len
401    }
402
403    /// Get number of allocated blocks
404    pub fn num_blocks(&self) -> usize {
405        self.block_table.len()
406    }
407
408    /// Get block table
409    pub fn block_table(&self) -> &[usize] {
410        &self.block_table
411    }
412
413    /// Reset cache
414    pub fn reset(&mut self) {
415        self.current_len = 0;
416        self.block_table.clear();
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_kv_cache_creation() {
426        let cache = KVCache::new(2, 512, 4096, 8);
427
428        assert_eq!(cache.num_layers(), 2);
429        assert_eq!(cache.max_seq_len(), 512);
430        assert_eq!(cache.hidden_dim(), 4096);
431        assert_eq!(cache.num_kv_heads(), 8);
432        assert_eq!(cache.current_len(), 0);
433    }
434
435    #[test]
436    fn test_kv_cache_update() {
437        let mut cache = KVCache::new(2, 512, 4096, 8);
438
439        let key = DenseTensor::ones(vec![1, 8, 512]);
440        let value = DenseTensor::ones(vec![1, 8, 512]);
441
442        cache.update(0, &key, &value, 0);
443
444        assert_eq!(cache.current_len(), 1);
445
446        let (cached_key, cached_value) = cache.get(0, Some(1)).unwrap();
447        assert_eq!(cached_key.shape(), &[1, 8, 512]);
448        assert_eq!(cached_value.shape(), &[1, 8, 512]);
449    }
450
451    #[test]
452    fn test_kv_cache_append() {
453        let mut cache = KVCache::new(2, 512, 4096, 8);
454
455        for i in 0..5 {
456            let key = DenseTensor::full(&vec![1, 8, 512], i as f64);
457            let value = DenseTensor::full(&vec![1, 8, 512], i as f64 * 2.0);
458            cache.append(0, &key, &value);
459        }
460
461        assert_eq!(cache.current_len(), 5);
462        assert!(!cache.is_full());
463        assert_eq!(cache.remaining_capacity(), 512 - 5);
464    }
465
466    #[test]
467    fn test_kv_cache_reset() {
468        let mut cache = KVCache::new(2, 512, 4096, 8);
469
470        let key = DenseTensor::ones(vec![1, 8, 512]);
471        let value = DenseTensor::ones(vec![1, 8, 512]);
472        cache.update(0, &key, &value, 0);
473
474        assert_eq!(cache.current_len(), 1);
475
476        cache.reset();
477
478        assert_eq!(cache.current_len(), 0);
479    }
480
481    #[test]
482    fn test_paged_kv_cache() {
483        let mut cache = PagedKVCache::new(2, 128, 4096, 8, 16);
484
485        for i in 0..20 {
486            let key = DenseTensor::full(&vec![1, 8, 512], i as f64);
487            let value = DenseTensor::full(&vec![1, 8, 512], i as f64);
488            cache.append(0, &key, &value);
489        }
490
491        assert_eq!(cache.current_len(), 20);
492        assert_eq!(cache.num_blocks(), 2); // 20 tokens / 16 block_size = 2 blocks
493    }
494
495    #[test]
496    fn test_gqa_kv_cache() {
497        // LLaMA-3 8B: 32 Q heads, 8 KV heads
498        let mut cache = KVCache::new(32, 8192, 4096, 8);
499
500        let key = DenseTensor::ones(vec![1, 8, 512]);
501        let value = DenseTensor::ones(vec![1, 8, 512]);
502
503        for layer in 0..32 {
504            cache.update(layer, &key, &value, 0);
505        }
506
507        assert_eq!(cache.num_layers(), 32);
508        assert_eq!(cache.num_kv_heads(), 8);
509    }
510}