1#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum KvCachePolicy {
10 #[default]
12 Standard,
13 Fp16,
15 SlidingWindow(usize),
17}
18
19#[derive(Debug)]
21pub struct KvCache {
22 num_layers: usize,
24 num_kv_heads: usize,
26 head_dim: usize,
28 max_seq_len: usize,
30 seq_len: usize,
32 keys: Vec<f32>,
34 values: Vec<f32>,
36}
37
38impl KvCache {
39 pub fn new(
41 num_layers: usize,
42 num_kv_heads: usize,
43 head_dim: usize,
44 max_seq_len: usize,
45 ) -> Self {
46 let total = num_layers * num_kv_heads * max_seq_len * head_dim;
47 Self {
48 num_layers,
49 num_kv_heads,
50 head_dim,
51 max_seq_len,
52 seq_len: 0,
53 keys: vec![0.0; total],
54 values: vec![0.0; total],
55 }
56 }
57
58 pub fn seq_len(&self) -> usize {
60 self.seq_len
61 }
62
63 pub fn max_seq_len(&self) -> usize {
65 self.max_seq_len
66 }
67
68 pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
70 debug_assert!(layer < self.num_layers);
71 debug_assert!(head < self.num_kv_heads);
72 debug_assert!(pos < self.max_seq_len);
73 debug_assert_eq!(key.len(), self.head_dim);
74
75 let offset = self.cache_offset(layer, head, pos);
76 self.keys[offset..offset + self.head_dim].copy_from_slice(key);
77 }
78
79 pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
81 debug_assert!(layer < self.num_layers);
82 debug_assert!(head < self.num_kv_heads);
83 debug_assert!(pos < self.max_seq_len);
84 debug_assert_eq!(value.len(), self.head_dim);
85
86 let offset = self.cache_offset(layer, head, pos);
87 self.values[offset..offset + self.head_dim].copy_from_slice(value);
88 }
89
90 pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
94 let start = self.cache_offset(layer, head, 0);
95 let end = start + seq_len * self.head_dim;
96 &self.keys[start..end]
97 }
98
99 pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
101 let start = self.cache_offset(layer, head, 0);
102 let end = start + seq_len * self.head_dim;
103 &self.values[start..end]
104 }
105
106 pub fn advance(&mut self) {
108 self.seq_len += 1;
109 }
110
111 pub fn clear(&mut self) {
113 self.seq_len = 0;
114 }
116
117 fn cache_offset(&self, layer: usize, head: usize, pos: usize) -> usize {
119 ((layer * self.num_kv_heads + head) * self.max_seq_len + pos) * self.head_dim
120 }
121
122 pub fn memory_bytes(&self) -> usize {
124 (self.keys.len() + self.values.len()) * std::mem::size_of::<f32>()
125 }
126
127 pub fn utilization_ratio(&self) -> f64 {
131 if self.max_seq_len == 0 {
132 return 0.0;
133 }
134 self.seq_len as f64 / self.max_seq_len as f64
135 }
136
137 pub fn num_layers(&self) -> usize {
139 self.num_layers
140 }
141
142 pub fn num_kv_heads(&self) -> usize {
144 self.num_kv_heads
145 }
146
147 pub fn head_dim(&self) -> usize {
149 self.head_dim
150 }
151
152 pub fn set_seq_len(&mut self, n: usize) {
161 self.seq_len = n.min(self.max_seq_len);
162 }
163
164 pub fn extract_block(
176 &self,
177 layer: usize,
178 start_pos: usize,
179 block_size: usize,
180 ) -> (Vec<f32>, Vec<f32>) {
181 debug_assert!(layer < self.num_layers);
182 let per_layer = self.num_kv_heads * block_size * self.head_dim;
183 let mut keys = vec![0.0f32; per_layer];
184 let mut values = vec![0.0f32; per_layer];
185
186 for head in 0..self.num_kv_heads {
187 for off in 0..block_size {
188 let pos = start_pos + off;
189 if pos >= self.max_seq_len {
190 continue;
191 }
192 let src = self.cache_offset(layer, head, pos);
193 let dst = (head * block_size + off) * self.head_dim;
194 keys[dst..dst + self.head_dim]
195 .copy_from_slice(&self.keys[src..src + self.head_dim]);
196 values[dst..dst + self.head_dim]
197 .copy_from_slice(&self.values[src..src + self.head_dim]);
198 }
199 }
200
201 (keys, values)
202 }
203
204 pub fn inject_block(
211 &mut self,
212 layer: usize,
213 start_pos: usize,
214 block_size: usize,
215 keys: &[f32],
216 values: &[f32],
217 ) {
218 debug_assert!(layer < self.num_layers);
219 let per_layer = self.num_kv_heads * block_size * self.head_dim;
220 debug_assert_eq!(keys.len(), per_layer);
221 debug_assert_eq!(values.len(), per_layer);
222
223 for head in 0..self.num_kv_heads {
224 for off in 0..block_size {
225 let pos = start_pos + off;
226 if pos >= self.max_seq_len {
227 continue;
228 }
229 let src = (head * block_size + off) * self.head_dim;
230 let dst = self.cache_offset(layer, head, pos);
231 self.keys[dst..dst + self.head_dim]
232 .copy_from_slice(&keys[src..src + self.head_dim]);
233 self.values[dst..dst + self.head_dim]
234 .copy_from_slice(&values[src..src + self.head_dim]);
235 }
236 }
237 }
238}
239
240const DEFAULT_PAGE_SIZE: usize = 256;
246
247#[derive(Debug, Clone)]
252struct KvPage {
253 keys: Vec<f32>,
255 values: Vec<f32>,
257 used: usize,
259}
260
261impl KvPage {
262 fn new(page_size: usize, head_dim: usize) -> Self {
263 Self {
264 keys: vec![0.0; page_size * head_dim],
265 values: vec![0.0; page_size * head_dim],
266 used: 0,
267 }
268 }
269}
270
271#[derive(Debug)]
278pub struct PagedKvCache {
279 pages: Vec<Vec<Vec<KvPage>>>,
281 num_layers: usize,
283 num_kv_heads: usize,
285 head_dim: usize,
287 page_size: usize,
289 max_seq_len: usize,
291 seq_len: usize,
293}
294
295impl PagedKvCache {
296 pub fn new(
300 num_layers: usize,
301 num_kv_heads: usize,
302 head_dim: usize,
303 max_seq_len: usize,
304 ) -> Self {
305 Self::with_page_size(
306 num_layers,
307 num_kv_heads,
308 head_dim,
309 max_seq_len,
310 DEFAULT_PAGE_SIZE,
311 )
312 }
313
314 pub fn with_page_size(
316 num_layers: usize,
317 num_kv_heads: usize,
318 head_dim: usize,
319 max_seq_len: usize,
320 page_size: usize,
321 ) -> Self {
322 let pages = (0..num_layers)
323 .map(|_| (0..num_kv_heads).map(|_| Vec::new()).collect())
324 .collect();
325
326 Self {
327 pages,
328 num_layers,
329 num_kv_heads,
330 head_dim,
331 page_size,
332 max_seq_len,
333 seq_len: 0,
334 }
335 }
336
337 pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
339 debug_assert!(layer < self.num_layers);
340 debug_assert!(head < self.num_kv_heads);
341 debug_assert!(pos < self.max_seq_len);
342 debug_assert_eq!(key.len(), self.head_dim);
343
344 let page_idx = pos / self.page_size;
345 let offset_in_page = pos % self.page_size;
346
347 self.ensure_page(layer, head, page_idx);
348
349 let page = &mut self.pages[layer][head][page_idx];
350 let start = offset_in_page * self.head_dim;
351 page.keys[start..start + self.head_dim].copy_from_slice(key);
352 if offset_in_page >= page.used {
353 page.used = offset_in_page + 1;
354 }
355 }
356
357 pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
359 debug_assert!(layer < self.num_layers);
360 debug_assert!(head < self.num_kv_heads);
361 debug_assert!(pos < self.max_seq_len);
362 debug_assert_eq!(value.len(), self.head_dim);
363
364 let page_idx = pos / self.page_size;
365 let offset_in_page = pos % self.page_size;
366
367 self.ensure_page(layer, head, page_idx);
368
369 let page = &mut self.pages[layer][head][page_idx];
370 let start = offset_in_page * self.head_dim;
371 page.values[start..start + self.head_dim].copy_from_slice(value);
372 if offset_in_page >= page.used {
373 page.used = offset_in_page + 1;
374 }
375 }
376
377 pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
379 let mut result = Vec::with_capacity(seq_len * self.head_dim);
380 let head_pages = &self.pages[layer][head];
381
382 for pos in 0..seq_len {
383 let page_idx = pos / self.page_size;
384 let offset_in_page = pos % self.page_size;
385
386 if page_idx < head_pages.len() {
387 let page = &head_pages[page_idx];
388 let start = offset_in_page * self.head_dim;
389 result.extend_from_slice(&page.keys[start..start + self.head_dim]);
390 } else {
391 result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
393 }
394 }
395
396 result
397 }
398
399 pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
401 let mut result = Vec::with_capacity(seq_len * self.head_dim);
402 let head_pages = &self.pages[layer][head];
403
404 for pos in 0..seq_len {
405 let page_idx = pos / self.page_size;
406 let offset_in_page = pos % self.page_size;
407
408 if page_idx < head_pages.len() {
409 let page = &head_pages[page_idx];
410 let start = offset_in_page * self.head_dim;
411 result.extend_from_slice(&page.values[start..start + self.head_dim]);
412 } else {
413 result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
414 }
415 }
416
417 result
418 }
419
420 pub fn seq_len(&self) -> usize {
422 self.seq_len
423 }
424
425 pub fn advance(&mut self) {
427 self.seq_len += 1;
428 }
429
430 pub fn clear(&mut self) {
432 self.seq_len = 0;
433 for layer_pages in &mut self.pages {
434 for head_pages in layer_pages.iter_mut() {
435 head_pages.clear();
436 }
437 }
438 }
439
440 pub fn memory_usage_bytes(&self) -> usize {
444 let mut total_pages = 0usize;
445 for layer_pages in &self.pages {
446 for head_pages in layer_pages {
447 total_pages += head_pages.len();
448 }
449 }
450 total_pages * self.page_size * self.head_dim * std::mem::size_of::<f32>() * 2
452 }
453
454 pub fn utilization_ratio(&self) -> f64 {
456 let mut total_slots = 0usize;
457 let mut used_slots = 0usize;
458 for layer_pages in &self.pages {
459 for head_pages in layer_pages {
460 for page in head_pages {
461 total_slots += self.page_size;
462 used_slots += page.used;
463 }
464 }
465 }
466 if total_slots == 0 {
467 return 0.0;
468 }
469 used_slots as f64 / total_slots as f64
470 }
471
472 pub fn total_pages(&self) -> usize {
474 let mut count = 0usize;
475 for layer_pages in &self.pages {
476 for head_pages in layer_pages {
477 count += head_pages.len();
478 }
479 }
480 count
481 }
482
483 pub fn page_size(&self) -> usize {
485 self.page_size
486 }
487
488 fn ensure_page(&mut self, layer: usize, head: usize, page_idx: usize) {
490 let head_pages = &mut self.pages[layer][head];
491 while head_pages.len() <= page_idx {
492 head_pages.push(KvPage::new(self.page_size, self.head_dim));
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn kv_cache_store_and_retrieve() {
503 let mut cache = KvCache::new(2, 8, 128, 16);
504
505 let key = vec![1.0f32; 128];
506 let value = vec![2.0f32; 128];
507
508 cache.store_key(0, 0, 0, &key);
509 cache.store_value(0, 0, 0, &value);
510 cache.advance();
511
512 let keys = cache.keys_for(0, 0, 1);
513 let values = cache.values_for(0, 0, 1);
514
515 assert_eq!(keys.len(), 128);
516 assert_eq!(values.len(), 128);
517 assert!((keys[0] - 1.0).abs() < 1e-5);
518 assert!((values[0] - 2.0).abs() < 1e-5);
519 }
520
521 #[test]
522 fn kv_cache_multiple_positions() {
523 let mut cache = KvCache::new(1, 1, 4, 8);
524
525 cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
526 cache.advance();
527 cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
528 cache.advance();
529
530 let keys = cache.keys_for(0, 0, 2);
531 assert_eq!(keys.len(), 8);
532 assert!((keys[0] - 1.0).abs() < 1e-5);
533 assert!((keys[4] - 5.0).abs() < 1e-5);
534 }
535
536 #[test]
537 fn kv_cache_memory_size() {
538 let cache = KvCache::new(36, 8, 128, 4096);
539 let expected = 36 * 8 * 4096 * 128 * 4 * 2;
541 assert_eq!(cache.memory_bytes(), expected);
542 }
543
544 #[test]
545 fn kv_cache_utilization() {
546 let mut cache = KvCache::new(1, 1, 4, 10);
547 assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
548
549 cache.advance();
550 cache.advance();
551 cache.advance();
552 assert!((cache.utilization_ratio() - 0.3).abs() < 1e-10);
553 }
554
555 #[test]
556 fn kv_cache_policy_default() {
557 let policy = KvCachePolicy::default();
558 assert_eq!(policy, KvCachePolicy::Standard);
559 }
560
561 #[test]
562 fn kv_cache_set_seq_len_clamps_to_max() {
563 let mut cache = KvCache::new(1, 1, 4, 8);
564 cache.set_seq_len(4);
565 assert_eq!(cache.seq_len(), 4);
566 cache.set_seq_len(100);
567 assert_eq!(cache.seq_len(), 8); }
569
570 #[test]
571 fn kv_cache_extract_inject_roundtrip() {
572 let num_layers = 2;
574 let num_kv_heads = 2;
575 let head_dim = 4;
576 let block_size = 4;
577 let max_seq = 16;
578 let mut cache = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
579
580 for head in 0..num_kv_heads {
582 for pos in 0..block_size {
583 let key: Vec<f32> = (0..head_dim)
584 .map(|d| (head as f32 + 1.0) * 100.0 + pos as f32 * 10.0 + d as f32)
585 .collect();
586 let value: Vec<f32> = (0..head_dim)
587 .map(|d| (head as f32 + 1.0) * 1000.0 + pos as f32 * 10.0 + d as f32)
588 .collect();
589 cache.store_key(1, head, pos, &key);
590 cache.store_value(1, head, pos, &value);
591 }
592 }
593
594 let (k_block, v_block) = cache.extract_block(1, 0, block_size);
596 let per_layer = num_kv_heads * block_size * head_dim;
597 assert_eq!(k_block.len(), per_layer);
598 assert_eq!(v_block.len(), per_layer);
599
600 let mut fresh = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
601 fresh.inject_block(1, 0, block_size, &k_block, &v_block);
602 fresh.set_seq_len(block_size);
603
604 let (k_block_2, v_block_2) = fresh.extract_block(1, 0, block_size);
605 assert_eq!(k_block_2, k_block);
606 assert_eq!(v_block_2, v_block);
607
608 for head in 0..num_kv_heads {
610 let original_keys = cache.keys_for(1, head, block_size);
611 let restored_keys = fresh.keys_for(1, head, block_size);
612 assert_eq!(
613 original_keys, restored_keys,
614 "head {head} keys must round-trip"
615 );
616 let original_values = cache.values_for(1, head, block_size);
617 let restored_values = fresh.values_for(1, head, block_size);
618 assert_eq!(
619 original_values, restored_values,
620 "head {head} values must round-trip"
621 );
622 }
623 }
624
625 #[test]
626 fn kv_cache_extract_inject_at_offset() {
627 let mut cache = KvCache::new(1, 1, 2, 16);
629 for pos in 0..4 {
631 let key = vec![pos as f32, pos as f32 + 0.5];
632 let value = vec![-(pos as f32), -(pos as f32) - 0.5];
633 cache.store_key(0, 0, 4 + pos, &key);
634 cache.store_value(0, 0, 4 + pos, &value);
635 }
636 let (k, v) = cache.extract_block(0, 4, 4);
637 let mut other = KvCache::new(1, 1, 2, 16);
638 other.inject_block(0, 4, 4, &k, &v);
639 for pos in 0..4 {
640 let original_k = cache.keys_for(0, 0, 8);
641 let restored_k = other.keys_for(0, 0, 8);
642 let off = (4 + pos) * 2;
644 assert!((restored_k[off] - original_k[off]).abs() < 1e-6);
645 assert!((restored_k[off + 1] - original_k[off + 1]).abs() < 1e-6);
646 }
647 }
648
649 #[test]
652 fn paged_kv_cache_store_and_retrieve() {
653 let mut cache = PagedKvCache::with_page_size(2, 1, 4, 16, 4);
654
655 let key = vec![1.0, 2.0, 3.0, 4.0];
656 let value = vec![5.0, 6.0, 7.0, 8.0];
657
658 cache.store_key(0, 0, 0, &key);
659 cache.store_value(0, 0, 0, &value);
660 cache.advance();
661
662 let keys = cache.keys_for(0, 0, 1);
663 let values = cache.values_for(0, 0, 1);
664
665 assert_eq!(keys.len(), 4);
666 assert_eq!(values.len(), 4);
667 assert!((keys[0] - 1.0).abs() < 1e-5);
668 assert!((values[0] - 5.0).abs() < 1e-5);
669 }
670
671 #[test]
672 fn paged_kv_cache_cross_page_boundary() {
673 let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 2);
674
675 cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
677 cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
678 cache.store_key(0, 0, 2, &[9.0, 10.0, 11.0, 12.0]);
680
681 let keys = cache.keys_for(0, 0, 3);
682 assert_eq!(keys.len(), 12);
683 assert!((keys[0] - 1.0).abs() < 1e-5);
684 assert!((keys[4] - 5.0).abs() < 1e-5);
685 assert!((keys[8] - 9.0).abs() < 1e-5);
686 }
687
688 #[test]
689 fn paged_kv_cache_lazy_allocation() {
690 let cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 256);
691 assert_eq!(cache.total_pages(), 0);
692 assert_eq!(cache.memory_usage_bytes(), 0);
693 }
694
695 #[test]
696 fn paged_kv_cache_memory_grows() {
697 let mut cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 4);
698
699 assert_eq!(cache.memory_usage_bytes(), 0);
700
701 cache.store_key(0, 0, 0, &[1.0; 4]);
702 let one_page_bytes = 4 * 4 * 4 * 2;
704 assert_eq!(cache.memory_usage_bytes(), one_page_bytes);
705
706 cache.store_key(0, 0, 4, &[1.0; 4]);
708 assert_eq!(cache.memory_usage_bytes(), one_page_bytes * 2);
709 }
710
711 #[test]
712 fn paged_kv_cache_clear() {
713 let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
714 cache.store_key(0, 0, 0, &[1.0; 4]);
715 cache.advance();
716
717 assert!(cache.total_pages() > 0);
718 cache.clear();
719 assert_eq!(cache.total_pages(), 0);
720 assert_eq!(cache.seq_len(), 0);
721 }
722
723 #[test]
724 fn paged_kv_cache_utilization() {
725 let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
726 assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
727
728 cache.store_key(0, 0, 0, &[1.0; 4]);
729 assert!((cache.utilization_ratio() - 0.25).abs() < 1e-10);
731 }
732}