1#[derive(Debug, Clone)]
11pub struct KVCache {
12 k: Vec<Vec<f32>>,
14 v: Vec<Vec<f32>>,
16 num_layers: usize,
18 num_kv_heads: usize,
20 head_dim: usize,
22 len: usize,
24}
25
26impl KVCache {
27 pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
29 Self {
30 k: (0..num_layers).map(|_| Vec::new()).collect(),
31 v: (0..num_layers).map(|_| Vec::new()).collect(),
32 num_layers,
33 num_kv_heads,
34 head_dim,
35 len: 0,
36 }
37 }
38
39 pub fn with_capacity(
41 num_layers: usize,
42 num_kv_heads: usize,
43 head_dim: usize,
44 max_seq_len: usize,
45 ) -> Self {
46 let entry_size = num_kv_heads * head_dim;
47 let capacity = max_seq_len * entry_size;
48 Self {
49 k: (0..num_layers)
50 .map(|_| Vec::with_capacity(capacity))
51 .collect(),
52 v: (0..num_layers)
53 .map(|_| Vec::with_capacity(capacity))
54 .collect(),
55 num_layers,
56 num_kv_heads,
57 head_dim,
58 len: 0,
59 }
60 }
61
62 pub fn append(&mut self, layer: usize, k_data: &[f32], v_data: &[f32]) {
66 debug_assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
67 debug_assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
68 self.k[layer].extend_from_slice(k_data);
69 self.v[layer].extend_from_slice(v_data);
70 }
71
72 pub fn advance(&mut self) {
75 self.len += 1;
76 }
77
78 pub fn k(&self, layer: usize) -> &[f32] {
80 &self.k[layer]
81 }
82
83 pub fn v(&self, layer: usize) -> &[f32] {
85 &self.v[layer]
86 }
87
88 pub fn len(&self) -> usize {
90 self.len
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.len == 0
96 }
97
98 pub fn num_layers(&self) -> usize {
100 self.num_layers
101 }
102
103 pub fn clear(&mut self) {
105 for layer_k in &mut self.k {
106 layer_k.clear();
107 }
108 for layer_v in &mut self.v {
109 layer_v.clear();
110 }
111 self.len = 0;
112 }
113
114 pub fn entry_size(&self) -> usize {
116 self.num_kv_heads * self.head_dim
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn create_and_append() {
126 let mut cache = KVCache::new(2, 4, 16);
127 assert!(cache.is_empty());
128 assert_eq!(cache.len(), 0);
129 assert_eq!(cache.entry_size(), 64);
130
131 let kv_data = vec![1.0f32; 64];
133 cache.append(0, &kv_data, &kv_data);
134 cache.append(1, &kv_data, &kv_data);
135 cache.advance();
136
137 assert_eq!(cache.len(), 1);
138 assert_eq!(cache.k(0).len(), 64);
139 assert_eq!(cache.v(1).len(), 64);
140 }
141
142 #[test]
143 fn append_multiple_tokens() {
144 let mut cache = KVCache::new(1, 2, 8);
145
146 for token_idx in 0..5 {
147 let data = vec![token_idx as f32; 16];
148 cache.append(0, &data, &data);
149 cache.advance();
150 }
151
152 assert_eq!(cache.len(), 5);
153 assert_eq!(cache.k(0).len(), 80); assert_eq!(cache.k(0)[0], 0.0);
156 assert_eq!(cache.k(0)[64], 4.0);
157 }
158
159 #[test]
160 fn clear_cache() {
161 let mut cache = KVCache::new(2, 4, 16);
162 let data = vec![1.0f32; 64];
163 cache.append(0, &data, &data);
164 cache.append(1, &data, &data);
165 cache.advance();
166 assert_eq!(cache.len(), 1);
167
168 cache.clear();
169 assert_eq!(cache.len(), 0);
170 assert!(cache.is_empty());
171 assert!(cache.k(0).is_empty());
172 }
173
174 #[test]
175 fn with_capacity() {
176 let cache = KVCache::with_capacity(16, 8, 64, 2048);
177 assert!(cache.is_empty());
178 assert_eq!(cache.num_layers(), 16);
179 assert_eq!(cache.entry_size(), 512);
180 }
181
182 #[test]
185 fn clear_resets_completely_for_independent_generation() {
186 let mut cache = KVCache::new(2, 4, 16);
189
190 let data_a = vec![1.0f32; 64];
192 for _ in 0..5 {
193 cache.append(0, &data_a, &data_a);
194 cache.append(1, &data_a, &data_a);
195 cache.advance();
196 }
197 assert_eq!(cache.len(), 5);
198 assert_eq!(cache.k(0).len(), 5 * 64);
199
200 cache.clear();
202 assert_eq!(cache.len(), 0);
203 assert!(cache.is_empty());
204 assert!(cache.k(0).is_empty());
205 assert!(cache.v(0).is_empty());
206 assert!(cache.k(1).is_empty());
207 assert!(cache.v(1).is_empty());
208
209 let data_b = vec![2.0f32; 64];
211 cache.append(0, &data_b, &data_b);
212 cache.append(1, &data_b, &data_b);
213 cache.advance();
214
215 assert_eq!(cache.len(), 1);
216 assert_eq!(cache.k(0).len(), 64);
217 assert_eq!(
219 cache.k(0)[0],
220 2.0,
221 "after clear, new data should overwrite old content"
222 );
223 }
224
225 #[test]
226 fn cache_handles_max_realistic_sequence() {
227 let num_layers = 4;
230 let num_kv_heads = 2;
231 let head_dim = 8;
232 let max_seq = 512;
233 let entry_size = num_kv_heads * head_dim; let mut cache = KVCache::with_capacity(num_layers, num_kv_heads, head_dim, max_seq);
236
237 for pos in 0..max_seq {
238 let data: Vec<f32> = (0..entry_size)
239 .map(|j| (pos * entry_size + j) as f32)
240 .collect();
241 for layer in 0..num_layers {
242 cache.append(layer, &data, &data);
243 }
244 cache.advance();
245 }
246
247 assert_eq!(cache.len(), max_seq);
248 assert_eq!(cache.k(0).len(), max_seq * entry_size);
249
250 assert_eq!(cache.k(0)[0], 0.0);
253 assert_eq!(cache.k(0)[entry_size - 1], (entry_size - 1) as f32);
254 let last_start = (max_seq - 1) * entry_size;
256 assert_eq!(cache.k(0)[last_start], last_start as f32);
257 }
258}