1use crate::tensor::DenseTensor;
7use crate::tensor::traits::TensorBase;
8
9#[derive(Debug, Clone)]
11pub struct KVCache {
12 key_cache: Vec<DenseTensor>,
14 value_cache: Vec<DenseTensor>,
16 current_len: usize,
18 max_seq_len: usize,
20 num_layers: usize,
22 hidden_dim: usize,
24 num_kv_heads: usize,
26}
27
28impl KVCache {
29 pub fn new(num_layers: usize, max_seq_len: usize, hidden_dim: usize, num_kv_heads: usize) -> Self {
37 let head_dim = hidden_dim / num_kv_heads;
38 let key_cache = vec![
39 DenseTensor::zeros(vec![max_seq_len, num_kv_heads, head_dim]);
40 num_layers
41 ];
42 let value_cache = vec![
43 DenseTensor::zeros(vec![max_seq_len, num_kv_heads, head_dim]);
44 num_layers
45 ];
46
47 Self {
48 key_cache,
49 value_cache,
50 current_len: 0,
51 max_seq_len,
52 num_layers,
53 hidden_dim,
54 num_kv_heads,
55 }
56 }
57
58 pub fn update(
66 &mut self,
67 layer: usize,
68 key: &DenseTensor,
69 value: &DenseTensor,
70 position: usize,
71 ) {
72 if layer >= self.num_layers || position >= self.max_seq_len {
73 return;
74 }
75
76 if let Some(layer_key) = self.key_cache.get_mut(layer) {
78 Self::copy_to_cache_static(layer_key, key, position, self.num_kv_heads);
79 }
80
81 if let Some(layer_value) = self.value_cache.get_mut(layer) {
83 Self::copy_to_cache_static(layer_value, value, position, self.num_kv_heads);
84 }
85
86 if position >= self.current_len {
88 self.current_len = position + 1;
89 }
90 }
91
92 #[inline]
94 fn copy_to_cache_static(cache: &mut DenseTensor, tensor: &DenseTensor, position: usize, num_kv_heads: usize) {
95 let batch_size = tensor.shape()[0];
96 let head_dim = tensor.shape()[2];
97
98 for b in 0..batch_size {
99 for h in 0..num_kv_heads {
100 let src_offset = (b * num_kv_heads + h) * head_dim;
101 let dst_offset = (position * num_kv_heads + h) * head_dim;
102
103 let src_slice = &tensor.data()[src_offset..src_offset + head_dim];
104 let cache_data = cache.data_mut();
105 cache_data[dst_offset..dst_offset + head_dim].copy_from_slice(src_slice);
106 }
107 }
108 }
109
110 pub fn get(&self, layer: usize, length: Option<usize>) -> Option<(DenseTensor, DenseTensor)> {
119 if layer >= self.num_layers {
120 return None;
121 }
122
123 let key_cache = self.key_cache.get(layer)?;
124 let value_cache = self.value_cache.get(layer)?;
125
126 let seq_len = length.unwrap_or(self.current_len);
127
128 let key = self.slice_cache(key_cache, seq_len);
130 let value = self.slice_cache(value_cache, seq_len);
131
132 Some((key, value))
133 }
134
135 fn slice_cache(&self, cache: &DenseTensor, length: usize) -> DenseTensor {
137 let num_kv_heads = cache.shape()[1];
138 let head_dim = cache.shape()[2];
139
140 let mut data = Vec::with_capacity(length * num_kv_heads * head_dim);
141
142 for pos in 0..length {
143 for h in 0..num_kv_heads {
144 let offset = (pos * num_kv_heads + h) * head_dim;
145 data.extend_from_slice(&cache.data()[offset..offset + head_dim]);
146 }
147 }
148
149 DenseTensor::new(data, vec![length, num_kv_heads, head_dim])
150 }
151
152 pub fn get_all(&self, layer: usize) -> Option<(DenseTensor, DenseTensor)> {
160 self.get(layer, Some(self.current_len))
161 }
162
163 pub fn reset(&mut self) {
165 self.current_len = 0;
166
167 for key_cache in &mut self.key_cache {
169 *key_cache = DenseTensor::zeros(key_cache.shape().to_vec());
170 }
171 for value_cache in &mut self.value_cache {
172 *value_cache = DenseTensor::zeros(value_cache.shape().to_vec());
173 }
174 }
175
176 pub fn current_len(&self) -> usize {
178 self.current_len
179 }
180
181 pub fn max_seq_len(&self) -> usize {
183 self.max_seq_len
184 }
185
186 pub fn num_layers(&self) -> usize {
188 self.num_layers
189 }
190
191 pub fn hidden_dim(&self) -> usize {
193 self.hidden_dim
194 }
195
196 pub fn num_kv_heads(&self) -> usize {
198 self.num_kv_heads
199 }
200
201 pub fn is_full(&self) -> bool {
203 self.current_len >= self.max_seq_len
204 }
205
206 pub fn remaining_capacity(&self) -> usize {
208 self.max_seq_len - self.current_len
209 }
210
211 pub fn append(&mut self, layer: usize, key: &DenseTensor, value: &DenseTensor) {
218 if self.is_full() {
219 return;
220 }
221 self.update(layer, key, value, self.current_len);
222 }
223
224 pub fn get_with_new(
234 &self,
235 layer: usize,
236 new_key: &DenseTensor,
237 new_value: &DenseTensor,
238 ) -> Option<(DenseTensor, DenseTensor)> {
239 let (cached_key, cached_value) = self.get(layer, None)?;
240
241 let key = self.concat_along_seq(&cached_key, new_key);
243 let value = self.concat_along_seq(&cached_value, new_value);
244
245 Some((key, value))
246 }
247
248 fn concat_along_seq(&self, cached: &DenseTensor, new: &DenseTensor) -> DenseTensor {
250 let cached_len = cached.shape()[0];
251 let num_kv_heads = cached.shape()[1];
252 let head_dim = cached.shape()[2];
253
254 let new_len = new.shape()[0];
255 let total_len = cached_len + new_len;
256
257 let mut data = Vec::with_capacity(total_len * num_kv_heads * head_dim);
258
259 data.extend_from_slice(cached.data());
261
262 data.extend_from_slice(new.data());
264
265 DenseTensor::new(data, vec![total_len, num_kv_heads, head_dim])
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct PagedKVCache {
272 block_size: usize,
274 key_blocks: Vec<DenseTensor>,
276 value_blocks: Vec<DenseTensor>,
278 block_table: Vec<usize>,
280 current_len: usize,
282 max_seq_len: usize,
284 #[allow(dead_code)]
286 num_layers: usize,
287 #[allow(dead_code)]
289 hidden_dim: usize,
290 num_kv_heads: usize,
292}
293
294impl PagedKVCache {
295 pub fn new(
304 num_layers: usize,
305 max_seq_len: usize,
306 hidden_dim: usize,
307 num_kv_heads: usize,
308 block_size: usize,
309 ) -> Self {
310 let num_blocks = max_seq_len.div_ceil(block_size);
311 let head_dim = hidden_dim / num_kv_heads;
312
313 let key_blocks = vec![
314 DenseTensor::zeros(vec![num_blocks, block_size, num_kv_heads, head_dim]);
315 num_layers
316 ];
317 let value_blocks = vec![
318 DenseTensor::zeros(vec![num_blocks, block_size, num_kv_heads, head_dim]);
319 num_layers
320 ];
321
322 Self {
323 block_size,
324 key_blocks,
325 value_blocks,
326 block_table: Vec::new(),
327 current_len: 0,
328 max_seq_len,
329 num_layers,
330 hidden_dim,
331 num_kv_heads,
332 }
333 }
334
335 fn allocate_block(&mut self) -> Option<usize> {
337 if self.block_table.len() * self.block_size >= self.max_seq_len {
338 return None; }
340
341 let block_id = self.block_table.len();
342 self.block_table.push(block_id);
343 Some(block_id)
344 }
345
346 pub fn append(&mut self, layer: usize, key: &DenseTensor, value: &DenseTensor) {
353 if self.current_len >= self.max_seq_len {
354 return;
355 }
356
357 if self.current_len % self.block_size == 0 {
359 self.allocate_block();
360 }
361
362 let block_id = self.block_table.len().saturating_sub(1);
363 let block_offset = self.current_len % self.block_size;
364
365 if let Some(key_block) = self.key_blocks.get_mut(layer) {
366 Self::copy_to_block_static(key_block, block_id, block_offset, key, self.block_size, self.num_kv_heads);
367 }
368
369 if let Some(value_block) = self.value_blocks.get_mut(layer) {
370 Self::copy_to_block_static(value_block, block_id, block_offset, value, self.block_size, self.num_kv_heads);
371 }
372
373 self.current_len += 1;
374 }
375
376 #[inline]
378 fn copy_to_block_static(
379 block: &mut DenseTensor,
380 block_id: usize,
381 offset: usize,
382 tensor: &DenseTensor,
383 block_size: usize,
384 num_kv_heads: usize,
385 ) {
386 let head_dim = tensor.shape()[2];
387
388 for h in 0..num_kv_heads {
389 let src_offset = h * head_dim;
390 let dst_offset = ((block_id * block_size + offset) * num_kv_heads + h) * head_dim;
391
392 let src_slice = &tensor.data()[src_offset..src_offset + head_dim];
393 let block_data = block.data_mut();
394 block_data[dst_offset..dst_offset + head_dim].copy_from_slice(src_slice);
395 }
396 }
397
398 pub fn current_len(&self) -> usize {
400 self.current_len
401 }
402
403 pub fn num_blocks(&self) -> usize {
405 self.block_table.len()
406 }
407
408 pub fn block_table(&self) -> &[usize] {
410 &self.block_table
411 }
412
413 pub fn reset(&mut self) {
415 self.current_len = 0;
416 self.block_table.clear();
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_kv_cache_creation() {
426 let cache = KVCache::new(2, 512, 4096, 8);
427
428 assert_eq!(cache.num_layers(), 2);
429 assert_eq!(cache.max_seq_len(), 512);
430 assert_eq!(cache.hidden_dim(), 4096);
431 assert_eq!(cache.num_kv_heads(), 8);
432 assert_eq!(cache.current_len(), 0);
433 }
434
435 #[test]
436 fn test_kv_cache_update() {
437 let mut cache = KVCache::new(2, 512, 4096, 8);
438
439 let key = DenseTensor::ones(vec![1, 8, 512]);
440 let value = DenseTensor::ones(vec![1, 8, 512]);
441
442 cache.update(0, &key, &value, 0);
443
444 assert_eq!(cache.current_len(), 1);
445
446 let (cached_key, cached_value) = cache.get(0, Some(1)).unwrap();
447 assert_eq!(cached_key.shape(), &[1, 8, 512]);
448 assert_eq!(cached_value.shape(), &[1, 8, 512]);
449 }
450
451 #[test]
452 fn test_kv_cache_append() {
453 let mut cache = KVCache::new(2, 512, 4096, 8);
454
455 for i in 0..5 {
456 let key = DenseTensor::full(&vec![1, 8, 512], i as f64);
457 let value = DenseTensor::full(&vec![1, 8, 512], i as f64 * 2.0);
458 cache.append(0, &key, &value);
459 }
460
461 assert_eq!(cache.current_len(), 5);
462 assert!(!cache.is_full());
463 assert_eq!(cache.remaining_capacity(), 512 - 5);
464 }
465
466 #[test]
467 fn test_kv_cache_reset() {
468 let mut cache = KVCache::new(2, 512, 4096, 8);
469
470 let key = DenseTensor::ones(vec![1, 8, 512]);
471 let value = DenseTensor::ones(vec![1, 8, 512]);
472 cache.update(0, &key, &value, 0);
473
474 assert_eq!(cache.current_len(), 1);
475
476 cache.reset();
477
478 assert_eq!(cache.current_len(), 0);
479 }
480
481 #[test]
482 fn test_paged_kv_cache() {
483 let mut cache = PagedKVCache::new(2, 128, 4096, 8, 16);
484
485 for i in 0..20 {
486 let key = DenseTensor::full(&vec![1, 8, 512], i as f64);
487 let value = DenseTensor::full(&vec![1, 8, 512], i as f64);
488 cache.append(0, &key, &value);
489 }
490
491 assert_eq!(cache.current_len(), 20);
492 assert_eq!(cache.num_blocks(), 2); }
494
495 #[test]
496 fn test_gqa_kv_cache() {
497 let mut cache = KVCache::new(32, 8192, 4096, 8);
499
500 let key = DenseTensor::ones(vec![1, 8, 512]);
501 let value = DenseTensor::ones(vec![1, 8, 512]);
502
503 for layer in 0..32 {
504 cache.update(layer, &key, &value, 0);
505 }
506
507 assert_eq!(cache.num_layers(), 32);
508 assert_eq!(cache.num_kv_heads(), 8);
509 }
510}