candle_mi/cache/kv.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! KV-cache for efficient autoregressive generation.
4//!
5//! Stores key and value tensors from previous positions so they don't
6//! need to be recomputed at each generation step. This enables efficient
7//! token-by-token generation with O(1) complexity per token instead of O(n).
8//!
9//! ## Memory Layout
10//!
11//! Each layer stores:
12//! - keys: `[batch, num_kv_heads, seq_len, head_dim]`
13//! - values: `[batch, num_kv_heads, seq_len, head_dim]`
14//!
15//! ## Memory Estimation
16//!
17//! For a 7B model (typical hyperparameters):
18//! - `num_kv_heads` = 8 (GQA)
19//! - `head_dim` = 128
20//! - `num_layers` = 32
21//! - dtype = BF16 (2 bytes)
22//!
23//! Per token: 8 * 128 * 2 * 2 * 32 = 128KB
24//! For 2048 tokens: ~256MB
25
26use candle_core::Tensor;
27
28use crate::error::{MIError, Result};
29
30/// KV-cache for efficient autoregressive generation.
31///
32/// Stores the key and value tensors from previous positions so they don't
33/// need to be recomputed at each generation step. Each layer has its own
34/// cache entry.
35///
36/// # Shapes
37///
38/// - `keys[i]`: `[batch, num_kv_heads, seq_len, head_dim]`
39/// - `values[i]`: `[batch, num_kv_heads, seq_len, head_dim]`
40#[derive(Debug, Clone)]
41pub struct KVCache {
42 /// Cached key tensors per layer: `[batch, num_kv_heads, seq_len, head_dim]`.
43 keys: Vec<Option<Tensor>>,
44 /// Cached value tensors per layer: `[batch, num_kv_heads, seq_len, head_dim]`.
45 values: Vec<Option<Tensor>>,
46}
47
48impl KVCache {
49 /// Create a new empty cache for the given number of layers.
50 #[must_use]
51 pub fn new(n_layers: usize) -> Self {
52 Self {
53 keys: vec![None; n_layers],
54 values: vec![None; n_layers],
55 }
56 }
57
58 /// Current sequence length from the cache (0 if empty).
59 ///
60 /// # Errors
61 ///
62 /// Returns [`MIError::Model`] if a cached tensor has an unexpected shape.
63 pub fn seq_len(&self) -> Result<usize> {
64 match self.keys.iter().find_map(Option::as_ref) {
65 Some(k) => Ok(k.dim(2)?),
66 None => Ok(0),
67 }
68 }
69
70 /// Whether the cache is empty (no layers have been populated).
71 #[must_use]
72 pub fn is_empty(&self) -> bool {
73 self.keys.iter().all(Option::is_none)
74 }
75
76 /// Number of layers in the cache.
77 #[must_use]
78 pub const fn n_layers(&self) -> usize {
79 self.keys.len()
80 }
81
82 /// Clear all cached tensors.
83 pub fn clear(&mut self) {
84 for k in &mut self.keys {
85 *k = None;
86 }
87 for v in &mut self.values {
88 *v = None;
89 }
90 }
91
92 /// Get mutable references to the cache entry for a specific layer.
93 ///
94 /// Returns `(&mut Option<Tensor>, &mut Option<Tensor>)` for (key, value).
95 ///
96 /// # Errors
97 ///
98 /// Returns [`MIError::Hook`] if `layer` is out of range.
99 pub fn layer_mut(
100 &mut self,
101 layer: usize,
102 ) -> Result<(&mut Option<Tensor>, &mut Option<Tensor>)> {
103 if layer >= self.keys.len() {
104 return Err(MIError::Hook(format!(
105 "layer {layer} out of range for KV cache"
106 )));
107 }
108 // Bounds checked above; keys and values are separate fields so the
109 // borrow checker allows simultaneous mutable borrows.
110 #[allow(clippy::indexing_slicing)]
111 Ok((&mut self.keys[layer], &mut self.values[layer]))
112 }
113
114 /// Estimate memory usage in bytes.
115 ///
116 /// Returns the total memory used by all cached tensors.
117 #[must_use]
118 pub fn memory_usage(&self) -> usize {
119 let key_mem: usize = self
120 .keys
121 .iter()
122 .filter_map(Option::as_ref)
123 .map(|k| k.elem_count() * k.dtype().size_in_bytes())
124 .sum();
125 let value_mem: usize = self
126 .values
127 .iter()
128 .filter_map(Option::as_ref)
129 .map(|v| v.elem_count() * v.dtype().size_in_bytes())
130 .sum();
131 key_mem + value_mem
132 }
133
134 /// Trim the cache to keep only the last `max_seq_len` tokens.
135 ///
136 /// Useful for memory-constrained scenarios with long sequences.
137 /// Returns `Ok(true)` if trimming occurred, `Ok(false)` if no
138 /// trimming was needed.
139 ///
140 /// # Errors
141 ///
142 /// Returns [`MIError::Model`] if tensor operations fail.
143 pub fn trim_to(&mut self, max_seq_len: usize) -> Result<bool> {
144 let current_len = self.seq_len()?;
145 if current_len <= max_seq_len {
146 return Ok(false);
147 }
148
149 let trim_start = current_len - max_seq_len;
150
151 for tensor in self.keys.iter_mut().flatten() {
152 *tensor = tensor.narrow(2, trim_start, max_seq_len)?;
153 }
154 for tensor in self.values.iter_mut().flatten() {
155 *tensor = tensor.narrow(2, trim_start, max_seq_len)?;
156 }
157 Ok(true)
158 }
159
160 /// Check if cache exceeds memory limit and trim if needed.
161 ///
162 /// Trims to ~75% of current length if memory limit is exceeded.
163 /// Returns `Ok(true)` if trimming occurred.
164 ///
165 /// # Errors
166 ///
167 /// Returns [`MIError::Model`] if tensor operations fail.
168 pub fn enforce_memory_limit(&mut self, max_bytes: usize) -> Result<bool> {
169 let current = self.memory_usage();
170 if current > max_bytes {
171 let current_len = self.seq_len()?;
172 let target_len = (current_len * 3) / 4;
173 if target_len > 0 {
174 self.trim_to(target_len)?;
175 return Ok(true);
176 }
177 }
178 Ok(false)
179 }
180}
181
182impl Default for KVCache {
183 fn default() -> Self {
184 Self::new(0)
185 }
186}
187
188// ---------------------------------------------------------------------------
189// Tests
190// ---------------------------------------------------------------------------
191
192#[cfg(test)]
193#[allow(clippy::unwrap_used, clippy::expect_used)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn new_cache() {
199 let cache = KVCache::new(32);
200 assert_eq!(cache.n_layers(), 32);
201 assert!(cache.is_empty());
202 assert_eq!(cache.seq_len().unwrap(), 0);
203 assert_eq!(cache.memory_usage(), 0);
204 }
205
206 #[test]
207 fn clear_cache() {
208 let mut cache = KVCache::new(4);
209 cache.clear();
210 assert!(cache.is_empty());
211 }
212
213 #[test]
214 fn layer_mut_valid() {
215 let mut cache = KVCache::new(4);
216 let (k, v) = cache.layer_mut(2).unwrap();
217 assert!(k.is_none());
218 assert!(v.is_none());
219 }
220
221 #[test]
222 fn layer_mut_out_of_range() {
223 let mut cache = KVCache::new(4);
224 assert!(cache.layer_mut(10).is_err());
225 }
226
227 #[test]
228 fn default_cache() {
229 let cache = KVCache::default();
230 assert_eq!(cache.n_layers(), 0);
231 assert!(cache.is_empty());
232 }
233}