Skip to main content

candle_mi/cache/
kv.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! KV-cache for efficient autoregressive generation.
4//!
5//! Stores key and value tensors from previous positions so they don't
6//! need to be recomputed at each generation step. This enables efficient
7//! token-by-token generation with O(1) complexity per token instead of O(n).
8//!
9//! ## Memory Layout
10//!
11//! Each layer stores:
12//! - keys: `[batch, num_kv_heads, seq_len, head_dim]`
13//! - values: `[batch, num_kv_heads, seq_len, head_dim]`
14//!
15//! ## Memory Estimation
16//!
17//! For a 7B model (typical hyperparameters):
18//! - `num_kv_heads` = 8 (GQA)
19//! - `head_dim` = 128
20//! - `num_layers` = 32
21//! - dtype = BF16 (2 bytes)
22//!
23//! Per token: 8 * 128 * 2 * 2 * 32 = 128KB
24//! For 2048 tokens: ~256MB
25
26use candle_core::Tensor;
27
28use crate::error::{MIError, Result};
29
30/// KV-cache for efficient autoregressive generation.
31///
32/// Stores the key and value tensors from previous positions so they don't
33/// need to be recomputed at each generation step. Each layer has its own
34/// cache entry.
35///
36/// # Shapes
37///
38/// - `keys[i]`: `[batch, num_kv_heads, seq_len, head_dim]`
39/// - `values[i]`: `[batch, num_kv_heads, seq_len, head_dim]`
40#[derive(Debug, Clone)]
41pub struct KVCache {
42    /// Cached key tensors per layer: `[batch, num_kv_heads, seq_len, head_dim]`.
43    keys: Vec<Option<Tensor>>,
44    /// Cached value tensors per layer: `[batch, num_kv_heads, seq_len, head_dim]`.
45    values: Vec<Option<Tensor>>,
46}
47
48impl KVCache {
49    /// Create a new empty cache for the given number of layers.
50    #[must_use]
51    pub fn new(n_layers: usize) -> Self {
52        Self {
53            keys: vec![None; n_layers],
54            values: vec![None; n_layers],
55        }
56    }
57
58    /// Current sequence length from the cache (0 if empty).
59    ///
60    /// # Errors
61    ///
62    /// Returns [`MIError::Model`] if a cached tensor has an unexpected shape.
63    pub fn seq_len(&self) -> Result<usize> {
64        match self.keys.iter().find_map(Option::as_ref) {
65            Some(k) => Ok(k.dim(2)?),
66            None => Ok(0),
67        }
68    }
69
70    /// Whether the cache is empty (no layers have been populated).
71    #[must_use]
72    pub fn is_empty(&self) -> bool {
73        self.keys.iter().all(Option::is_none)
74    }
75
76    /// Number of layers in the cache.
77    #[must_use]
78    pub const fn n_layers(&self) -> usize {
79        self.keys.len()
80    }
81
82    /// Clear all cached tensors.
83    pub fn clear(&mut self) {
84        for k in &mut self.keys {
85            *k = None;
86        }
87        for v in &mut self.values {
88            *v = None;
89        }
90    }
91
92    /// Get mutable references to the cache entry for a specific layer.
93    ///
94    /// Returns `(&mut Option<Tensor>, &mut Option<Tensor>)` for (key, value).
95    ///
96    /// # Errors
97    ///
98    /// Returns [`MIError::Hook`] if `layer` is out of range.
99    pub fn layer_mut(
100        &mut self,
101        layer: usize,
102    ) -> Result<(&mut Option<Tensor>, &mut Option<Tensor>)> {
103        if layer >= self.keys.len() {
104            return Err(MIError::Hook(format!(
105                "layer {layer} out of range for KV cache"
106            )));
107        }
108        // Bounds checked above; keys and values are separate fields so the
109        // borrow checker allows simultaneous mutable borrows.
110        #[allow(clippy::indexing_slicing)]
111        Ok((&mut self.keys[layer], &mut self.values[layer]))
112    }
113
114    /// Estimate memory usage in bytes.
115    ///
116    /// Returns the total memory used by all cached tensors.
117    #[must_use]
118    pub fn memory_usage(&self) -> usize {
119        let key_mem: usize = self
120            .keys
121            .iter()
122            .filter_map(Option::as_ref)
123            .map(|k| k.elem_count() * k.dtype().size_in_bytes())
124            .sum();
125        let value_mem: usize = self
126            .values
127            .iter()
128            .filter_map(Option::as_ref)
129            .map(|v| v.elem_count() * v.dtype().size_in_bytes())
130            .sum();
131        key_mem + value_mem
132    }
133
134    /// Trim the cache to keep only the last `max_seq_len` tokens.
135    ///
136    /// Useful for memory-constrained scenarios with long sequences.
137    /// Returns `Ok(true)` if trimming occurred, `Ok(false)` if no
138    /// trimming was needed.
139    ///
140    /// # Errors
141    ///
142    /// Returns [`MIError::Model`] if tensor operations fail.
143    pub fn trim_to(&mut self, max_seq_len: usize) -> Result<bool> {
144        let current_len = self.seq_len()?;
145        if current_len <= max_seq_len {
146            return Ok(false);
147        }
148
149        let trim_start = current_len - max_seq_len;
150
151        for tensor in self.keys.iter_mut().flatten() {
152            *tensor = tensor.narrow(2, trim_start, max_seq_len)?;
153        }
154        for tensor in self.values.iter_mut().flatten() {
155            *tensor = tensor.narrow(2, trim_start, max_seq_len)?;
156        }
157        Ok(true)
158    }
159
160    /// Check if cache exceeds memory limit and trim if needed.
161    ///
162    /// Trims to ~75% of current length if memory limit is exceeded.
163    /// Returns `Ok(true)` if trimming occurred.
164    ///
165    /// # Errors
166    ///
167    /// Returns [`MIError::Model`] if tensor operations fail.
168    pub fn enforce_memory_limit(&mut self, max_bytes: usize) -> Result<bool> {
169        let current = self.memory_usage();
170        if current > max_bytes {
171            let current_len = self.seq_len()?;
172            let target_len = (current_len * 3) / 4;
173            if target_len > 0 {
174                self.trim_to(target_len)?;
175                return Ok(true);
176            }
177        }
178        Ok(false)
179    }
180}
181
182impl Default for KVCache {
183    fn default() -> Self {
184        Self::new(0)
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Tests
190// ---------------------------------------------------------------------------
191
192#[cfg(test)]
193#[allow(clippy::unwrap_used, clippy::expect_used)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn new_cache() {
199        let cache = KVCache::new(32);
200        assert_eq!(cache.n_layers(), 32);
201        assert!(cache.is_empty());
202        assert_eq!(cache.seq_len().unwrap(), 0);
203        assert_eq!(cache.memory_usage(), 0);
204    }
205
206    #[test]
207    fn clear_cache() {
208        let mut cache = KVCache::new(4);
209        cache.clear();
210        assert!(cache.is_empty());
211    }
212
213    #[test]
214    fn layer_mut_valid() {
215        let mut cache = KVCache::new(4);
216        let (k, v) = cache.layer_mut(2).unwrap();
217        assert!(k.is_none());
218        assert!(v.is_none());
219    }
220
221    #[test]
222    fn layer_mut_out_of_range() {
223        let mut cache = KVCache::new(4);
224        assert!(cache.layer_mut(10).is_err());
225    }
226
227    #[test]
228    fn default_cache() {
229        let cache = KVCache::default();
230        assert_eq!(cache.n_layers(), 0);
231        assert!(cache.is_empty());
232    }
233}