Skip to main content

oxibonsai_runtime/
prefix_cache_engine.rs

1//! Prefix-cache-aware inference engine wrapper.
2//!
3//! [`PrefixCachedEngine`] wraps an [`InferenceEngine`] and transparently
4//! intercepts the prefill phase: identical prompt prefixes (e.g. a shared
5//! system prompt) are served from the KV-cache trie rather than being
6//! re-processed by the model, cutting prefill cost to near-zero for cached
7//! prefixes.
8//!
9//! ## Usage
10//!
11//! ```rust,no_run
12//! use oxibonsai_core::config::Qwen3Config;
13//! use oxibonsai_runtime::engine::InferenceEngine;
14//! use oxibonsai_runtime::sampling::SamplingParams;
15//! use oxibonsai_runtime::prefix_cache_engine::PrefixCachedEngine;
16//!
17//! let config = Qwen3Config::tiny_test();
18//! let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
19//! let mut cached = PrefixCachedEngine::new(engine, 64);
20//!
21//! let tokens = cached.generate(&[1, 2, 3, 4], &SamplingParams::default());
22//! let stats = cached.cache_stats();
23//! println!("hit rate: {:.1}%", stats.hit_rate * 100.0);
24//! ```
25//!
26//! ## Limitations
27//!
28//! Real prefix-cache reuse is only effective when the engine's forward
29//! path populates the CPU [`oxibonsai_model::KvCache`]. On Metal/CUDA tiers
30//! the GPU keeps its own KV state separate from the CPU cache; in that
31//! case the post-prefill extraction would yield all-zero tensors. This
32//! engine detects that case (the `real_cpu_kv` check below) and falls back
33//! to plain prefill without poisoning the trie. The session bookkeeping
34//! (hit-rate stats) still runs.
35
36use oxibonsai_model::prefix_cache::{
37    KvBlockPair, PrefixAwarePrefill, PrefixCache, PrefixCacheStats,
38};
39
40use crate::engine::{InferenceEngine, EOS_TOKEN_ID};
41use crate::sampling::SamplingParams;
42
43/// Tokens per cache block — must divide evenly into most prompt lengths.
44const BLOCK_SIZE: usize = 16;
45
46/// An [`InferenceEngine`] augmented with prefix KV-cache reuse.
47///
48/// On each [`generate`](PrefixCachedEngine::generate) call the engine:
49///
50/// 1. Resets the model's KV cache (single-engine, sequential request model).
51/// 2. Looks up the longest cached prefix in the trie.
52/// 3. Injects the matched KV blocks back into the model's CPU cache.
53/// 4. Runs prefill only on the uncached suffix at the correct `pos_start`.
54/// 5. Extracts any newly produced full blocks of KV state and stores them
55///    in the trie for subsequent requests (skipped on GPU tiers where the
56///    CPU cache stays empty).
57/// 6. Sample-decodes new tokens up to `params.max_tokens` or EOS.
58/// 7. Releases the session (decrements ref counts) when done.
59pub struct PrefixCachedEngine<'a> {
60    /// The underlying inference engine.
61    pub inner: InferenceEngine<'a>,
62    /// Prefix-cache-aware prefill helper with the block trie.
63    pub prefix_cache: PrefixAwarePrefill,
64}
65
66impl<'a> PrefixCachedEngine<'a> {
67    /// Wrap an existing [`InferenceEngine`] with a prefix cache.
68    ///
69    /// Derives `num_layers`, `num_kv_heads`, and `head_dim` directly from
70    /// the engine's model configuration, so no manual wiring is required.
71    ///
72    /// # Parameters
73    ///
74    /// - `engine` — the inference engine to wrap.
75    /// - `max_cache_blocks` — maximum number of simultaneously live cache
76    ///   blocks.  Each block holds `BLOCK_SIZE` (16) tokens of KV data for
77    ///   every layer; memory per block is approximately
78    ///   `2 × num_layers × num_kv_heads × head_dim × 16 × 4` bytes.
79    pub fn new(engine: InferenceEngine<'a>, max_cache_blocks: usize) -> Self {
80        let cfg = engine.model().config();
81        let cache = PrefixCache::new(
82            max_cache_blocks,
83            BLOCK_SIZE,
84            cfg.num_layers,
85            cfg.num_kv_heads,
86            cfg.head_dim,
87        );
88        let prefix_cache = PrefixAwarePrefill::new(cache);
89        Self {
90            inner: engine,
91            prefix_cache,
92        }
93    }
94
95    /// Generate tokens from `prompt_tokens`, reusing any cached prefix.
96    ///
97    /// Returns the generated token IDs (not including the prompt). On any
98    /// internal error the method logs via `tracing::warn` and returns an
99    /// empty vector — `generate` itself is infallible from the caller's
100    /// perspective so it can be dropped into batch pipelines.
101    pub fn generate(&mut self, prompt_tokens: &[u32], params: &SamplingParams) -> Vec<u32> {
102        if prompt_tokens.is_empty() {
103            return vec![];
104        }
105
106        // ── Step 1: reset model KV cache ─────────────────────────────────────
107        // We treat the wrapper as a single-engine, sequential request server.
108        self.inner.model_mut().reset();
109
110        // ── Step 2: query the prefix cache ───────────────────────────────────
111        let (session, uncached_start) = self.prefix_cache.prepare(prompt_tokens);
112        let block_size = self.prefix_cache.cache.block_size();
113        let cfg = self.inner.model().config().clone();
114        let num_layers = cfg.num_layers;
115
116        // ── Step 3: restore cached blocks into the model's CPU KV cache ──────
117        if uncached_start > 0 && !session.block_indices.is_empty() {
118            for (block_num, &bidx) in session.block_indices.iter().enumerate() {
119                if bidx == usize::MAX {
120                    continue;
121                }
122                // Snapshot keys/values per layer before mutably borrowing model.
123                let snapshots: Option<Vec<(Vec<f32>, Vec<f32>)>> =
124                    self.prefix_cache.cache.get_block(bidx).map(|block| {
125                        (0..num_layers)
126                            .map(|l| (block.keys[l].clone(), block.values[l].clone()))
127                            .collect()
128                    });
129                let snapshots = match snapshots {
130                    Some(s) => s,
131                    None => continue,
132                };
133                let block_start = block_num * block_size;
134                let kv = self.inner.model_mut().kv_cache_mut();
135                for (layer, (keys, values)) in snapshots.into_iter().enumerate() {
136                    kv.inject_block(layer, block_start, block_size, &keys, &values);
137                }
138            }
139            self.inner
140                .model_mut()
141                .kv_cache_mut()
142                .set_seq_len(uncached_start);
143        }
144
145        // ── Step 4: prefill on the uncached suffix only ──────────────────────
146        let mut last_logits = if uncached_start < prompt_tokens.len() {
147            match self
148                .inner
149                .prefill_from_pos(&prompt_tokens[uncached_start..], uncached_start)
150            {
151                Ok(logits) => logits,
152                Err(e) => {
153                    tracing::warn!(error = %e, "prefix-cache prefill failed");
154                    self.prefix_cache.release_session(session);
155                    return vec![];
156                }
157            }
158        } else {
159            // Entire prompt was cached — re-run the final token to get logits
160            // (we still need a fresh logits vector to drive the decode loop).
161            let last_pos = prompt_tokens.len().saturating_sub(1);
162            let last_tok = prompt_tokens[last_pos];
163            match self.inner.decode_step(last_tok, last_pos) {
164                Ok(logits) => logits,
165                Err(e) => {
166                    tracing::warn!(error = %e, "prefix-cache decode_step failed");
167                    self.prefix_cache.release_session(session);
168                    return vec![];
169                }
170            }
171        };
172
173        // ── Step 5: detect whether the CPU KV cache was actually populated ──
174        // GPU tiers (Metal/CUDA) maintain their own KV cache and leave the
175        // CPU `KvCache` untouched; in that case any extraction yields zeros
176        // which would silently corrupt the trie. We sample one layer/head/
177        // range and skip the store_blocks step if everything is zero.
178        let real_cpu_kv = {
179            let kv = self.inner.model().kv_cache();
180            let probe_len = prompt_tokens.len().min(kv.max_seq_len());
181            kv.keys_for(0, 0, probe_len).iter().any(|&x| x != 0.0)
182        };
183
184        // ── Step 6: store newly computed blocks into the trie ────────────────
185        if real_cpu_kv {
186            let new_blocks_count = prompt_tokens.len().saturating_sub(uncached_start) / block_size;
187            if new_blocks_count > 0 {
188                let mut keys_by_block: Vec<KvBlockPair> = Vec::with_capacity(new_blocks_count);
189                for blk in 0..new_blocks_count {
190                    let block_pos = uncached_start + blk * block_size;
191                    let mut layer_keys: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
192                    let mut layer_values: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
193                    for layer in 0..num_layers {
194                        let (k, v) = self
195                            .inner
196                            .model()
197                            .kv_cache()
198                            .extract_block(layer, block_pos, block_size);
199                        layer_keys.push(k);
200                        layer_values.push(v);
201                    }
202                    keys_by_block.push((layer_keys, layer_values));
203                }
204                self.prefix_cache
205                    .store_blocks(prompt_tokens, uncached_start, keys_by_block);
206            }
207        }
208
209        // ── Step 7: decode loop ──────────────────────────────────────────────
210        // Swap in a per-request sampler matching `params` so that the wrapper
211        // honours per-call sampling while leaving the engine's persistent
212        // sampler unchanged.
213        let mut output = Vec::with_capacity(params.max_tokens);
214        let mut sampler = crate::sampling::Sampler::new(params.clone(), 0);
215        for (pos, _) in (prompt_tokens.len()..).zip(0..params.max_tokens) {
216            let next_token = match sampler.sample(&last_logits) {
217                Ok(t) => t,
218                Err(e) => {
219                    tracing::warn!(error = %e, "prefix-cache sampler error");
220                    break;
221                }
222            };
223            if next_token == EOS_TOKEN_ID {
224                break;
225            }
226            output.push(next_token);
227            last_logits = match self.inner.decode_step(next_token, pos) {
228                Ok(l) => l,
229                Err(e) => {
230                    tracing::warn!(error = %e, "prefix-cache decode loop error");
231                    break;
232                }
233            };
234        }
235
236        // ── Step 8: release session ──────────────────────────────────────────
237        self.prefix_cache.release_session(session);
238        output
239    }
240
241    /// Return a snapshot of the current prefix-cache statistics.
242    pub fn cache_stats(&self) -> PrefixCacheStats {
243        self.prefix_cache.stats()
244    }
245
246    /// Clear all entries from the prefix cache.
247    ///
248    /// Does *not* reset the inner engine's KV cache.
249    pub fn clear_cache(&mut self) {
250        self.prefix_cache.cache.clear();
251    }
252}
253
254// ──────────────────────────────────────────────────────────────────
255// Tests
256// ──────────────────────────────────────────────────────────────────
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use oxibonsai_core::config::Qwen3Config;
262    use oxibonsai_model::model::BonsaiModel;
263
264    fn make_engine_no_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
265        let config = Qwen3Config::tiny_test();
266        let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
267        PrefixCachedEngine::new(engine, max_blocks)
268    }
269
270    /// Build a config small enough to keep test runtimes tight while still
271    /// satisfying the Q1_0_g128 constraint (in_features must be a multiple
272    /// of 128).
273    fn small_real_config() -> Qwen3Config {
274        Qwen3Config {
275            hidden_size: 128,
276            intermediate_size: 256,
277            num_layers: 1,
278            num_attention_heads: 4,
279            num_kv_heads: 2,
280            head_dim: 32,
281            vocab_size: 256,
282            max_context_length: 128,
283            rms_norm_eps: 1e-6,
284            rope_freq_base: 10_000.0,
285            architecture: "qwen3".to_string(),
286            model_name: "PrefixCacheTest".to_string(),
287        }
288    }
289
290    fn make_engine_with_real_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
291        use oxibonsai_kernels::{KernelDispatcher, KernelTier};
292        let config = small_real_config();
293        let model = BonsaiModel::new_for_testing_with_blocks(config);
294        // Pin the engine to the Reference (CPU) tier so the CPU KV cache is
295        // populated by the forward path. With auto_detect on a GPU host the
296        // GPU shortcut would bypass the CPU cache entirely.
297        let kernel = KernelDispatcher::with_tier(KernelTier::Reference);
298        let engine =
299            InferenceEngine::from_model_with_kernel(model, kernel, SamplingParams::default(), 42);
300        PrefixCachedEngine::new(engine, max_blocks)
301    }
302
303    #[test]
304    fn prefix_cached_engine_construction() {
305        let engine = make_engine_no_blocks(16);
306        let stats = engine.cache_stats();
307        assert_eq!(stats.cached_blocks, 0);
308        assert_eq!(stats.capacity_blocks, 16);
309    }
310
311    #[test]
312    fn prefix_cached_engine_generate_empty() {
313        let mut engine = make_engine_no_blocks(16);
314        let tokens = engine.generate(&[], &SamplingParams::default());
315        assert!(tokens.is_empty());
316    }
317
318    #[test]
319    fn prefix_cached_engine_clear_cache() {
320        let mut engine = make_engine_no_blocks(16);
321        // Run a generate so the cache might get some blocks.
322        let prompt: Vec<u32> = (0..32).collect();
323        let fast_params = SamplingParams {
324            max_tokens: 4,
325            top_k: 1,
326            temperature: 0.0,
327            ..SamplingParams::default()
328        };
329        let _ = engine.generate(&prompt, &fast_params);
330        engine.clear_cache();
331        let stats = engine.cache_stats();
332        assert_eq!(stats.cached_blocks, 0);
333    }
334
335    #[test]
336    fn prefix_cached_engine_stats_structure() {
337        let engine = make_engine_no_blocks(32);
338        let stats = engine.cache_stats();
339        assert_eq!(stats.capacity_blocks, 32);
340        assert!((stats.hit_rate - 0.0).abs() < f32::EPSILON);
341    }
342
343    #[test]
344    fn prefix_cached_engine_repeated_prompt_builds_cache() {
345        // Use a model with real blocks so the CPU KV cache is actually populated.
346        let mut engine = make_engine_with_real_blocks(32);
347        let prompt: Vec<u32> = (0..32).collect();
348        let fast_params = SamplingParams {
349            max_tokens: 1,
350            top_k: 1,
351            temperature: 0.0,
352            ..SamplingParams::default()
353        };
354
355        // First call: cold cache.
356        let _ = engine.generate(&prompt, &fast_params);
357        let stats_after_first = engine.cache_stats();
358
359        // Second call: same prompt; should record at least one hit and the
360        // cache should contain entries.
361        let _ = engine.generate(&prompt, &fast_params);
362        let stats_after_second = engine.cache_stats();
363
364        assert!(
365            stats_after_first.cached_blocks > 0,
366            "first call should have populated some cache blocks"
367        );
368        assert!(
369            stats_after_second.total_hits > 0,
370            "second call should record cache hits"
371        );
372    }
373
374    /// Acceptance criterion #5 from issue #2: a repeated prompt must
375    /// actually skip prefill work, not merely record bookkeeping hits.
376    #[test]
377    fn prefix_cached_engine_avoids_redundant_prefill_work() {
378        let mut engine = make_engine_with_real_blocks(64);
379        let prompt: Vec<u32> = (0..32).collect();
380        let fast_params = SamplingParams {
381            max_tokens: 2,
382            top_k: 1,
383            temperature: 0.0,
384            ..SamplingParams::default()
385        };
386
387        let out1 = engine.generate(&prompt, &fast_params);
388        let prefill_after_first = engine.inner.prefill_token_count();
389
390        let out2 = engine.generate(&prompt, &fast_params);
391        let prefill_after_second = engine.inner.prefill_token_count();
392
393        let second_call_prefill = prefill_after_second - prefill_after_first;
394        assert!(
395            second_call_prefill < prompt.len() as u64,
396            "second call prefilled {} tokens, expected < {} (prefix cache should have skipped some)",
397            second_call_prefill,
398            prompt.len()
399        );
400        assert!(
401            engine.cache_stats().total_hits > 0,
402            "cache should report hits"
403        );
404        // AC #3 from issue #2: cached path must produce identical output to
405        // the cold-cache path. With temperature=0 and top_k=1 the sampler is
406        // deterministic, so the two generations must match token-for-token.
407        assert_eq!(
408            out1, out2,
409            "AC #3: cached path must produce identical output ({:?} vs {:?})",
410            out1, out2
411        );
412    }
413}