Skip to main content

oxibonsai_model/
prefix_cache.rs

1//! Prefix KV-cache — share key/value tensors across requests with a common prefix.
2//!
3//! Architecture:
4//!   - A `CacheBlock` holds the KV tensors for one "block" of tokens (block_size tokens).
5//!   - Blocks are arranged in a trie keyed by token-id sequences.
6//!   - A `PrefixCache` owns the trie and enforces a capacity limit (max_blocks).
7//!   - Cache eviction uses LRU (Least Recently Used) policy tracked via a generation counter.
8
9use std::collections::HashMap;
10
11/// KV tensor pair for one block: (keys per layer, values per layer).
12pub type KvBlockPair = (Vec<Vec<f32>>, Vec<Vec<f32>>);
13
14// ──────────────────────────────────────────────────────────────────
15// CacheBlock
16// ──────────────────────────────────────────────────────────────────
17
18/// One cache block: KV tensors for `block_size` tokens in every layer.
19pub struct CacheBlock {
20    /// key tensors: [num_layers][num_kv_heads * head_dim * block_size] f32
21    pub keys: Vec<Vec<f32>>,
22    /// value tensors: [num_layers][num_kv_heads * head_dim * block_size] f32
23    pub values: Vec<Vec<f32>>,
24    /// The exact token IDs this block covers.
25    pub token_ids: Vec<u32>,
26    /// LRU generation counter — higher means more recently used.
27    pub last_used: u64,
28    /// How many live requests are currently using this block.
29    pub ref_count: usize,
30}
31
32impl CacheBlock {
33    /// Allocate a new, zeroed cache block.
34    pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize, block_size: usize) -> Self {
35        let per_layer = num_kv_heads * head_dim * block_size;
36        let keys = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
37        let values = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
38        Self {
39            keys,
40            values,
41            token_ids: Vec::new(),
42            last_used: 0,
43            ref_count: 0,
44        }
45    }
46
47    /// Total memory consumed by this block's KV tensors in bytes.
48    ///
49    /// Formula: 2 (K+V) × num_layers × per_layer_elements × 4 bytes/f32.
50    pub fn memory_bytes(&self) -> usize {
51        let per_layer = self.keys.first().map(|v| v.len()).unwrap_or(0);
52        // keys + values, each num_layers slices of per_layer f32s
53        2 * self.keys.len() * per_layer * std::mem::size_of::<f32>()
54    }
55}
56
57// ──────────────────────────────────────────────────────────────────
58// Trie internals
59// ──────────────────────────────────────────────────────────────────
60
61/// A node in the prefix trie.
62///
63/// Uses a `Vec`-based arena (indices into `PrefixCache::nodes`) so that
64/// all indices remain stable across insertions and evictions.
65struct TrieNode {
66    /// Maps token_id → child node index in the arena.
67    children: HashMap<u32, usize>,
68    /// Index into `PrefixCache::blocks` if this node holds a cached block.
69    block_idx: Option<usize>,
70}
71
72impl TrieNode {
73    fn new() -> Self {
74        Self {
75            children: HashMap::new(),
76            block_idx: None,
77        }
78    }
79}
80
81// ──────────────────────────────────────────────────────────────────
82// PrefixCache
83// ──────────────────────────────────────────────────────────────────
84
85/// Prefix KV-cache with trie-based lookup and LRU eviction.
86///
87/// The trie is keyed by complete blocks of `block_size` tokens.  Each
88/// internal node in the trie corresponds to one block boundary; leaf
89/// nodes that carry a `block_idx` have a fully populated `CacheBlock`.
90pub struct PrefixCache {
91    /// Arena of trie nodes.  Index 0 is always the root.
92    nodes: Vec<TrieNode>,
93    /// All allocated cache blocks (some may be logically free).
94    blocks: Vec<CacheBlock>,
95    /// Indices of blocks that are currently allocated (live).
96    /// (We track occupied block indices; eviction removes from here.)
97    occupied_blocks: Vec<usize>,
98    /// Pool of block slots that have been evicted and can be reused.
99    free_block_pool: Vec<usize>,
100    /// Maximum number of simultaneously live blocks.
101    max_blocks: usize,
102    /// Tokens per block.
103    block_size: usize,
104    num_layers: usize,
105    num_kv_heads: usize,
106    head_dim: usize,
107    /// Monotonically increasing counter used for LRU tracking.
108    generation: u64,
109    /// Total cache hits since creation.
110    pub hits: u64,
111    /// Total cache misses since creation.
112    pub misses: u64,
113    /// Total blocks evicted since creation.
114    pub evictions: u64,
115}
116
117impl PrefixCache {
118    /// Create a new, empty prefix cache.
119    pub fn new(
120        max_blocks: usize,
121        block_size: usize,
122        num_layers: usize,
123        num_kv_heads: usize,
124        head_dim: usize,
125    ) -> Self {
126        let root = TrieNode::new();
127        Self {
128            nodes: vec![root],
129            blocks: Vec::new(),
130            occupied_blocks: Vec::new(),
131            free_block_pool: Vec::new(),
132            max_blocks,
133            block_size,
134            num_layers,
135            num_kv_heads,
136            head_dim,
137            generation: 0,
138            hits: 0,
139            misses: 0,
140            evictions: 0,
141        }
142    }
143
144    // ── public API ─────────────────────────────────────────────────
145
146    /// Look up the longest cached prefix of `token_ids`.
147    ///
148    /// Walks the trie block-by-block.  For every complete block whose
149    /// tokens match and whose trie node carries a cached block, the
150    /// block's `last_used` stamp is refreshed and the block is returned.
151    ///
152    /// Returns `(matched_len, Vec<&CacheBlock>)`.
153    pub fn lookup(&mut self, token_ids: &[u32]) -> (usize, Vec<&CacheBlock>) {
154        let mut node_idx = 0usize; // root
155        let mut matched_len = 0usize;
156        let mut matched_block_indices: Vec<usize> = Vec::new();
157
158        let full_blocks = token_ids.len() / self.block_size;
159
160        for block_num in 0..full_blocks {
161            let block_start = block_num * self.block_size;
162            let block_end = block_start + self.block_size;
163            let block_tokens = &token_ids[block_start..block_end];
164
165            // All tokens in the block must follow the path in the trie.
166            // We encode an entire block as a single edge keyed by the *first*
167            // token of the block, using a compound key.  However, the trie is
168            // keyed per-token — to support block-level granularity we use the
169            // first token of the block as the edge key and store the full
170            // `token_ids` list on the CacheBlock for validation.
171            let edge_key = Self::block_edge_key(block_tokens);
172
173            match self.nodes[node_idx].children.get(&edge_key).copied() {
174                None => {
175                    // No match at this block level — stop.
176                    self.misses += 1;
177                    break;
178                }
179                Some(child_node_idx) => {
180                    // Check that the child node actually has a block.
181                    let maybe_block_idx = self.nodes[child_node_idx].block_idx;
182                    match maybe_block_idx {
183                        None => {
184                            self.misses += 1;
185                            break;
186                        }
187                        Some(bidx) => {
188                            // Validate token sequence matches exactly.
189                            if self.blocks[bidx].token_ids != block_tokens {
190                                self.misses += 1;
191                                break;
192                            }
193                            // Hit — refresh timestamp and continue.
194                            self.generation += 1;
195                            self.blocks[bidx].last_used = self.generation;
196                            self.blocks[bidx].ref_count += 1;
197                            matched_len += self.block_size;
198                            matched_block_indices.push(bidx);
199                            self.hits += 1;
200                            node_idx = child_node_idx;
201                        }
202                    }
203                }
204            }
205        }
206
207        // Collect immutable references (safe: we hold &mut self but return
208        // shared refs to elements of self.blocks, which is fine for the caller).
209        let block_refs: Vec<&CacheBlock> = matched_block_indices
210            .iter()
211            .map(|&bidx| &self.blocks[bidx])
212            .collect();
213
214        (matched_len, block_refs)
215    }
216
217    /// Insert a new block for `token_ids[block_start .. block_start + block_size]`.
218    ///
219    /// Evicts the LRU block if the cache is at capacity.
220    /// Returns the index of the inserted block in `self.blocks`.
221    pub fn insert(
222        &mut self,
223        token_ids: &[u32],
224        block_start: usize,
225        keys: Vec<Vec<f32>>,
226        values: Vec<Vec<f32>>,
227    ) -> usize {
228        // Evict if necessary.
229        while self.occupied_blocks.len() >= self.max_blocks {
230            if !self.evict_lru() {
231                // Nothing evictable — cache is pinned; caller must wait.
232                break;
233            }
234        }
235
236        let block_end = block_start + self.block_size;
237        let block_tokens = token_ids[block_start..block_end.min(token_ids.len())].to_vec();
238
239        // Navigate/build the trie path up to this block.
240        let mut node_idx = 0usize;
241        let num_full_blocks_before = block_start / self.block_size;
242
243        for blk in 0..num_full_blocks_before {
244            let seg_start = blk * self.block_size;
245            let seg_end = seg_start + self.block_size;
246            let seg = &token_ids[seg_start..seg_end];
247            let edge_key = Self::block_edge_key(seg);
248
249            if let Some(&child) = self.nodes[node_idx].children.get(&edge_key) {
250                node_idx = child;
251            } else {
252                // Intermediate node missing; create it (no block data).
253                let new_node_idx = self.nodes.len();
254                self.nodes.push(TrieNode::new());
255                self.nodes[node_idx].children.insert(edge_key, new_node_idx);
256                node_idx = new_node_idx;
257            }
258        }
259
260        // Insert/update the leaf node for this block.
261        let edge_key = Self::block_edge_key(&block_tokens);
262
263        let leaf_node_idx = if let Some(&existing) = self.nodes[node_idx].children.get(&edge_key) {
264            existing
265        } else {
266            let new_node_idx = self.nodes.len();
267            self.nodes.push(TrieNode::new());
268            self.nodes[node_idx].children.insert(edge_key, new_node_idx);
269            new_node_idx
270        };
271
272        // Assign or reuse a block slot.
273        self.generation += 1;
274        let block_idx = if let Some(reuse_idx) = self.free_block_pool.pop() {
275            // Reuse a previously evicted slot.
276            let block = &mut self.blocks[reuse_idx];
277            block.keys = keys;
278            block.values = values;
279            block.token_ids = block_tokens;
280            block.last_used = self.generation;
281            block.ref_count = 0;
282            reuse_idx
283        } else {
284            // Allocate a new slot.
285            let mut blk = CacheBlock::new(
286                self.num_layers,
287                self.num_kv_heads,
288                self.head_dim,
289                self.block_size,
290            );
291            blk.keys = keys;
292            blk.values = values;
293            blk.token_ids = block_tokens;
294            blk.last_used = self.generation;
295            blk.ref_count = 0;
296            let idx = self.blocks.len();
297            self.blocks.push(blk);
298            idx
299        };
300
301        self.nodes[leaf_node_idx].block_idx = Some(block_idx);
302        self.occupied_blocks.push(block_idx);
303
304        block_idx
305    }
306
307    /// Decrement the reference count of a block, making it eligible for eviction.
308    pub fn release(&mut self, block_idx: usize) {
309        if block_idx < self.blocks.len() && self.blocks[block_idx].ref_count > 0 {
310            self.blocks[block_idx].ref_count -= 1;
311        }
312    }
313
314    /// Number of currently live (occupied) blocks.
315    pub fn len(&self) -> usize {
316        self.occupied_blocks.len()
317    }
318
319    /// Returns `true` if the cache contains no live blocks.
320    pub fn is_empty(&self) -> bool {
321        self.occupied_blocks.is_empty()
322    }
323
324    /// Maximum number of blocks this cache can hold.
325    pub fn capacity(&self) -> usize {
326        self.max_blocks
327    }
328
329    /// Total memory consumed by all live blocks' KV tensors.
330    pub fn memory_bytes(&self) -> usize {
331        self.occupied_blocks
332            .iter()
333            .map(|&idx| self.blocks[idx].memory_bytes())
334            .sum()
335    }
336
337    /// Cache hit rate in [0, 1].  Returns 0.0 if no lookups have been made.
338    pub fn hit_rate(&self) -> f32 {
339        let total = self.hits + self.misses;
340        if total == 0 {
341            0.0
342        } else {
343            self.hits as f32 / total as f32
344        }
345    }
346
347    /// Tokens per block.
348    pub fn block_size(&self) -> usize {
349        self.block_size
350    }
351
352    /// Borrow a cached block by its index in the underlying arena.
353    ///
354    /// Block indices are returned by [`CacheSession::block_indices`] when a
355    /// session is prepared via
356    /// [`PrefixAwarePrefill::prepare`](crate::prefix_cache::PrefixAwarePrefill::prepare).
357    /// `None` means the index is out of bounds (e.g. the sentinel
358    /// `usize::MAX` placed for trie path failures).
359    pub fn get_block(&self, idx: usize) -> Option<&CacheBlock> {
360        self.blocks.get(idx)
361    }
362
363    /// Remove all cached blocks, resetting the trie to an empty root.
364    pub fn clear(&mut self) {
365        self.nodes.clear();
366        self.nodes.push(TrieNode::new());
367        self.blocks.clear();
368        self.occupied_blocks.clear();
369        self.free_block_pool.clear();
370        self.generation = 0;
371        // Statistics are intentionally preserved across clear().
372    }
373
374    // ── private helpers ────────────────────────────────────────────
375
376    /// Compute a single `u32` edge key that represents an entire block.
377    ///
378    /// We use a simple polynomial hash of the token IDs so that distinct
379    /// token sequences produce distinct keys with very high probability.
380    /// The trie node still stores the full `token_ids` in the `CacheBlock`
381    /// for exact-match validation on lookup.
382    fn block_edge_key(tokens: &[u32]) -> u32 {
383        let mut h: u64 = 0xcbf2_9ce4_8422_2325; // FNV-1a offset basis
384        for &t in tokens {
385            h ^= t as u64;
386            h = h.wrapping_mul(0x0000_0100_0000_01b3); // FNV-1a prime
387        }
388        // Fold to 32 bits.
389        ((h >> 32) ^ (h & 0xffff_ffff)) as u32
390    }
391
392    /// Evict the least-recently-used block with `ref_count == 0`.
393    ///
394    /// Returns `true` if a block was evicted, `false` if all blocks are pinned.
395    fn evict_lru(&mut self) -> bool {
396        // Find the occupied block index with the smallest `last_used` that has ref_count == 0.
397        let victim_pos = self
398            .occupied_blocks
399            .iter()
400            .enumerate()
401            .filter(|(_, &bidx)| self.blocks[bidx].ref_count == 0)
402            .min_by_key(|(_, &bidx)| self.blocks[bidx].last_used)
403            .map(|(pos, _)| pos);
404
405        let Some(pos) = victim_pos else {
406            return false;
407        };
408
409        let victim_block_idx = self.occupied_blocks.swap_remove(pos);
410
411        // Remove the corresponding trie node → block association.
412        // We search for the trie node whose block_idx == victim_block_idx.
413        for node in &mut self.nodes {
414            if node.block_idx == Some(victim_block_idx) {
415                node.block_idx = None;
416                break;
417            }
418        }
419
420        // Return the slot to the free pool for reuse.
421        self.free_block_pool.push(victim_block_idx);
422        self.evictions += 1;
423
424        true
425    }
426}
427
428// ──────────────────────────────────────────────────────────────────
429// CacheSession
430// ──────────────────────────────────────────────────────────────────
431
432/// A handle that tracks which cache blocks a specific request is using.
433///
434/// Holding a `CacheSession` keeps the referenced blocks' `ref_count`
435/// elevated so they will not be evicted while the request is in flight.
436pub struct CacheSession {
437    /// Number of prefix tokens that were already cached.
438    pub matched_prefix_len: usize,
439    /// Indices of the cache blocks matched for this session (in order).
440    pub block_indices: Vec<usize>,
441}
442
443impl CacheSession {
444    /// Create a new session handle.
445    pub fn new(matched_prefix_len: usize, block_indices: Vec<usize>) -> Self {
446        Self {
447            matched_prefix_len,
448            block_indices,
449        }
450    }
451
452    /// Number of tokens covered by cached blocks in this session.
453    ///
454    /// May differ slightly from `matched_prefix_len` if the prefix was not
455    /// an exact multiple of `block_size`, but in practice they will be equal.
456    pub fn cached_tokens(&self, block_size: usize) -> usize {
457        self.block_indices.len() * block_size
458    }
459
460    /// Returns `true` if no prefix tokens were cached.
461    pub fn is_empty(&self) -> bool {
462        self.block_indices.is_empty()
463    }
464}
465
466// ──────────────────────────────────────────────────────────────────
467// PrefixAwarePrefill
468// ──────────────────────────────────────────────────────────────────
469
470/// Wraps a [`PrefixCache`] and exposes a higher-level prefill API.
471///
472/// The typical call pattern for one request is:
473///
474/// ```text
475/// let (session, uncached_start) = prefill.prepare(&token_ids);
476/// // run your model prefill on token_ids[uncached_start..]
477/// prefill.store_blocks(&token_ids, uncached_start, new_kv_blocks);
478/// prefill.release_session(session);
479/// ```
480pub struct PrefixAwarePrefill {
481    /// The underlying prefix cache.
482    pub cache: PrefixCache,
483}
484
485impl PrefixAwarePrefill {
486    /// Wrap an existing `PrefixCache`.
487    pub fn new(cache: PrefixCache) -> Self {
488        Self { cache }
489    }
490
491    /// Determine how much of `token_ids` is already cached.
492    ///
493    /// Returns `(session, uncached_start)` where `uncached_start` is the
494    /// index of the first token that must be processed by the model.
495    pub fn prepare(&mut self, token_ids: &[u32]) -> (CacheSession, usize) {
496        // Phase 1: lookup to get matched_len and number of matched blocks.
497        // lookup() already increments ref_counts for matched blocks.
498        let (matched_len, matched_blocks) = self.cache.lookup(token_ids);
499        let num_matched = matched_blocks.len();
500        // Drop the borrowed references immediately.
501        drop(matched_blocks);
502
503        // Phase 2: recover the block indices by walking the trie (no borrow conflict now).
504        let block_indices: Vec<usize> = (0..num_matched)
505            .map(|blk_num| {
506                let block_start = blk_num * self.cache.block_size;
507                let block_tokens = &token_ids[block_start..block_start + self.cache.block_size];
508                let edge_key = PrefixCache::block_edge_key(block_tokens);
509                self.find_block_idx_for_edge(blk_num, token_ids, edge_key)
510            })
511            .collect();
512
513        let uncached_start = matched_len;
514        let session = CacheSession::new(matched_len, block_indices);
515        (session, uncached_start)
516    }
517
518    /// After prefill, store the newly computed KV blocks back into the cache.
519    ///
520    /// `keys_by_block` is a list of `(keys, values)` for each newly computed
521    /// block, in order, starting from the block at `uncached_start`.
522    pub fn store_blocks(
523        &mut self,
524        token_ids: &[u32],
525        uncached_start: usize,
526        keys_by_block: Vec<KvBlockPair>,
527    ) {
528        let block_size = self.cache.block_size;
529        for (i, (keys, values)) in keys_by_block.into_iter().enumerate() {
530            let block_start = uncached_start + i * block_size;
531            let block_end = block_start + block_size;
532            if block_end > token_ids.len() {
533                // Incomplete final block — do not cache partial blocks.
534                break;
535            }
536            self.cache.insert(token_ids, block_start, keys, values);
537        }
538    }
539
540    /// Release all blocks held by a session (decrement their ref counts).
541    pub fn release_session(&mut self, session: CacheSession) {
542        for bidx in session.block_indices {
543            self.cache.release(bidx);
544        }
545    }
546
547    /// Snapshot of current cache statistics.
548    pub fn stats(&self) -> PrefixCacheStats {
549        PrefixCacheStats {
550            hit_rate: self.cache.hit_rate(),
551            cached_blocks: self.cache.len(),
552            capacity_blocks: self.cache.capacity(),
553            memory_bytes: self.cache.memory_bytes(),
554            total_hits: self.cache.hits,
555            total_misses: self.cache.misses,
556            total_evictions: self.cache.evictions,
557        }
558    }
559
560    // ── private helpers ────────────────────────────────────────────
561
562    /// Walk the trie to find the block index for the block at position `blk_num`.
563    fn find_block_idx_for_edge(&self, blk_num: usize, token_ids: &[u32], edge_key: u32) -> usize {
564        // Navigate to the parent of the target node.
565        let mut node_idx = 0usize;
566        for blk in 0..blk_num {
567            let seg_start = blk * self.cache.block_size;
568            let seg_end = seg_start + self.cache.block_size;
569            let seg = &token_ids[seg_start..seg_end];
570            let parent_edge_key = PrefixCache::block_edge_key(seg);
571            if let Some(&child) = self.cache.nodes[node_idx].children.get(&parent_edge_key) {
572                node_idx = child;
573            } else {
574                // Trie path broken — return a sentinel.
575                return usize::MAX;
576            }
577        }
578        // Now look up the target child node.
579        if let Some(&child_idx) = self.cache.nodes[node_idx].children.get(&edge_key) {
580            self.cache.nodes[child_idx].block_idx.unwrap_or(usize::MAX)
581        } else {
582            usize::MAX
583        }
584    }
585}
586
587// ──────────────────────────────────────────────────────────────────
588// PrefixCacheStats
589// ──────────────────────────────────────────────────────────────────
590
591/// A snapshot of prefix-cache statistics for observability.
592#[derive(Debug, serde::Serialize)]
593pub struct PrefixCacheStats {
594    /// Fraction of lookups that found a cached block, in [0, 1].
595    pub hit_rate: f32,
596    /// Number of blocks currently in the cache.
597    pub cached_blocks: usize,
598    /// Maximum number of blocks the cache can hold.
599    pub capacity_blocks: usize,
600    /// Total memory consumed by KV data in bytes.
601    pub memory_bytes: usize,
602    /// Cumulative cache hits.
603    pub total_hits: u64,
604    /// Cumulative cache misses.
605    pub total_misses: u64,
606    /// Cumulative evictions.
607    pub total_evictions: u64,
608}
609
610// ──────────────────────────────────────────────────────────────────
611// Tests
612// ──────────────────────────────────────────────────────────────────
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    // Helper: build a CacheBlock with predictable data.
619    fn make_block(
620        num_layers: usize,
621        num_kv_heads: usize,
622        head_dim: usize,
623        block_size: usize,
624    ) -> CacheBlock {
625        CacheBlock::new(num_layers, num_kv_heads, head_dim, block_size)
626    }
627
628    // Helper: build key/value layer tensors filled with a constant.
629    fn make_kv(
630        num_layers: usize,
631        num_kv_heads: usize,
632        head_dim: usize,
633        block_size: usize,
634        val: f32,
635    ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
636        let per_layer = num_kv_heads * head_dim * block_size;
637        let keys: Vec<Vec<f32>> = (0..num_layers).map(|_| vec![val; per_layer]).collect();
638        let values: Vec<Vec<f32>> = (0..num_layers)
639            .map(|_| vec![val + 1.0; per_layer])
640            .collect();
641        (keys, values)
642    }
643
644    #[test]
645    fn test_cache_block_memory_bytes() {
646        // 2 layers, 4 heads, head_dim=8, block_size=4
647        // per_layer = 4 * 8 * 4 = 128 f32s
648        // memory = 2 (K+V) * 2 layers * 128 * 4 bytes = 2048
649        let block = make_block(2, 4, 8, 4);
650        let expected = 2 * 2 * (4 * 8 * 4) * std::mem::size_of::<f32>();
651        assert_eq!(block.memory_bytes(), expected);
652    }
653
654    #[test]
655    fn test_prefix_cache_insert_and_lookup_hit() {
656        let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
657        let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
658
659        let (keys, values) = make_kv(2, 2, 8, 4, 1.0);
660        cache.insert(&token_ids, 0, keys, values);
661
662        let (matched, blocks) = cache.lookup(&token_ids);
663        assert_eq!(matched, 4, "should match one full block of 4 tokens");
664        assert_eq!(blocks.len(), 1);
665        assert_eq!(cache.hits, 1);
666    }
667
668    #[test]
669    fn test_prefix_cache_lookup_miss() {
670        let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
671        let token_ids: Vec<u32> = vec![10, 20, 30, 40];
672
673        let (matched, blocks) = cache.lookup(&token_ids);
674        assert_eq!(matched, 0);
675        assert!(blocks.is_empty());
676        assert_eq!(cache.misses, 1);
677    }
678
679    #[test]
680    fn test_prefix_cache_partial_prefix_match() {
681        let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
682        // Insert block 0 (tokens 0..4)
683        let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
684        let (keys0, values0) = make_kv(2, 2, 8, 4, 0.5);
685        cache.insert(&token_ids, 0, keys0, values0);
686
687        // Query with same first block but different second block.
688        let query: Vec<u32> = vec![1, 2, 3, 4, 9, 10, 11, 12];
689        let (matched, blocks) = cache.lookup(&query);
690        // Should match first block (4 tokens) but miss on second.
691        assert_eq!(matched, 4);
692        assert_eq!(blocks.len(), 1);
693    }
694
695    #[test]
696    fn test_prefix_cache_lru_eviction() {
697        // max_blocks = 2 so inserting a third triggers eviction.
698        let mut cache = PrefixCache::new(2, 4, 1, 1, 4);
699
700        let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
701        let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
702        let tokens_c: Vec<u32> = vec![9, 10, 11, 12];
703
704        let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
705        let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
706        let (kc, vc) = make_kv(1, 1, 4, 4, 3.0);
707
708        cache.insert(&tokens_a, 0, ka, va);
709        cache.insert(&tokens_b, 0, kb, vb);
710        // Access token_b to make it more recently used than token_a.
711        let _ = cache.lookup(&tokens_b);
712        // Now insert token_c — should evict token_a (LRU).
713        cache.insert(&tokens_c, 0, kc, vc);
714
715        assert_eq!(
716            cache.len(),
717            2,
718            "should have exactly 2 blocks after eviction"
719        );
720        assert_eq!(cache.evictions, 1);
721
722        // token_a should no longer be found.
723        let (matched_a, _) = cache.lookup(&tokens_a);
724        assert_eq!(matched_a, 0, "evicted block should not be found");
725    }
726
727    #[test]
728    fn test_prefix_cache_ref_count_prevents_eviction() {
729        let mut cache = PrefixCache::new(1, 4, 1, 1, 4);
730
731        let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
732        let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
733
734        let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
735        let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
736
737        let bidx_a = cache.insert(&tokens_a, 0, ka, va);
738        // Pin block a by incrementing ref_count manually (simulates an active session).
739        cache.blocks[bidx_a].ref_count += 1;
740
741        // Inserting tokens_b when at capacity should fail to evict because bidx_a is pinned.
742        cache.insert(&tokens_b, 0, kb, vb);
743
744        // No evictions should have happened — the only eligible block was pinned.
745        assert_eq!(cache.evictions, 0, "pinned block must not be evicted");
746
747        // Release the manual pin.
748        cache.release(bidx_a);
749        assert_eq!(cache.blocks[bidx_a].ref_count, 0);
750    }
751
752    #[test]
753    fn test_prefix_cache_hit_rate() {
754        let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
755        let tokens: Vec<u32> = vec![1, 2, 3, 4];
756        let (k, v) = make_kv(1, 1, 4, 4, 1.0);
757        cache.insert(&tokens, 0, k, v);
758
759        // 1 hit
760        let _ = cache.lookup(&tokens);
761        // 1 miss
762        let _ = cache.lookup(&[99, 100, 101, 102]);
763
764        let rate = cache.hit_rate();
765        assert!(
766            (rate - 0.5).abs() < 1e-5,
767            "hit rate should be 0.5, got {rate}"
768        );
769    }
770
771    #[test]
772    fn test_prefix_cache_clear() {
773        let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
774        let tokens: Vec<u32> = vec![1, 2, 3, 4];
775        let (k, v) = make_kv(1, 1, 4, 4, 1.0);
776        cache.insert(&tokens, 0, k, v);
777        assert!(!cache.is_empty());
778
779        cache.clear();
780        assert!(cache.is_empty());
781        assert_eq!(cache.len(), 0);
782
783        // After clear, lookup should miss.
784        let (matched, _) = cache.lookup(&tokens);
785        assert_eq!(matched, 0);
786    }
787
788    #[test]
789    fn test_cache_session_cached_tokens() {
790        let session = CacheSession::new(8, vec![0, 1]);
791        assert_eq!(session.cached_tokens(4), 8);
792        assert!(!session.is_empty());
793
794        let empty = CacheSession::new(0, vec![]);
795        assert!(empty.is_empty());
796        assert_eq!(empty.cached_tokens(4), 0);
797    }
798
799    #[test]
800    fn test_prefix_aware_prefill_prepare() {
801        let inner = PrefixCache::new(8, 4, 1, 1, 4);
802        let mut prefill = PrefixAwarePrefill::new(inner);
803
804        // Insert a block for the first 4 tokens.
805        let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
806        let (k, v) = make_kv(1, 1, 4, 4, 1.0);
807        prefill.cache.insert(&token_ids, 0, k, v);
808
809        let (session, uncached_start) = prefill.prepare(&token_ids);
810        // First block (4 tokens) should be cached.
811        assert_eq!(session.matched_prefix_len, 4);
812        assert_eq!(uncached_start, 4);
813
814        prefill.release_session(session);
815    }
816
817    #[test]
818    fn test_prefix_cache_stats() {
819        let inner = PrefixCache::new(8, 4, 1, 1, 4);
820        let mut prefill = PrefixAwarePrefill::new(inner);
821
822        let token_ids: Vec<u32> = vec![1, 2, 3, 4];
823        let (k, v) = make_kv(1, 1, 4, 4, 1.0);
824        prefill.cache.insert(&token_ids, 0, k, v);
825
826        let _ = prefill.prepare(&token_ids);
827
828        let stats = prefill.stats();
829        assert!(stats.cached_blocks > 0 || stats.total_hits > 0 || stats.total_misses > 0);
830        assert_eq!(stats.capacity_blocks, 8);
831    }
832
833    #[test]
834    fn test_prefix_cache_capacity_enforcement() {
835        let capacity = 4usize;
836        let mut cache = PrefixCache::new(capacity, 4, 1, 1, 4);
837
838        for i in 0..capacity + 2 {
839            let tokens: Vec<u32> = (0..4).map(|j| (i * 4 + j) as u32).collect();
840            let (k, v) = make_kv(1, 1, 4, 4, i as f32);
841            cache.insert(&tokens, 0, k, v);
842        }
843
844        assert!(
845            cache.len() <= capacity,
846            "cache should not exceed max_blocks={capacity}, got {}",
847            cache.len()
848        );
849        assert!(
850            cache.evictions >= 2,
851            "should have evicted at least 2 blocks"
852        );
853    }
854}