1use std::collections::HashMap;
10
11pub type KvBlockPair = (Vec<Vec<f32>>, Vec<Vec<f32>>);
13
14pub struct CacheBlock {
20 pub keys: Vec<Vec<f32>>,
22 pub values: Vec<Vec<f32>>,
24 pub token_ids: Vec<u32>,
26 pub last_used: u64,
28 pub ref_count: usize,
30}
31
32impl CacheBlock {
33 pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize, block_size: usize) -> Self {
35 let per_layer = num_kv_heads * head_dim * block_size;
36 let keys = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
37 let values = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
38 Self {
39 keys,
40 values,
41 token_ids: Vec::new(),
42 last_used: 0,
43 ref_count: 0,
44 }
45 }
46
47 pub fn memory_bytes(&self) -> usize {
51 let per_layer = self.keys.first().map(|v| v.len()).unwrap_or(0);
52 2 * self.keys.len() * per_layer * std::mem::size_of::<f32>()
54 }
55}
56
57struct TrieNode {
66 children: HashMap<u32, usize>,
68 block_idx: Option<usize>,
70}
71
72impl TrieNode {
73 fn new() -> Self {
74 Self {
75 children: HashMap::new(),
76 block_idx: None,
77 }
78 }
79}
80
81pub struct PrefixCache {
91 nodes: Vec<TrieNode>,
93 blocks: Vec<CacheBlock>,
95 occupied_blocks: Vec<usize>,
98 free_block_pool: Vec<usize>,
100 max_blocks: usize,
102 block_size: usize,
104 num_layers: usize,
105 num_kv_heads: usize,
106 head_dim: usize,
107 generation: u64,
109 pub hits: u64,
111 pub misses: u64,
113 pub evictions: u64,
115}
116
117impl PrefixCache {
118 pub fn new(
120 max_blocks: usize,
121 block_size: usize,
122 num_layers: usize,
123 num_kv_heads: usize,
124 head_dim: usize,
125 ) -> Self {
126 let root = TrieNode::new();
127 Self {
128 nodes: vec![root],
129 blocks: Vec::new(),
130 occupied_blocks: Vec::new(),
131 free_block_pool: Vec::new(),
132 max_blocks,
133 block_size,
134 num_layers,
135 num_kv_heads,
136 head_dim,
137 generation: 0,
138 hits: 0,
139 misses: 0,
140 evictions: 0,
141 }
142 }
143
144 pub fn lookup(&mut self, token_ids: &[u32]) -> (usize, Vec<&CacheBlock>) {
154 let mut node_idx = 0usize; let mut matched_len = 0usize;
156 let mut matched_block_indices: Vec<usize> = Vec::new();
157
158 let full_blocks = token_ids.len() / self.block_size;
159
160 for block_num in 0..full_blocks {
161 let block_start = block_num * self.block_size;
162 let block_end = block_start + self.block_size;
163 let block_tokens = &token_ids[block_start..block_end];
164
165 let edge_key = Self::block_edge_key(block_tokens);
172
173 match self.nodes[node_idx].children.get(&edge_key).copied() {
174 None => {
175 self.misses += 1;
177 break;
178 }
179 Some(child_node_idx) => {
180 let maybe_block_idx = self.nodes[child_node_idx].block_idx;
182 match maybe_block_idx {
183 None => {
184 self.misses += 1;
185 break;
186 }
187 Some(bidx) => {
188 if self.blocks[bidx].token_ids != block_tokens {
190 self.misses += 1;
191 break;
192 }
193 self.generation += 1;
195 self.blocks[bidx].last_used = self.generation;
196 self.blocks[bidx].ref_count += 1;
197 matched_len += self.block_size;
198 matched_block_indices.push(bidx);
199 self.hits += 1;
200 node_idx = child_node_idx;
201 }
202 }
203 }
204 }
205 }
206
207 let block_refs: Vec<&CacheBlock> = matched_block_indices
210 .iter()
211 .map(|&bidx| &self.blocks[bidx])
212 .collect();
213
214 (matched_len, block_refs)
215 }
216
217 pub fn insert(
222 &mut self,
223 token_ids: &[u32],
224 block_start: usize,
225 keys: Vec<Vec<f32>>,
226 values: Vec<Vec<f32>>,
227 ) -> usize {
228 while self.occupied_blocks.len() >= self.max_blocks {
230 if !self.evict_lru() {
231 break;
233 }
234 }
235
236 let block_end = block_start + self.block_size;
237 let block_tokens = token_ids[block_start..block_end.min(token_ids.len())].to_vec();
238
239 let mut node_idx = 0usize;
241 let num_full_blocks_before = block_start / self.block_size;
242
243 for blk in 0..num_full_blocks_before {
244 let seg_start = blk * self.block_size;
245 let seg_end = seg_start + self.block_size;
246 let seg = &token_ids[seg_start..seg_end];
247 let edge_key = Self::block_edge_key(seg);
248
249 if let Some(&child) = self.nodes[node_idx].children.get(&edge_key) {
250 node_idx = child;
251 } else {
252 let new_node_idx = self.nodes.len();
254 self.nodes.push(TrieNode::new());
255 self.nodes[node_idx].children.insert(edge_key, new_node_idx);
256 node_idx = new_node_idx;
257 }
258 }
259
260 let edge_key = Self::block_edge_key(&block_tokens);
262
263 let leaf_node_idx = if let Some(&existing) = self.nodes[node_idx].children.get(&edge_key) {
264 existing
265 } else {
266 let new_node_idx = self.nodes.len();
267 self.nodes.push(TrieNode::new());
268 self.nodes[node_idx].children.insert(edge_key, new_node_idx);
269 new_node_idx
270 };
271
272 self.generation += 1;
274 let block_idx = if let Some(reuse_idx) = self.free_block_pool.pop() {
275 let block = &mut self.blocks[reuse_idx];
277 block.keys = keys;
278 block.values = values;
279 block.token_ids = block_tokens;
280 block.last_used = self.generation;
281 block.ref_count = 0;
282 reuse_idx
283 } else {
284 let mut blk = CacheBlock::new(
286 self.num_layers,
287 self.num_kv_heads,
288 self.head_dim,
289 self.block_size,
290 );
291 blk.keys = keys;
292 blk.values = values;
293 blk.token_ids = block_tokens;
294 blk.last_used = self.generation;
295 blk.ref_count = 0;
296 let idx = self.blocks.len();
297 self.blocks.push(blk);
298 idx
299 };
300
301 self.nodes[leaf_node_idx].block_idx = Some(block_idx);
302 self.occupied_blocks.push(block_idx);
303
304 block_idx
305 }
306
307 pub fn release(&mut self, block_idx: usize) {
309 if block_idx < self.blocks.len() && self.blocks[block_idx].ref_count > 0 {
310 self.blocks[block_idx].ref_count -= 1;
311 }
312 }
313
314 pub fn len(&self) -> usize {
316 self.occupied_blocks.len()
317 }
318
319 pub fn is_empty(&self) -> bool {
321 self.occupied_blocks.is_empty()
322 }
323
324 pub fn capacity(&self) -> usize {
326 self.max_blocks
327 }
328
329 pub fn memory_bytes(&self) -> usize {
331 self.occupied_blocks
332 .iter()
333 .map(|&idx| self.blocks[idx].memory_bytes())
334 .sum()
335 }
336
337 pub fn hit_rate(&self) -> f32 {
339 let total = self.hits + self.misses;
340 if total == 0 {
341 0.0
342 } else {
343 self.hits as f32 / total as f32
344 }
345 }
346
347 pub fn block_size(&self) -> usize {
349 self.block_size
350 }
351
352 pub fn get_block(&self, idx: usize) -> Option<&CacheBlock> {
360 self.blocks.get(idx)
361 }
362
363 pub fn clear(&mut self) {
365 self.nodes.clear();
366 self.nodes.push(TrieNode::new());
367 self.blocks.clear();
368 self.occupied_blocks.clear();
369 self.free_block_pool.clear();
370 self.generation = 0;
371 }
373
374 fn block_edge_key(tokens: &[u32]) -> u32 {
383 let mut h: u64 = 0xcbf2_9ce4_8422_2325; for &t in tokens {
385 h ^= t as u64;
386 h = h.wrapping_mul(0x0000_0100_0000_01b3); }
388 ((h >> 32) ^ (h & 0xffff_ffff)) as u32
390 }
391
392 fn evict_lru(&mut self) -> bool {
396 let victim_pos = self
398 .occupied_blocks
399 .iter()
400 .enumerate()
401 .filter(|(_, &bidx)| self.blocks[bidx].ref_count == 0)
402 .min_by_key(|(_, &bidx)| self.blocks[bidx].last_used)
403 .map(|(pos, _)| pos);
404
405 let Some(pos) = victim_pos else {
406 return false;
407 };
408
409 let victim_block_idx = self.occupied_blocks.swap_remove(pos);
410
411 for node in &mut self.nodes {
414 if node.block_idx == Some(victim_block_idx) {
415 node.block_idx = None;
416 break;
417 }
418 }
419
420 self.free_block_pool.push(victim_block_idx);
422 self.evictions += 1;
423
424 true
425 }
426}
427
428pub struct CacheSession {
437 pub matched_prefix_len: usize,
439 pub block_indices: Vec<usize>,
441}
442
443impl CacheSession {
444 pub fn new(matched_prefix_len: usize, block_indices: Vec<usize>) -> Self {
446 Self {
447 matched_prefix_len,
448 block_indices,
449 }
450 }
451
452 pub fn cached_tokens(&self, block_size: usize) -> usize {
457 self.block_indices.len() * block_size
458 }
459
460 pub fn is_empty(&self) -> bool {
462 self.block_indices.is_empty()
463 }
464}
465
466pub struct PrefixAwarePrefill {
481 pub cache: PrefixCache,
483}
484
485impl PrefixAwarePrefill {
486 pub fn new(cache: PrefixCache) -> Self {
488 Self { cache }
489 }
490
491 pub fn prepare(&mut self, token_ids: &[u32]) -> (CacheSession, usize) {
496 let (matched_len, matched_blocks) = self.cache.lookup(token_ids);
499 let num_matched = matched_blocks.len();
500 drop(matched_blocks);
502
503 let block_indices: Vec<usize> = (0..num_matched)
505 .map(|blk_num| {
506 let block_start = blk_num * self.cache.block_size;
507 let block_tokens = &token_ids[block_start..block_start + self.cache.block_size];
508 let edge_key = PrefixCache::block_edge_key(block_tokens);
509 self.find_block_idx_for_edge(blk_num, token_ids, edge_key)
510 })
511 .collect();
512
513 let uncached_start = matched_len;
514 let session = CacheSession::new(matched_len, block_indices);
515 (session, uncached_start)
516 }
517
518 pub fn store_blocks(
523 &mut self,
524 token_ids: &[u32],
525 uncached_start: usize,
526 keys_by_block: Vec<KvBlockPair>,
527 ) {
528 let block_size = self.cache.block_size;
529 for (i, (keys, values)) in keys_by_block.into_iter().enumerate() {
530 let block_start = uncached_start + i * block_size;
531 let block_end = block_start + block_size;
532 if block_end > token_ids.len() {
533 break;
535 }
536 self.cache.insert(token_ids, block_start, keys, values);
537 }
538 }
539
540 pub fn release_session(&mut self, session: CacheSession) {
542 for bidx in session.block_indices {
543 self.cache.release(bidx);
544 }
545 }
546
547 pub fn stats(&self) -> PrefixCacheStats {
549 PrefixCacheStats {
550 hit_rate: self.cache.hit_rate(),
551 cached_blocks: self.cache.len(),
552 capacity_blocks: self.cache.capacity(),
553 memory_bytes: self.cache.memory_bytes(),
554 total_hits: self.cache.hits,
555 total_misses: self.cache.misses,
556 total_evictions: self.cache.evictions,
557 }
558 }
559
560 fn find_block_idx_for_edge(&self, blk_num: usize, token_ids: &[u32], edge_key: u32) -> usize {
564 let mut node_idx = 0usize;
566 for blk in 0..blk_num {
567 let seg_start = blk * self.cache.block_size;
568 let seg_end = seg_start + self.cache.block_size;
569 let seg = &token_ids[seg_start..seg_end];
570 let parent_edge_key = PrefixCache::block_edge_key(seg);
571 if let Some(&child) = self.cache.nodes[node_idx].children.get(&parent_edge_key) {
572 node_idx = child;
573 } else {
574 return usize::MAX;
576 }
577 }
578 if let Some(&child_idx) = self.cache.nodes[node_idx].children.get(&edge_key) {
580 self.cache.nodes[child_idx].block_idx.unwrap_or(usize::MAX)
581 } else {
582 usize::MAX
583 }
584 }
585}
586
587#[derive(Debug, serde::Serialize)]
593pub struct PrefixCacheStats {
594 pub hit_rate: f32,
596 pub cached_blocks: usize,
598 pub capacity_blocks: usize,
600 pub memory_bytes: usize,
602 pub total_hits: u64,
604 pub total_misses: u64,
606 pub total_evictions: u64,
608}
609
610#[cfg(test)]
615mod tests {
616 use super::*;
617
618 fn make_block(
620 num_layers: usize,
621 num_kv_heads: usize,
622 head_dim: usize,
623 block_size: usize,
624 ) -> CacheBlock {
625 CacheBlock::new(num_layers, num_kv_heads, head_dim, block_size)
626 }
627
628 fn make_kv(
630 num_layers: usize,
631 num_kv_heads: usize,
632 head_dim: usize,
633 block_size: usize,
634 val: f32,
635 ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
636 let per_layer = num_kv_heads * head_dim * block_size;
637 let keys: Vec<Vec<f32>> = (0..num_layers).map(|_| vec![val; per_layer]).collect();
638 let values: Vec<Vec<f32>> = (0..num_layers)
639 .map(|_| vec![val + 1.0; per_layer])
640 .collect();
641 (keys, values)
642 }
643
644 #[test]
645 fn test_cache_block_memory_bytes() {
646 let block = make_block(2, 4, 8, 4);
650 let expected = 2 * 2 * (4 * 8 * 4) * std::mem::size_of::<f32>();
651 assert_eq!(block.memory_bytes(), expected);
652 }
653
654 #[test]
655 fn test_prefix_cache_insert_and_lookup_hit() {
656 let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
657 let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
658
659 let (keys, values) = make_kv(2, 2, 8, 4, 1.0);
660 cache.insert(&token_ids, 0, keys, values);
661
662 let (matched, blocks) = cache.lookup(&token_ids);
663 assert_eq!(matched, 4, "should match one full block of 4 tokens");
664 assert_eq!(blocks.len(), 1);
665 assert_eq!(cache.hits, 1);
666 }
667
668 #[test]
669 fn test_prefix_cache_lookup_miss() {
670 let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
671 let token_ids: Vec<u32> = vec![10, 20, 30, 40];
672
673 let (matched, blocks) = cache.lookup(&token_ids);
674 assert_eq!(matched, 0);
675 assert!(blocks.is_empty());
676 assert_eq!(cache.misses, 1);
677 }
678
679 #[test]
680 fn test_prefix_cache_partial_prefix_match() {
681 let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
682 let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
684 let (keys0, values0) = make_kv(2, 2, 8, 4, 0.5);
685 cache.insert(&token_ids, 0, keys0, values0);
686
687 let query: Vec<u32> = vec![1, 2, 3, 4, 9, 10, 11, 12];
689 let (matched, blocks) = cache.lookup(&query);
690 assert_eq!(matched, 4);
692 assert_eq!(blocks.len(), 1);
693 }
694
695 #[test]
696 fn test_prefix_cache_lru_eviction() {
697 let mut cache = PrefixCache::new(2, 4, 1, 1, 4);
699
700 let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
701 let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
702 let tokens_c: Vec<u32> = vec![9, 10, 11, 12];
703
704 let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
705 let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
706 let (kc, vc) = make_kv(1, 1, 4, 4, 3.0);
707
708 cache.insert(&tokens_a, 0, ka, va);
709 cache.insert(&tokens_b, 0, kb, vb);
710 let _ = cache.lookup(&tokens_b);
712 cache.insert(&tokens_c, 0, kc, vc);
714
715 assert_eq!(
716 cache.len(),
717 2,
718 "should have exactly 2 blocks after eviction"
719 );
720 assert_eq!(cache.evictions, 1);
721
722 let (matched_a, _) = cache.lookup(&tokens_a);
724 assert_eq!(matched_a, 0, "evicted block should not be found");
725 }
726
727 #[test]
728 fn test_prefix_cache_ref_count_prevents_eviction() {
729 let mut cache = PrefixCache::new(1, 4, 1, 1, 4);
730
731 let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
732 let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
733
734 let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
735 let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
736
737 let bidx_a = cache.insert(&tokens_a, 0, ka, va);
738 cache.blocks[bidx_a].ref_count += 1;
740
741 cache.insert(&tokens_b, 0, kb, vb);
743
744 assert_eq!(cache.evictions, 0, "pinned block must not be evicted");
746
747 cache.release(bidx_a);
749 assert_eq!(cache.blocks[bidx_a].ref_count, 0);
750 }
751
752 #[test]
753 fn test_prefix_cache_hit_rate() {
754 let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
755 let tokens: Vec<u32> = vec![1, 2, 3, 4];
756 let (k, v) = make_kv(1, 1, 4, 4, 1.0);
757 cache.insert(&tokens, 0, k, v);
758
759 let _ = cache.lookup(&tokens);
761 let _ = cache.lookup(&[99, 100, 101, 102]);
763
764 let rate = cache.hit_rate();
765 assert!(
766 (rate - 0.5).abs() < 1e-5,
767 "hit rate should be 0.5, got {rate}"
768 );
769 }
770
771 #[test]
772 fn test_prefix_cache_clear() {
773 let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
774 let tokens: Vec<u32> = vec![1, 2, 3, 4];
775 let (k, v) = make_kv(1, 1, 4, 4, 1.0);
776 cache.insert(&tokens, 0, k, v);
777 assert!(!cache.is_empty());
778
779 cache.clear();
780 assert!(cache.is_empty());
781 assert_eq!(cache.len(), 0);
782
783 let (matched, _) = cache.lookup(&tokens);
785 assert_eq!(matched, 0);
786 }
787
788 #[test]
789 fn test_cache_session_cached_tokens() {
790 let session = CacheSession::new(8, vec![0, 1]);
791 assert_eq!(session.cached_tokens(4), 8);
792 assert!(!session.is_empty());
793
794 let empty = CacheSession::new(0, vec![]);
795 assert!(empty.is_empty());
796 assert_eq!(empty.cached_tokens(4), 0);
797 }
798
799 #[test]
800 fn test_prefix_aware_prefill_prepare() {
801 let inner = PrefixCache::new(8, 4, 1, 1, 4);
802 let mut prefill = PrefixAwarePrefill::new(inner);
803
804 let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
806 let (k, v) = make_kv(1, 1, 4, 4, 1.0);
807 prefill.cache.insert(&token_ids, 0, k, v);
808
809 let (session, uncached_start) = prefill.prepare(&token_ids);
810 assert_eq!(session.matched_prefix_len, 4);
812 assert_eq!(uncached_start, 4);
813
814 prefill.release_session(session);
815 }
816
817 #[test]
818 fn test_prefix_cache_stats() {
819 let inner = PrefixCache::new(8, 4, 1, 1, 4);
820 let mut prefill = PrefixAwarePrefill::new(inner);
821
822 let token_ids: Vec<u32> = vec![1, 2, 3, 4];
823 let (k, v) = make_kv(1, 1, 4, 4, 1.0);
824 prefill.cache.insert(&token_ids, 0, k, v);
825
826 let _ = prefill.prepare(&token_ids);
827
828 let stats = prefill.stats();
829 assert!(stats.cached_blocks > 0 || stats.total_hits > 0 || stats.total_misses > 0);
830 assert_eq!(stats.capacity_blocks, 8);
831 }
832
833 #[test]
834 fn test_prefix_cache_capacity_enforcement() {
835 let capacity = 4usize;
836 let mut cache = PrefixCache::new(capacity, 4, 1, 1, 4);
837
838 for i in 0..capacity + 2 {
839 let tokens: Vec<u32> = (0..4).map(|j| (i * 4 + j) as u32).collect();
840 let (k, v) = make_kv(1, 1, 4, 4, i as f32);
841 cache.insert(&tokens, 0, k, v);
842 }
843
844 assert!(
845 cache.len() <= capacity,
846 "cache should not exceed max_blocks={capacity}, got {}",
847 cache.len()
848 );
849 assert!(
850 cache.evictions >= 2,
851 "should have evicted at least 2 blocks"
852 );
853 }
854}