1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
//! Key-Value cache for efficient autoregressive generation.
//!
//! The KV-cache stores computed key and value tensors from previous
//! positions, avoiding redundant computation during generation.
//!
//! Reference: transformers DynamicCache (modeling.py:767-786)
use candle_core::{Result, Tensor};
/// Per-layer key-value cache entry.
#[derive(Debug, Clone)]
pub struct KVCacheEntry {
/// Cached key tensor: (batch, num_kv_heads, seq_len, head_dim)
pub key: Tensor,
/// Cached value tensor: (batch, num_kv_heads, seq_len, head_dim)
pub value: Tensor,
}
impl KVCacheEntry {
/// Create a new cache entry from key and value tensors.
pub fn new(key: Tensor, value: Tensor) -> Self {
Self { key, value }
}
/// Get the current cached sequence length.
pub fn seq_len(&self) -> Result<usize> {
self.key.dim(2)
}
/// Update the cache by appending new key/value tensors.
///
/// Args:
/// new_key: New key tensor to append (batch, num_kv_heads, new_len, head_dim)
/// new_value: New value tensor to append (batch, num_kv_heads, new_len, head_dim)
///
/// Returns:
/// Updated (full_key, full_value) tensors
pub fn update(&mut self, new_key: &Tensor, new_value: &Tensor) -> Result<(Tensor, Tensor)> {
self.key = Tensor::cat(&[&self.key, new_key], 2)?;
self.value = Tensor::cat(&[&self.value, new_value], 2)?;
Ok((self.key.clone(), self.value.clone()))
}
}
/// Dynamic KV-cache that grows as generation progresses.
///
/// Stores key/value tensors for each transformer layer.
#[derive(Debug, Clone, Default)]
pub struct KVCache {
/// Per-layer cache entries (layer_idx -> cache entry)
entries: Vec<Option<KVCacheEntry>>,
/// Total sequence length cached
seq_len: usize,
}
impl KVCache {
/// Create a new empty cache.
pub fn new() -> Self {
Self {
entries: Vec::new(),
seq_len: 0,
}
}
/// Create a cache with pre-allocated layers.
pub fn with_num_layers(num_layers: usize) -> Self {
Self {
entries: vec![None; num_layers],
seq_len: 0,
}
}
/// Get the current cached sequence length.
pub fn seq_len(&self) -> usize {
self.seq_len
}
/// Check if the cache is empty.
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
/// Get the cache entry for a specific layer.
pub fn get(&self, layer_idx: usize) -> Option<&KVCacheEntry> {
self.entries.get(layer_idx).and_then(|e| e.as_ref())
}
/// Update the cache for a specific layer.
///
/// If this is the first update for this layer, creates a new entry.
/// Otherwise, appends to the existing entry.
///
/// Args:
/// layer_idx: The layer index
/// key: Key tensor (batch, num_kv_heads, new_len, head_dim)
/// value: Value tensor (batch, num_kv_heads, new_len, head_dim)
///
/// Returns:
/// The full (key, value) tensors after update
pub fn update(
&mut self,
layer_idx: usize,
key: &Tensor,
value: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Ensure we have enough layers
while self.entries.len() <= layer_idx {
self.entries.push(None);
}
let new_len = key.dim(2)?;
match &mut self.entries[layer_idx] {
Some(entry) => {
let result = entry.update(key, value)?;
// Update seq_len on first layer update
if layer_idx == 0 {
self.seq_len += new_len;
}
Ok(result)
}
None => {
self.entries[layer_idx] = Some(KVCacheEntry::new(key.clone(), value.clone()));
// Update seq_len on first layer update
if layer_idx == 0 {
self.seq_len = new_len;
}
Ok((key.clone(), value.clone()))
}
}
}
/// Clear the cache (e.g., at the start of a new generation).
pub fn clear(&mut self) {
for entry in &mut self.entries {
*entry = None;
}
self.seq_len = 0;
}
/// Get the cache position for RoPE (next position to fill).
pub fn cache_position(&self) -> usize {
self.seq_len
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_kv_cache_basic() -> Result<()> {
let device = Device::Cpu;
let mut cache = KVCache::with_num_layers(2);
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
// First update
let key1 = Tensor::zeros((1, 4, 5, 64), candle_core::DType::F32, &device)?;
let value1 = Tensor::zeros((1, 4, 5, 64), candle_core::DType::F32, &device)?;
let (k, v) = cache.update(0, &key1, &value1)?;
assert_eq!(k.dims(), &[1, 4, 5, 64]);
assert_eq!(v.dims(), &[1, 4, 5, 64]);
assert_eq!(cache.seq_len(), 5);
// Second update (append)
let key2 = Tensor::zeros((1, 4, 1, 64), candle_core::DType::F32, &device)?;
let value2 = Tensor::zeros((1, 4, 1, 64), candle_core::DType::F32, &device)?;
let (k, v) = cache.update(0, &key2, &value2)?;
assert_eq!(k.dims(), &[1, 4, 6, 64]);
assert_eq!(v.dims(), &[1, 4, 6, 64]);
assert_eq!(cache.seq_len(), 6);
Ok(())
}
#[test]
fn test_kv_cache_clear() -> Result<()> {
let device = Device::Cpu;
let mut cache = KVCache::with_num_layers(2);
let key = Tensor::zeros((1, 4, 5, 64), candle_core::DType::F32, &device)?;
let value = Tensor::zeros((1, 4, 5, 64), candle_core::DType::F32, &device)?;
cache.update(0, &key, &value)?;
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.seq_len(), 0);
Ok(())
}
}