1use std::collections::HashMap;
16use std::time::Instant;
17
18use oxillama_arch::traits::KvCacheAccess;
19
20use super::KvCache;
21
22#[derive(Debug, Clone)]
26pub struct PrefixCacheConfig {
27 pub max_entries: usize,
29 pub max_memory_bytes: usize,
31 pub min_prefix_len: usize,
33}
34
35impl Default for PrefixCacheConfig {
36 fn default() -> Self {
37 Self {
38 max_entries: 256,
39 max_memory_bytes: 512 * 1024 * 1024, min_prefix_len: 4,
41 }
42 }
43}
44
45#[derive(Clone)]
49pub struct CachedKvState {
50 keys: Vec<Vec<f32>>,
52 values: Vec<Vec<f32>>,
54 seq_len: usize,
56}
57
58impl CachedKvState {
59 pub fn new(keys: Vec<Vec<f32>>, values: Vec<Vec<f32>>, seq_len: usize) -> Self {
65 Self {
66 keys,
67 values,
68 seq_len,
69 }
70 }
71
72 pub fn seq_len(&self) -> usize {
74 self.seq_len
75 }
76
77 pub fn keys(&self) -> &[Vec<f32>] {
79 &self.keys
80 }
81
82 pub fn values(&self) -> &[Vec<f32>] {
84 &self.values
85 }
86
87 fn memory_bytes(&self) -> usize {
89 let float_count: usize = self
90 .keys
91 .iter()
92 .chain(self.values.iter())
93 .map(|v| v.len())
94 .sum();
95 float_count * std::mem::size_of::<f32>()
96 }
97}
98
99struct RadixNode {
103 tokens: Vec<u32>,
105 children: HashMap<u32, Box<RadixNode>>,
107 cached_kv: Option<CachedKvState>,
109 last_access: Instant,
111 ref_count: u32,
113}
114
115impl RadixNode {
116 fn new(tokens: Vec<u32>) -> Self {
118 Self {
119 tokens,
120 children: HashMap::new(),
121 cached_kv: None,
122 last_access: Instant::now(),
123 ref_count: 0,
124 }
125 }
126
127 fn lookup<'a>(
132 &'a mut self,
133 query: &[u32],
134 matched_so_far: usize,
135 ) -> Option<(usize, &'a CachedKvState)> {
136 let common = common_prefix_len(&self.tokens, query);
138 if common < self.tokens.len() {
139 return None;
143 }
144
145 let total_matched = matched_so_far + common;
146 let remaining = &query[common..];
147
148 self.last_access = Instant::now();
150
151 let mut best: Option<(usize, &'a CachedKvState)> = None;
153
154 if let Some(&first_token) = remaining.first() {
155 if let Some(child) = self.children.get_mut(&first_token) {
156 best = child.lookup(remaining, total_matched);
157 }
158 }
159
160 if best.is_none() {
162 if let Some(ref kv) = self.cached_kv {
163 best = Some((total_matched, kv));
164 }
165 }
166
167 best
168 }
169
170 fn insert(&mut self, tokens: &[u32], kv: CachedKvState) {
172 if tokens.is_empty() {
173 self.cached_kv = Some(kv);
174 self.last_access = Instant::now();
175 return;
176 }
177
178 let common = common_prefix_len(&self.tokens, tokens);
179
180 if common < self.tokens.len() {
181 self.split_at(common);
183 }
184
185 let remaining = &tokens[common..];
186 if remaining.is_empty() {
187 self.cached_kv = Some(kv);
188 self.last_access = Instant::now();
189 return;
190 }
191
192 let first = remaining[0];
193 let child = self
194 .children
195 .entry(first)
196 .or_insert_with(|| Box::new(RadixNode::new(remaining.to_vec())));
197
198 if child.tokens == remaining {
200 child.cached_kv = Some(kv);
201 child.last_access = Instant::now();
202 } else {
203 child.insert(remaining, kv);
204 }
205 }
206
207 fn split_at(&mut self, pos: usize) {
210 let suffix = self.tokens[pos..].to_vec();
211 let first_of_suffix = suffix[0];
212
213 let mut new_child = RadixNode::new(suffix);
214 new_child.children = std::mem::take(&mut self.children);
215 new_child.cached_kv = self.cached_kv.take();
216 new_child.last_access = self.last_access;
217 new_child.ref_count = self.ref_count;
218
219 self.tokens.truncate(pos);
220 self.children.insert(first_of_suffix, Box::new(new_child));
221 }
222
223 fn count_entries(&self) -> usize {
225 let mine = usize::from(self.cached_kv.is_some());
226 let children_count: usize = self.children.values().map(|c| c.count_entries()).sum();
227 mine + children_count
228 }
229
230 fn total_memory(&self) -> usize {
232 let mine = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
233 let children_mem: usize = self.children.values().map(|c| c.total_memory()).sum();
234 mine + children_mem
235 }
236
237 fn evict_lru_one(&mut self) -> usize {
241 let mut oldest_time = Instant::now();
243 let mut oldest_path: Option<Vec<u32>> = None;
244 let mut oldest_mem: usize = 0;
245
246 self.find_lru_candidate(&mut oldest_time, &mut oldest_path, &mut oldest_mem, &[]);
247
248 if let Some(path) = oldest_path {
249 self.remove_cached_at(&path)
250 } else {
251 0
252 }
253 }
254
255 fn find_lru_candidate(
257 &self,
258 oldest_time: &mut Instant,
259 oldest_path: &mut Option<Vec<u32>>,
260 oldest_mem: &mut usize,
261 prefix: &[u32],
262 ) {
263 if self.cached_kv.is_some() && self.ref_count == 0 && self.last_access < *oldest_time {
264 *oldest_time = self.last_access;
265 let mut path = prefix.to_vec();
266 path.extend_from_slice(&self.tokens);
267 *oldest_path = Some(path);
268 *oldest_mem = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
269 }
270
271 for child in self.children.values() {
272 let mut child_prefix = prefix.to_vec();
273 child_prefix.extend_from_slice(&self.tokens);
274 child.find_lru_candidate(oldest_time, oldest_path, oldest_mem, &child_prefix);
275 }
276 }
277
278 fn remove_cached_at(&mut self, path: &[u32]) -> usize {
282 let common = common_prefix_len(&self.tokens, path);
283 if common < self.tokens.len() {
284 return 0;
285 }
286
287 let remaining = &path[common..];
288 if remaining.is_empty() {
289 let freed = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
291 self.cached_kv = None;
292 return freed;
293 }
294
295 if let Some(&first) = remaining.first() {
296 if let Some(child) = self.children.get_mut(&first) {
297 let freed = child.remove_cached_at(remaining);
298 if child.cached_kv.is_none() && child.children.is_empty() {
300 self.children.remove(&first);
301 }
302 return freed;
303 }
304 }
305 0
306 }
307
308 fn clear_all(&mut self) {
310 self.cached_kv = None;
311 self.children.clear();
312 }
313}
314
315fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
317 a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
318}
319
320pub struct PrefixKvCache {
328 root: RadixNode,
330 config: PrefixCacheConfig,
332 hit_count: u64,
334 miss_count: u64,
336}
337
338impl PrefixKvCache {
339 pub fn new(config: PrefixCacheConfig) -> Self {
341 Self {
342 root: RadixNode::new(Vec::new()),
343 config,
344 hit_count: 0,
345 miss_count: 0,
346 }
347 }
348
349 pub fn lookup(&mut self, tokens: &[u32]) -> Option<(usize, &CachedKvState)> {
354 if tokens.is_empty() {
355 self.miss_count += 1;
356 return None;
357 }
358
359 let result = self.root.lookup(tokens, 0);
360
361 match result {
362 Some((matched, kv)) if matched >= self.config.min_prefix_len => {
363 self.hit_count += 1;
364 Some((matched, kv))
365 }
366 _ => {
367 self.miss_count += 1;
368 None
369 }
370 }
371 }
372
373 pub fn store(
379 &mut self,
380 tokens: &[u32],
381 kv_cache: &dyn KvCacheAccess,
382 seq_len: usize,
383 kv_dim: usize,
384 num_layers: usize,
385 ) {
386 if tokens.len() < self.config.min_prefix_len {
387 return;
388 }
389
390 let mut keys = Vec::with_capacity(num_layers);
392 let mut values = Vec::with_capacity(num_layers);
393
394 for layer in 0..num_layers {
395 let k = kv_cache.get_keys(layer).unwrap_or(&[]);
396 let v = kv_cache.get_values(layer).unwrap_or(&[]);
397 let end = seq_len * kv_dim;
398 keys.push(k[..end.min(k.len())].to_vec());
399 values.push(v[..end.min(v.len())].to_vec());
400 }
401
402 let snapshot = CachedKvState {
403 keys,
404 values,
405 seq_len,
406 };
407
408 self.root.insert(tokens, snapshot);
409
410 self.evict_lru();
412 }
413
414 pub fn store_snapshot(&mut self, tokens: &[u32], snapshot: CachedKvState) {
418 if tokens.len() < self.config.min_prefix_len {
419 return;
420 }
421 self.root.insert(tokens, snapshot);
422 self.evict_lru();
423 }
424
425 pub fn restore(cached: &CachedKvState, target: &mut KvCache) {
430 target.restore_from_snapshot(&cached.keys, &cached.values, cached.seq_len);
431 }
432
433 fn evict_lru(&mut self) {
435 while self.root.count_entries() > self.config.max_entries {
437 if self.root.evict_lru_one() == 0 {
438 break; }
440 }
441 while self.root.total_memory() > self.config.max_memory_bytes {
443 if self.root.evict_lru_one() == 0 {
444 break;
445 }
446 }
447 }
448
449 pub fn len(&self) -> usize {
451 self.root.count_entries()
452 }
453
454 pub fn is_empty(&self) -> bool {
456 self.root.count_entries() == 0
457 }
458
459 pub fn clear(&mut self) {
461 self.root.clear_all();
462 self.hit_count = 0;
463 self.miss_count = 0;
464 }
465
466 pub fn memory_usage(&self) -> usize {
468 self.root.total_memory()
469 }
470
471 pub fn hits(&self) -> u64 {
473 self.hit_count
474 }
475
476 pub fn misses(&self) -> u64 {
478 self.miss_count
479 }
480}
481
482#[cfg(test)]
485mod tests {
486 use super::*;
487 use oxillama_arch::traits::KvCacheAccess;
488
489 fn make_filled_cache(
492 num_layers: usize,
493 kv_dim: usize,
494 num_tokens: usize,
495 ) -> (KvCache, Vec<u32>) {
496 let mut cache = KvCache::new(num_layers, 128, kv_dim);
497 let tokens: Vec<u32> = (0..num_tokens as u32).collect();
498
499 for t in 0..num_tokens {
500 for layer in 0..num_layers {
501 let base = (layer * 1000 + t) as f32;
502 let key: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.01).collect();
503 let val: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.02).collect();
504 cache
505 .store_kv(layer, &key, &val)
506 .expect("store_kv should succeed");
507 }
508 cache.advance();
509 }
510
511 (cache, tokens)
512 }
513
514 fn default_config() -> PrefixCacheConfig {
515 PrefixCacheConfig {
516 max_entries: 64,
517 max_memory_bytes: 16 * 1024 * 1024,
518 min_prefix_len: 1,
519 }
520 }
521
522 #[test]
525 fn test_insert_and_lookup_exact() {
526 let mut pcache = PrefixKvCache::new(default_config());
527 let (cache, tokens) = make_filled_cache(2, 4, 5);
528
529 pcache.store(&tokens, &cache, 5, 4, 2);
530 assert_eq!(pcache.len(), 1);
531
532 let result = pcache.lookup(&tokens);
533 assert!(result.is_some());
534 let (matched, kv) = result.expect("lookup should succeed");
535 assert_eq!(matched, 5);
536 assert_eq!(kv.seq_len(), 5);
537 }
538
539 #[test]
540 fn test_lookup_longer_query_returns_cached_prefix() {
541 let mut pcache = PrefixKvCache::new(default_config());
542 let (cache, tokens) = make_filled_cache(2, 4, 5);
543
544 pcache.store(&tokens, &cache, 5, 4, 2);
545
546 let longer: Vec<u32> = (0..10).collect();
548 let result = pcache.lookup(&longer);
549 assert!(result.is_some());
550 let (matched, _) = result.expect("lookup should succeed");
551 assert_eq!(matched, 5);
552 }
553
554 #[test]
555 fn test_lookup_no_match_returns_none() {
556 let mut pcache = PrefixKvCache::new(default_config());
557 let (cache, tokens) = make_filled_cache(1, 4, 5);
558 pcache.store(&tokens, &cache, 5, 4, 1);
559
560 let other = vec![100, 200, 300];
562 let result = pcache.lookup(&other);
563 assert!(result.is_none());
564 }
565
566 #[test]
567 fn test_empty_cache_lookup_returns_none() {
568 let mut pcache = PrefixKvCache::new(default_config());
569 let result = pcache.lookup(&[1, 2, 3]);
570 assert!(result.is_none());
571 }
572
573 #[test]
574 fn test_empty_query_returns_none() {
575 let mut pcache = PrefixKvCache::new(default_config());
576 let result = pcache.lookup(&[]);
577 assert!(result.is_none());
578 }
579
580 #[test]
583 fn test_multiple_prefixes_with_shared_root() {
584 let mut pcache = PrefixKvCache::new(default_config());
585
586 let tokens_a = vec![0u32, 1, 2, 3, 4];
588 let tokens_b = vec![0u32, 1, 2, 10, 11];
589
590 let (cache_a, _) = make_filled_cache(1, 4, 5);
591 let (cache_b, _) = make_filled_cache(1, 4, 5);
592
593 pcache.store(&tokens_a, &cache_a, 5, 4, 1);
594 pcache.store(&tokens_b, &cache_b, 5, 4, 1);
595
596 assert_eq!(pcache.len(), 2);
597
598 let (m_a, _) = pcache.lookup(&tokens_a).expect("lookup A");
600 assert_eq!(m_a, 5);
601
602 let (m_b, _) = pcache.lookup(&tokens_b).expect("lookup B");
603 assert_eq!(m_b, 5);
604
605 let shared_only = vec![0u32, 1, 2];
610 let result = pcache.lookup(&shared_only);
611 assert!(result.is_none());
612 }
613
614 #[test]
617 fn test_lru_eviction_by_entries() {
618 let config = PrefixCacheConfig {
619 max_entries: 2,
620 max_memory_bytes: usize::MAX,
621 min_prefix_len: 1,
622 };
623 let mut pcache = PrefixKvCache::new(config);
624
625 for i in 0u32..3 {
626 let tokens = vec![100 + i, 200 + i];
627 let snapshot = CachedKvState {
628 keys: vec![vec![i as f32; 4]],
629 values: vec![vec![i as f32; 4]],
630 seq_len: 2,
631 };
632 pcache.store_snapshot(&tokens, snapshot);
633 }
634
635 assert!(pcache.len() <= 2);
637 }
638
639 #[test]
640 fn test_lru_eviction_by_memory() {
641 let config = PrefixCacheConfig {
643 max_entries: 100,
644 max_memory_bytes: 64, min_prefix_len: 1,
646 };
647 let mut pcache = PrefixKvCache::new(config);
648
649 for i in 0u32..5 {
650 let tokens = vec![100 + i, 200 + i];
651 let snapshot = CachedKvState {
652 keys: vec![vec![i as f32; 4]],
653 values: vec![vec![i as f32; 4]],
654 seq_len: 2,
655 };
656 pcache.store_snapshot(&tokens, snapshot);
657 }
658
659 assert!(pcache.memory_usage() <= 64);
660 }
661
662 #[test]
665 fn test_clear_resets_everything() {
666 let mut pcache = PrefixKvCache::new(default_config());
667 let (cache, tokens) = make_filled_cache(1, 4, 5);
668 pcache.store(&tokens, &cache, 5, 4, 1);
669
670 let _ = pcache.lookup(&tokens);
672
673 pcache.clear();
674
675 assert!(pcache.is_empty());
676 assert_eq!(pcache.len(), 0);
677 assert_eq!(pcache.memory_usage(), 0);
678 assert_eq!(pcache.hits(), 0);
679 assert_eq!(pcache.misses(), 0);
680 }
681
682 #[test]
685 fn test_store_and_restore_round_trip() {
686 let num_layers = 2;
687 let kv_dim = 4;
688 let num_tokens = 5;
689
690 let mut pcache = PrefixKvCache::new(default_config());
691 let (source_cache, tokens) = make_filled_cache(num_layers, kv_dim, num_tokens);
692
693 pcache.store(&tokens, &source_cache, num_tokens, kv_dim, num_layers);
694
695 let (_, cached_kv) = pcache.lookup(&tokens).expect("lookup must succeed");
696 let cached_kv_clone = cached_kv.clone();
697
698 let mut target = KvCache::new(num_layers, 128, kv_dim);
700 PrefixKvCache::restore(&cached_kv_clone, &mut target);
701
702 assert_eq!(target.seq_len(), num_tokens);
703
704 for layer in 0..num_layers {
706 let src_keys = source_cache.get_keys(layer).expect("get_keys");
707 let tgt_keys = target.get_keys(layer).expect("get_keys");
708 assert_eq!(src_keys.len(), tgt_keys.len(), "layer {layer} key length");
709 for (i, (&s, &t)) in src_keys.iter().zip(tgt_keys.iter()).enumerate() {
710 assert!(
711 (s - t).abs() < 1e-7,
712 "layer {layer} key[{i}]: source={s}, target={t}"
713 );
714 }
715
716 let src_vals = source_cache.get_values(layer).expect("get_values");
717 let tgt_vals = target.get_values(layer).expect("get_values");
718 assert_eq!(src_vals.len(), tgt_vals.len(), "layer {layer} value length");
719 for (i, (&s, &t)) in src_vals.iter().zip(tgt_vals.iter()).enumerate() {
720 assert!(
721 (s - t).abs() < 1e-7,
722 "layer {layer} value[{i}]: source={s}, target={t}"
723 );
724 }
725 }
726 }
727
728 #[test]
731 fn test_memory_usage_tracking() {
732 let mut pcache = PrefixKvCache::new(default_config());
733 assert_eq!(pcache.memory_usage(), 0);
734
735 let snapshot = CachedKvState {
737 keys: vec![vec![0.0f32; 8]], values: vec![vec![0.0f32; 8]],
739 seq_len: 2,
740 };
741 pcache.store_snapshot(&[1, 2], snapshot);
742
743 assert_eq!(pcache.memory_usage(), 64);
745 }
746
747 #[test]
750 fn test_hit_miss_counters() {
751 let mut pcache = PrefixKvCache::new(default_config());
752 assert_eq!(pcache.hits(), 0);
753 assert_eq!(pcache.misses(), 0);
754
755 let _ = pcache.lookup(&[1, 2, 3]);
757 assert_eq!(pcache.misses(), 1);
758 assert_eq!(pcache.hits(), 0);
759
760 let snapshot = CachedKvState {
762 keys: vec![vec![0.0; 4]],
763 values: vec![vec![0.0; 4]],
764 seq_len: 2,
765 };
766 pcache.store_snapshot(&[1, 2], snapshot);
767
768 let _ = pcache.lookup(&[1, 2]);
770 assert_eq!(pcache.hits(), 1);
771 assert_eq!(pcache.misses(), 1);
772
773 let _ = pcache.lookup(&[99, 100]);
775 assert_eq!(pcache.hits(), 1);
776 assert_eq!(pcache.misses(), 2);
777 }
778
779 #[test]
782 fn test_min_prefix_len_filters_short_store() {
783 let config = PrefixCacheConfig {
784 max_entries: 64,
785 max_memory_bytes: 16 * 1024 * 1024,
786 min_prefix_len: 5,
787 };
788 let mut pcache = PrefixKvCache::new(config);
789
790 let (cache, _) = make_filled_cache(1, 4, 3);
792 pcache.store(&[0, 1, 2], &cache, 3, 4, 1);
793
794 assert!(pcache.is_empty());
796 }
797
798 #[test]
799 fn test_min_prefix_len_filters_short_lookup() {
800 let config = PrefixCacheConfig {
801 max_entries: 64,
802 max_memory_bytes: 16 * 1024 * 1024,
803 min_prefix_len: 5,
804 };
805 let mut pcache = PrefixKvCache::new(config);
806
807 let (cache, tokens) = make_filled_cache(1, 4, 10);
809 pcache.store(&tokens, &cache, 10, 4, 1);
810 assert_eq!(pcache.len(), 1);
811
812 let short_query = vec![0u32, 1, 2];
815 let result = pcache.lookup(&short_query);
816 assert!(result.is_none());
817 }
818
819 #[test]
822 fn test_is_empty_and_len() {
823 let mut pcache = PrefixKvCache::new(default_config());
824 assert!(pcache.is_empty());
825 assert_eq!(pcache.len(), 0);
826
827 let snapshot = CachedKvState {
828 keys: vec![vec![0.0; 4]],
829 values: vec![vec![0.0; 4]],
830 seq_len: 2,
831 };
832 pcache.store_snapshot(&[1, 2], snapshot);
833
834 assert!(!pcache.is_empty());
835 assert_eq!(pcache.len(), 1);
836 }
837
838 #[test]
841 fn test_common_prefix_len() {
842 assert_eq!(common_prefix_len(&[], &[]), 0);
843 assert_eq!(common_prefix_len(&[1, 2, 3], &[]), 0);
844 assert_eq!(common_prefix_len(&[], &[1, 2, 3]), 0);
845 assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 3]), 3);
846 assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 4]), 2);
847 assert_eq!(common_prefix_len(&[1, 2, 3], &[4, 5, 6]), 0);
848 assert_eq!(common_prefix_len(&[1, 2], &[1, 2, 3, 4]), 2);
849 }
850
851 #[test]
854 fn test_node_split_preserves_data() {
855 let mut pcache = PrefixKvCache::new(default_config());
856
857 let snap_a = CachedKvState {
859 keys: vec![vec![1.0; 4]],
860 values: vec![vec![2.0; 4]],
861 seq_len: 4,
862 };
863 let snap_b = CachedKvState {
864 keys: vec![vec![3.0; 4]],
865 values: vec![vec![4.0; 4]],
866 seq_len: 4,
867 };
868
869 pcache.store_snapshot(&[1, 2, 3, 4], snap_a);
870 pcache.store_snapshot(&[1, 2, 5, 6], snap_b);
871
872 assert_eq!(pcache.len(), 2);
873
874 let (m_a, kv_a) = pcache.lookup(&[1, 2, 3, 4]).expect("lookup A");
876 assert_eq!(m_a, 4);
877 assert_eq!(kv_a.keys()[0][0], 1.0);
878
879 let (m_b, kv_b) = pcache.lookup(&[1, 2, 5, 6]).expect("lookup B");
880 assert_eq!(m_b, 4);
881 assert_eq!(kv_b.keys()[0][0], 3.0);
882 }
883}