Skip to main content

oxillama_runtime/kv_cache/
prefix.rs

1//! Radix-tree based prefix KV cache.
2//!
3//! Stores KV cache states indexed by token prefix sequences.  When a new prompt
4//! shares a prefix with a previously-cached sequence, the matching KV state is
5//! reused and only the remaining tokens need prefill.
6//!
7//! ## How it works
8//!
9//! 1. Token sequences are stored in a radix tree (trie with path compression).
10//! 2. Each node stores a segment of tokens and optional KV cache data.
11//! 3. On lookup, the tree walks down matching token prefixes.
12//! 4. The longest matching prefix's KV state can be directly restored.
13//! 5. LRU eviction removes least-recently-used entries when capacity is exceeded.
14
15use std::collections::HashMap;
16use std::time::Instant;
17
18use oxillama_arch::traits::KvCacheAccess;
19
20use super::KvCache;
21
22// ── Configuration ────────────────────────────────────────────────────────────
23
24/// Configuration for prefix KV caching.
25#[derive(Debug, Clone)]
26pub struct PrefixCacheConfig {
27    /// Maximum number of cached prefixes (nodes with KV data).
28    pub max_entries: usize,
29    /// Maximum total memory for cached KV states (bytes).
30    pub max_memory_bytes: usize,
31    /// Minimum prefix length to cache (tokens).
32    pub min_prefix_len: usize,
33}
34
35impl Default for PrefixCacheConfig {
36    fn default() -> Self {
37        Self {
38            max_entries: 256,
39            max_memory_bytes: 512 * 1024 * 1024, // 512 MiB
40            min_prefix_len: 4,
41        }
42    }
43}
44
45// ── Cached KV state ──────────────────────────────────────────────────────────
46
47/// Snapshot of KV cache state for a prefix.
48#[derive(Clone)]
49pub struct CachedKvState {
50    /// Per-layer key tensors flattened: `[layer][seq_pos * kv_dim]`.
51    keys: Vec<Vec<f32>>,
52    /// Per-layer value tensors flattened.
53    values: Vec<Vec<f32>>,
54    /// Number of tokens this state covers.
55    seq_len: usize,
56}
57
58impl CachedKvState {
59    /// Construct a new `CachedKvState` from pre-built KV buffers.
60    ///
61    /// This is the public constructor used when re-assembling a state from
62    /// cloned data (e.g. after releasing a `Mutex` lock on a `PrefixKvCache`).
63    /// `keys` and `values` must each have one inner `Vec<f32>` per layer.
64    pub fn new(keys: Vec<Vec<f32>>, values: Vec<Vec<f32>>, seq_len: usize) -> Self {
65        Self {
66            keys,
67            values,
68            seq_len,
69        }
70    }
71
72    /// Number of tokens this snapshot covers.
73    pub fn seq_len(&self) -> usize {
74        self.seq_len
75    }
76
77    /// Per-layer key buffers.
78    pub fn keys(&self) -> &[Vec<f32>] {
79        &self.keys
80    }
81
82    /// Per-layer value buffers.
83    pub fn values(&self) -> &[Vec<f32>] {
84        &self.values
85    }
86
87    /// Estimated memory usage in bytes.
88    fn memory_bytes(&self) -> usize {
89        let float_count: usize = self
90            .keys
91            .iter()
92            .chain(self.values.iter())
93            .map(|v| v.len())
94            .sum();
95        float_count * std::mem::size_of::<f32>()
96    }
97}
98
99// ── Radix tree node ──────────────────────────────────────────────────────────
100
101/// A node in the radix tree.
102struct RadixNode {
103    /// Token segment stored at this node (compressed path).
104    tokens: Vec<u32>,
105    /// Children keyed by the first token of their segment.
106    children: HashMap<u32, Box<RadixNode>>,
107    /// Cached KV data for this prefix (`None` for internal-only nodes).
108    cached_kv: Option<CachedKvState>,
109    /// Last access timestamp for LRU eviction.
110    last_access: Instant,
111    /// Reference count (how many active sequences use this prefix).
112    ref_count: u32,
113}
114
115impl RadixNode {
116    /// Create a new node with the given token segment.
117    fn new(tokens: Vec<u32>) -> Self {
118        Self {
119            tokens,
120            children: HashMap::new(),
121            cached_kv: None,
122            last_access: Instant::now(),
123            ref_count: 0,
124        }
125    }
126
127    /// Walk down the tree, returning the best (deepest) node that has cached
128    /// KV data whose prefix matches the query tokens.
129    ///
130    /// Returns `(matched_token_count, reference_to_node)`.
131    fn lookup<'a>(
132        &'a mut self,
133        query: &[u32],
134        matched_so_far: usize,
135    ) -> Option<(usize, &'a CachedKvState)> {
136        // Match this node's segment against the beginning of `query`.
137        let common = common_prefix_len(&self.tokens, query);
138        if common < self.tokens.len() {
139            // Partial match only — cannot descend further.
140            // Return the cached KV at this node only if the segment fully matched
141            // (it didn't, so nothing from this node).
142            return None;
143        }
144
145        let total_matched = matched_so_far + common;
146        let remaining = &query[common..];
147
148        // Update access time since we're visiting this node.
149        self.last_access = Instant::now();
150
151        // Try to descend into a child.
152        let mut best: Option<(usize, &'a CachedKvState)> = None;
153
154        if let Some(&first_token) = remaining.first() {
155            if let Some(child) = self.children.get_mut(&first_token) {
156                best = child.lookup(remaining, total_matched);
157            }
158        }
159
160        // If no deeper match found, use this node's cache (if any).
161        if best.is_none() {
162            if let Some(ref kv) = self.cached_kv {
163                best = Some((total_matched, kv));
164            }
165        }
166
167        best
168    }
169
170    /// Insert KV data at the leaf matching `tokens`, splitting nodes as needed.
171    fn insert(&mut self, tokens: &[u32], kv: CachedKvState) {
172        if tokens.is_empty() {
173            self.cached_kv = Some(kv);
174            self.last_access = Instant::now();
175            return;
176        }
177
178        let common = common_prefix_len(&self.tokens, tokens);
179
180        if common < self.tokens.len() {
181            // Need to split this node.
182            self.split_at(common);
183        }
184
185        let remaining = &tokens[common..];
186        if remaining.is_empty() {
187            self.cached_kv = Some(kv);
188            self.last_access = Instant::now();
189            return;
190        }
191
192        let first = remaining[0];
193        let child = self
194            .children
195            .entry(first)
196            .or_insert_with(|| Box::new(RadixNode::new(remaining.to_vec())));
197
198        // If the child already exists, recurse into it.
199        if child.tokens == remaining {
200            child.cached_kv = Some(kv);
201            child.last_access = Instant::now();
202        } else {
203            child.insert(remaining, kv);
204        }
205    }
206
207    /// Split this node at position `pos`, pushing the suffix (and all
208    /// children / cached data) into a new child node.
209    fn split_at(&mut self, pos: usize) {
210        let suffix = self.tokens[pos..].to_vec();
211        let first_of_suffix = suffix[0];
212
213        let mut new_child = RadixNode::new(suffix);
214        new_child.children = std::mem::take(&mut self.children);
215        new_child.cached_kv = self.cached_kv.take();
216        new_child.last_access = self.last_access;
217        new_child.ref_count = self.ref_count;
218
219        self.tokens.truncate(pos);
220        self.children.insert(first_of_suffix, Box::new(new_child));
221    }
222
223    /// Count the number of nodes that carry cached KV data.
224    fn count_entries(&self) -> usize {
225        let mine = usize::from(self.cached_kv.is_some());
226        let children_count: usize = self.children.values().map(|c| c.count_entries()).sum();
227        mine + children_count
228    }
229
230    /// Sum the estimated memory of all cached KV states in this subtree.
231    fn total_memory(&self) -> usize {
232        let mine = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
233        let children_mem: usize = self.children.values().map(|c| c.total_memory()).sum();
234        mine + children_mem
235    }
236
237    /// Find and remove the LRU eviction candidate in this subtree.
238    ///
239    /// Returns the memory freed (0 if nothing was evicted).
240    fn evict_lru_one(&mut self) -> usize {
241        // Collect candidates: this node and all descendants.
242        let mut oldest_time = Instant::now();
243        let mut oldest_path: Option<Vec<u32>> = None;
244        let mut oldest_mem: usize = 0;
245
246        self.find_lru_candidate(&mut oldest_time, &mut oldest_path, &mut oldest_mem, &[]);
247
248        if let Some(path) = oldest_path {
249            self.remove_cached_at(&path)
250        } else {
251            0
252        }
253    }
254
255    /// Recursively find the LRU candidate with `ref_count == 0`.
256    fn find_lru_candidate(
257        &self,
258        oldest_time: &mut Instant,
259        oldest_path: &mut Option<Vec<u32>>,
260        oldest_mem: &mut usize,
261        prefix: &[u32],
262    ) {
263        if self.cached_kv.is_some() && self.ref_count == 0 && self.last_access < *oldest_time {
264            *oldest_time = self.last_access;
265            let mut path = prefix.to_vec();
266            path.extend_from_slice(&self.tokens);
267            *oldest_path = Some(path);
268            *oldest_mem = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
269        }
270
271        for child in self.children.values() {
272            let mut child_prefix = prefix.to_vec();
273            child_prefix.extend_from_slice(&self.tokens);
274            child.find_lru_candidate(oldest_time, oldest_path, oldest_mem, &child_prefix);
275        }
276    }
277
278    /// Remove cached KV data at the node reached by following `path` tokens.
279    ///
280    /// Returns the memory freed.
281    fn remove_cached_at(&mut self, path: &[u32]) -> usize {
282        let common = common_prefix_len(&self.tokens, path);
283        if common < self.tokens.len() {
284            return 0;
285        }
286
287        let remaining = &path[common..];
288        if remaining.is_empty() {
289            // This is the target node.
290            let freed = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
291            self.cached_kv = None;
292            return freed;
293        }
294
295        if let Some(&first) = remaining.first() {
296            if let Some(child) = self.children.get_mut(&first) {
297                let freed = child.remove_cached_at(remaining);
298                // If the child is now empty (no cache, no children), prune it.
299                if child.cached_kv.is_none() && child.children.is_empty() {
300                    self.children.remove(&first);
301                }
302                return freed;
303            }
304        }
305        0
306    }
307
308    /// Clear all cached data in this subtree.
309    fn clear_all(&mut self) {
310        self.cached_kv = None;
311        self.children.clear();
312    }
313}
314
315/// Returns the length of the common prefix between two slices.
316fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
317    a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
318}
319
320// ── PrefixKvCache ────────────────────────────────────────────────────────────
321
322/// A radix-tree based prefix KV cache.
323///
324/// Stores KV cache states indexed by token prefix sequences. When a new prompt
325/// shares a prefix with a previously-cached sequence, the matching KV state is
326/// reused and only the remaining tokens need prefill.
327pub struct PrefixKvCache {
328    /// Root of the radix tree (has an empty token segment).
329    root: RadixNode,
330    /// Configuration.
331    config: PrefixCacheConfig,
332    /// Cache hit counter.
333    hit_count: u64,
334    /// Cache miss counter.
335    miss_count: u64,
336}
337
338impl PrefixKvCache {
339    /// Create a new prefix KV cache with the given configuration.
340    pub fn new(config: PrefixCacheConfig) -> Self {
341        Self {
342            root: RadixNode::new(Vec::new()),
343            config,
344            hit_count: 0,
345            miss_count: 0,
346        }
347    }
348
349    /// Look up the longest matching prefix for the given tokens.
350    ///
351    /// Returns `(matching_prefix_length, cached_kv_state_ref)`.  Returns `None`
352    /// if no prefix matches or the match is shorter than `min_prefix_len`.
353    pub fn lookup(&mut self, tokens: &[u32]) -> Option<(usize, &CachedKvState)> {
354        if tokens.is_empty() {
355            self.miss_count += 1;
356            return None;
357        }
358
359        let result = self.root.lookup(tokens, 0);
360
361        match result {
362            Some((matched, kv)) if matched >= self.config.min_prefix_len => {
363                self.hit_count += 1;
364                Some((matched, kv))
365            }
366            _ => {
367                self.miss_count += 1;
368                None
369            }
370        }
371    }
372
373    /// Store KV cache state for a token prefix.
374    ///
375    /// Extracts the relevant KV data from the live cache via the
376    /// [`KvCacheAccess`] trait.  If the prefix is shorter than
377    /// `min_prefix_len`, the store is silently skipped.
378    pub fn store(
379        &mut self,
380        tokens: &[u32],
381        kv_cache: &dyn KvCacheAccess,
382        seq_len: usize,
383        kv_dim: usize,
384        num_layers: usize,
385    ) {
386        if tokens.len() < self.config.min_prefix_len {
387            return;
388        }
389
390        // Snapshot the KV state from the live cache.
391        let mut keys = Vec::with_capacity(num_layers);
392        let mut values = Vec::with_capacity(num_layers);
393
394        for layer in 0..num_layers {
395            let k = kv_cache.get_keys(layer).unwrap_or(&[]);
396            let v = kv_cache.get_values(layer).unwrap_or(&[]);
397            let end = seq_len * kv_dim;
398            keys.push(k[..end.min(k.len())].to_vec());
399            values.push(v[..end.min(v.len())].to_vec());
400        }
401
402        let snapshot = CachedKvState {
403            keys,
404            values,
405            seq_len,
406        };
407
408        self.root.insert(tokens, snapshot);
409
410        // Evict if over limits.
411        self.evict_lru();
412    }
413
414    /// Store a pre-built [`CachedKvState`] directly for a token prefix.
415    ///
416    /// This is useful when the caller has already constructed the snapshot.
417    pub fn store_snapshot(&mut self, tokens: &[u32], snapshot: CachedKvState) {
418        if tokens.len() < self.config.min_prefix_len {
419            return;
420        }
421        self.root.insert(tokens, snapshot);
422        self.evict_lru();
423    }
424
425    /// Restore a cached prefix into a live KV cache.
426    ///
427    /// Copies the cached KV data into the target cache's buffers and resets
428    /// the target's sequence position to match the snapshot.
429    pub fn restore(cached: &CachedKvState, target: &mut KvCache) {
430        target.restore_from_snapshot(&cached.keys, &cached.values, cached.seq_len);
431    }
432
433    /// Evict least-recently-used entries until memory is under the limit.
434    fn evict_lru(&mut self) {
435        // Evict by entry count.
436        while self.root.count_entries() > self.config.max_entries {
437            if self.root.evict_lru_one() == 0 {
438                break; // No more evictable entries.
439            }
440        }
441        // Evict by memory.
442        while self.root.total_memory() > self.config.max_memory_bytes {
443            if self.root.evict_lru_one() == 0 {
444                break;
445            }
446        }
447    }
448
449    /// Current number of cached prefixes (nodes with KV data).
450    pub fn len(&self) -> usize {
451        self.root.count_entries()
452    }
453
454    /// Whether the cache is empty (no cached KV data).
455    pub fn is_empty(&self) -> bool {
456        self.root.count_entries() == 0
457    }
458
459    /// Clear all cached entries.
460    pub fn clear(&mut self) {
461        self.root.clear_all();
462        self.hit_count = 0;
463        self.miss_count = 0;
464    }
465
466    /// Current estimated memory usage in bytes.
467    pub fn memory_usage(&self) -> usize {
468        self.root.total_memory()
469    }
470
471    /// Number of cache hits since creation.
472    pub fn hits(&self) -> u64 {
473        self.hit_count
474    }
475
476    /// Number of cache misses since creation.
477    pub fn misses(&self) -> u64 {
478        self.miss_count
479    }
480}
481
482// ── Tests ────────────────────────────────────────────────────────────────────
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use oxillama_arch::traits::KvCacheAccess;
488
489    /// Helper: build a tiny KV cache, fill it with deterministic data for
490    /// `num_tokens` tokens, and return it along with the tokens used.
491    fn make_filled_cache(
492        num_layers: usize,
493        kv_dim: usize,
494        num_tokens: usize,
495    ) -> (KvCache, Vec<u32>) {
496        let mut cache = KvCache::new(num_layers, 128, kv_dim);
497        let tokens: Vec<u32> = (0..num_tokens as u32).collect();
498
499        for t in 0..num_tokens {
500            for layer in 0..num_layers {
501                let base = (layer * 1000 + t) as f32;
502                let key: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.01).collect();
503                let val: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.02).collect();
504                cache
505                    .store_kv(layer, &key, &val)
506                    .expect("store_kv should succeed");
507            }
508            cache.advance();
509        }
510
511        (cache, tokens)
512    }
513
514    fn default_config() -> PrefixCacheConfig {
515        PrefixCacheConfig {
516            max_entries: 64,
517            max_memory_bytes: 16 * 1024 * 1024,
518            min_prefix_len: 1,
519        }
520    }
521
522    // ── Basic insert / lookup ────────────────────────────────────────────
523
524    #[test]
525    fn test_insert_and_lookup_exact() {
526        let mut pcache = PrefixKvCache::new(default_config());
527        let (cache, tokens) = make_filled_cache(2, 4, 5);
528
529        pcache.store(&tokens, &cache, 5, 4, 2);
530        assert_eq!(pcache.len(), 1);
531
532        let result = pcache.lookup(&tokens);
533        assert!(result.is_some());
534        let (matched, kv) = result.expect("lookup should succeed");
535        assert_eq!(matched, 5);
536        assert_eq!(kv.seq_len(), 5);
537    }
538
539    #[test]
540    fn test_lookup_longer_query_returns_cached_prefix() {
541        let mut pcache = PrefixKvCache::new(default_config());
542        let (cache, tokens) = make_filled_cache(2, 4, 5);
543
544        pcache.store(&tokens, &cache, 5, 4, 2);
545
546        // Query with more tokens — should still match the cached 5-token prefix.
547        let longer: Vec<u32> = (0..10).collect();
548        let result = pcache.lookup(&longer);
549        assert!(result.is_some());
550        let (matched, _) = result.expect("lookup should succeed");
551        assert_eq!(matched, 5);
552    }
553
554    #[test]
555    fn test_lookup_no_match_returns_none() {
556        let mut pcache = PrefixKvCache::new(default_config());
557        let (cache, tokens) = make_filled_cache(1, 4, 5);
558        pcache.store(&tokens, &cache, 5, 4, 1);
559
560        // Completely different tokens.
561        let other = vec![100, 200, 300];
562        let result = pcache.lookup(&other);
563        assert!(result.is_none());
564    }
565
566    #[test]
567    fn test_empty_cache_lookup_returns_none() {
568        let mut pcache = PrefixKvCache::new(default_config());
569        let result = pcache.lookup(&[1, 2, 3]);
570        assert!(result.is_none());
571    }
572
573    #[test]
574    fn test_empty_query_returns_none() {
575        let mut pcache = PrefixKvCache::new(default_config());
576        let result = pcache.lookup(&[]);
577        assert!(result.is_none());
578    }
579
580    // ── Multiple prefixes with shared prefix ─────────────────────────────
581
582    #[test]
583    fn test_multiple_prefixes_with_shared_root() {
584        let mut pcache = PrefixKvCache::new(default_config());
585
586        // Two sequences that share tokens [0,1,2] but diverge after.
587        let tokens_a = vec![0u32, 1, 2, 3, 4];
588        let tokens_b = vec![0u32, 1, 2, 10, 11];
589
590        let (cache_a, _) = make_filled_cache(1, 4, 5);
591        let (cache_b, _) = make_filled_cache(1, 4, 5);
592
593        pcache.store(&tokens_a, &cache_a, 5, 4, 1);
594        pcache.store(&tokens_b, &cache_b, 5, 4, 1);
595
596        assert_eq!(pcache.len(), 2);
597
598        // Lookup each — should get exact match.
599        let (m_a, _) = pcache.lookup(&tokens_a).expect("lookup A");
600        assert_eq!(m_a, 5);
601
602        let (m_b, _) = pcache.lookup(&tokens_b).expect("lookup B");
603        assert_eq!(m_b, 5);
604
605        // Lookup shared prefix only — should match A or B (both have 5-len
606        // prefix starting with [0,1,2,…]; the shared subset is [0,1,2]).
607        // Since neither has a cached node at exactly 3 tokens, this should
608        // return None (no node at depth 3 has cached_kv).
609        let shared_only = vec![0u32, 1, 2];
610        let result = pcache.lookup(&shared_only);
611        assert!(result.is_none());
612    }
613
614    // ── LRU eviction ─────────────────────────────────────────────────────
615
616    #[test]
617    fn test_lru_eviction_by_entries() {
618        let config = PrefixCacheConfig {
619            max_entries: 2,
620            max_memory_bytes: usize::MAX,
621            min_prefix_len: 1,
622        };
623        let mut pcache = PrefixKvCache::new(config);
624
625        for i in 0u32..3 {
626            let tokens = vec![100 + i, 200 + i];
627            let snapshot = CachedKvState {
628                keys: vec![vec![i as f32; 4]],
629                values: vec![vec![i as f32; 4]],
630                seq_len: 2,
631            };
632            pcache.store_snapshot(&tokens, snapshot);
633        }
634
635        // Should have evicted one entry to stay at max_entries=2.
636        assert!(pcache.len() <= 2);
637    }
638
639    #[test]
640    fn test_lru_eviction_by_memory() {
641        // Each entry: 1 layer, 4 floats for keys + 4 floats for values = 32 bytes.
642        let config = PrefixCacheConfig {
643            max_entries: 100,
644            max_memory_bytes: 64, // room for ~2 entries
645            min_prefix_len: 1,
646        };
647        let mut pcache = PrefixKvCache::new(config);
648
649        for i in 0u32..5 {
650            let tokens = vec![100 + i, 200 + i];
651            let snapshot = CachedKvState {
652                keys: vec![vec![i as f32; 4]],
653                values: vec![vec![i as f32; 4]],
654                seq_len: 2,
655            };
656            pcache.store_snapshot(&tokens, snapshot);
657        }
658
659        assert!(pcache.memory_usage() <= 64);
660    }
661
662    // ── Clear ────────────────────────────────────────────────────────────
663
664    #[test]
665    fn test_clear_resets_everything() {
666        let mut pcache = PrefixKvCache::new(default_config());
667        let (cache, tokens) = make_filled_cache(1, 4, 5);
668        pcache.store(&tokens, &cache, 5, 4, 1);
669
670        // Trigger a hit.
671        let _ = pcache.lookup(&tokens);
672
673        pcache.clear();
674
675        assert!(pcache.is_empty());
676        assert_eq!(pcache.len(), 0);
677        assert_eq!(pcache.memory_usage(), 0);
678        assert_eq!(pcache.hits(), 0);
679        assert_eq!(pcache.misses(), 0);
680    }
681
682    // ── Store and restore round-trip ─────────────────────────────────────
683
684    #[test]
685    fn test_store_and_restore_round_trip() {
686        let num_layers = 2;
687        let kv_dim = 4;
688        let num_tokens = 5;
689
690        let mut pcache = PrefixKvCache::new(default_config());
691        let (source_cache, tokens) = make_filled_cache(num_layers, kv_dim, num_tokens);
692
693        pcache.store(&tokens, &source_cache, num_tokens, kv_dim, num_layers);
694
695        let (_, cached_kv) = pcache.lookup(&tokens).expect("lookup must succeed");
696        let cached_kv_clone = cached_kv.clone();
697
698        // Restore into a fresh KvCache.
699        let mut target = KvCache::new(num_layers, 128, kv_dim);
700        PrefixKvCache::restore(&cached_kv_clone, &mut target);
701
702        assert_eq!(target.seq_len(), num_tokens);
703
704        // Verify all data matches the source.
705        for layer in 0..num_layers {
706            let src_keys = source_cache.get_keys(layer).expect("get_keys");
707            let tgt_keys = target.get_keys(layer).expect("get_keys");
708            assert_eq!(src_keys.len(), tgt_keys.len(), "layer {layer} key length");
709            for (i, (&s, &t)) in src_keys.iter().zip(tgt_keys.iter()).enumerate() {
710                assert!(
711                    (s - t).abs() < 1e-7,
712                    "layer {layer} key[{i}]: source={s}, target={t}"
713                );
714            }
715
716            let src_vals = source_cache.get_values(layer).expect("get_values");
717            let tgt_vals = target.get_values(layer).expect("get_values");
718            assert_eq!(src_vals.len(), tgt_vals.len(), "layer {layer} value length");
719            for (i, (&s, &t)) in src_vals.iter().zip(tgt_vals.iter()).enumerate() {
720                assert!(
721                    (s - t).abs() < 1e-7,
722                    "layer {layer} value[{i}]: source={s}, target={t}"
723                );
724            }
725        }
726    }
727
728    // ── Memory tracking ──────────────────────────────────────────────────
729
730    #[test]
731    fn test_memory_usage_tracking() {
732        let mut pcache = PrefixKvCache::new(default_config());
733        assert_eq!(pcache.memory_usage(), 0);
734
735        // 1 layer, kv_dim=4, 2 tokens → keys: 8 floats, values: 8 floats = 64 bytes.
736        let snapshot = CachedKvState {
737            keys: vec![vec![0.0f32; 8]], // 2 tokens * kv_dim=4
738            values: vec![vec![0.0f32; 8]],
739            seq_len: 2,
740        };
741        pcache.store_snapshot(&[1, 2], snapshot);
742
743        // 8 floats * 4 bytes * 2 (keys + values) = 64 bytes.
744        assert_eq!(pcache.memory_usage(), 64);
745    }
746
747    // ── Hit / miss counters ──────────────────────────────────────────────
748
749    #[test]
750    fn test_hit_miss_counters() {
751        let mut pcache = PrefixKvCache::new(default_config());
752        assert_eq!(pcache.hits(), 0);
753        assert_eq!(pcache.misses(), 0);
754
755        // Miss on empty cache.
756        let _ = pcache.lookup(&[1, 2, 3]);
757        assert_eq!(pcache.misses(), 1);
758        assert_eq!(pcache.hits(), 0);
759
760        // Store something.
761        let snapshot = CachedKvState {
762            keys: vec![vec![0.0; 4]],
763            values: vec![vec![0.0; 4]],
764            seq_len: 2,
765        };
766        pcache.store_snapshot(&[1, 2], snapshot);
767
768        // Hit.
769        let _ = pcache.lookup(&[1, 2]);
770        assert_eq!(pcache.hits(), 1);
771        assert_eq!(pcache.misses(), 1);
772
773        // Another miss (different tokens).
774        let _ = pcache.lookup(&[99, 100]);
775        assert_eq!(pcache.hits(), 1);
776        assert_eq!(pcache.misses(), 2);
777    }
778
779    // ── min_prefix_len filter ────────────────────────────────────────────
780
781    #[test]
782    fn test_min_prefix_len_filters_short_store() {
783        let config = PrefixCacheConfig {
784            max_entries: 64,
785            max_memory_bytes: 16 * 1024 * 1024,
786            min_prefix_len: 5,
787        };
788        let mut pcache = PrefixKvCache::new(config);
789
790        // Try to store a 3-token prefix with min_prefix_len=5.
791        let (cache, _) = make_filled_cache(1, 4, 3);
792        pcache.store(&[0, 1, 2], &cache, 3, 4, 1);
793
794        // Should not have been stored.
795        assert!(pcache.is_empty());
796    }
797
798    #[test]
799    fn test_min_prefix_len_filters_short_lookup() {
800        let config = PrefixCacheConfig {
801            max_entries: 64,
802            max_memory_bytes: 16 * 1024 * 1024,
803            min_prefix_len: 5,
804        };
805        let mut pcache = PrefixKvCache::new(config);
806
807        // Store a long prefix.
808        let (cache, tokens) = make_filled_cache(1, 4, 10);
809        pcache.store(&tokens, &cache, 10, 4, 1);
810        assert_eq!(pcache.len(), 1);
811
812        // Lookup with a 3-token query. Even though 3 tokens match, the
813        // matched length (3) is below min_prefix_len (5), so it returns None.
814        let short_query = vec![0u32, 1, 2];
815        let result = pcache.lookup(&short_query);
816        assert!(result.is_none());
817    }
818
819    // ── is_empty / len ───────────────────────────────────────────────────
820
821    #[test]
822    fn test_is_empty_and_len() {
823        let mut pcache = PrefixKvCache::new(default_config());
824        assert!(pcache.is_empty());
825        assert_eq!(pcache.len(), 0);
826
827        let snapshot = CachedKvState {
828            keys: vec![vec![0.0; 4]],
829            values: vec![vec![0.0; 4]],
830            seq_len: 2,
831        };
832        pcache.store_snapshot(&[1, 2], snapshot);
833
834        assert!(!pcache.is_empty());
835        assert_eq!(pcache.len(), 1);
836    }
837
838    // ── common_prefix_len helper ─────────────────────────────────────────
839
840    #[test]
841    fn test_common_prefix_len() {
842        assert_eq!(common_prefix_len(&[], &[]), 0);
843        assert_eq!(common_prefix_len(&[1, 2, 3], &[]), 0);
844        assert_eq!(common_prefix_len(&[], &[1, 2, 3]), 0);
845        assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 3]), 3);
846        assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 4]), 2);
847        assert_eq!(common_prefix_len(&[1, 2, 3], &[4, 5, 6]), 0);
848        assert_eq!(common_prefix_len(&[1, 2], &[1, 2, 3, 4]), 2);
849    }
850
851    // ── Radix tree node splitting ────────────────────────────────────────
852
853    #[test]
854    fn test_node_split_preserves_data() {
855        let mut pcache = PrefixKvCache::new(default_config());
856
857        // Insert [1,2,3,4] then [1,2,5,6]. This forces a split at [1,2].
858        let snap_a = CachedKvState {
859            keys: vec![vec![1.0; 4]],
860            values: vec![vec![2.0; 4]],
861            seq_len: 4,
862        };
863        let snap_b = CachedKvState {
864            keys: vec![vec![3.0; 4]],
865            values: vec![vec![4.0; 4]],
866            seq_len: 4,
867        };
868
869        pcache.store_snapshot(&[1, 2, 3, 4], snap_a);
870        pcache.store_snapshot(&[1, 2, 5, 6], snap_b);
871
872        assert_eq!(pcache.len(), 2);
873
874        // Both lookups should still succeed.
875        let (m_a, kv_a) = pcache.lookup(&[1, 2, 3, 4]).expect("lookup A");
876        assert_eq!(m_a, 4);
877        assert_eq!(kv_a.keys()[0][0], 1.0);
878
879        let (m_b, kv_b) = pcache.lookup(&[1, 2, 5, 6]).expect("lookup B");
880        assert_eq!(m_b, 4);
881        assert_eq!(kv_b.keys()[0][0], 3.0);
882    }
883}