oxibonsai_runtime/prefix_cache_engine.rs
1//! Prefix-cache-aware inference engine wrapper.
2//!
3//! [`PrefixCachedEngine`] wraps an [`InferenceEngine`] and transparently
4//! intercepts the prefill phase: identical prompt prefixes (e.g. a shared
5//! system prompt) are served from the KV-cache trie rather than being
6//! re-processed by the model, cutting prefill cost to near-zero for cached
7//! prefixes.
8//!
9//! ## Usage
10//!
11//! ```rust,no_run
12//! use oxibonsai_core::config::Qwen3Config;
13//! use oxibonsai_runtime::engine::InferenceEngine;
14//! use oxibonsai_runtime::sampling::SamplingParams;
15//! use oxibonsai_runtime::prefix_cache_engine::PrefixCachedEngine;
16//!
17//! let config = Qwen3Config::tiny_test();
18//! let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
19//! let mut cached = PrefixCachedEngine::new(engine, 64);
20//!
21//! let tokens = cached.generate(&[1, 2, 3, 4], &SamplingParams::default());
22//! let stats = cached.cache_stats();
23//! println!("hit rate: {:.1}%", stats.hit_rate * 100.0);
24//! ```
25//!
26//! ## Limitations
27//!
28//! Real prefix-cache reuse is only effective when the engine's forward
29//! path populates the CPU [`oxibonsai_model::KvCache`]. On Metal/CUDA tiers
30//! the GPU keeps its own KV state separate from the CPU cache; in that
31//! case the post-prefill extraction would yield all-zero tensors. This
32//! engine detects that case (the `real_cpu_kv` check below) and falls back
33//! to plain prefill without poisoning the trie. The session bookkeeping
34//! (hit-rate stats) still runs.
35
36use oxibonsai_model::prefix_cache::{
37 KvBlockPair, PrefixAwarePrefill, PrefixCache, PrefixCacheStats,
38};
39
40use crate::engine::{InferenceEngine, EOS_TOKEN_ID};
41use crate::sampling::SamplingParams;
42
43/// Tokens per cache block — must divide evenly into most prompt lengths.
44const BLOCK_SIZE: usize = 16;
45
46/// An [`InferenceEngine`] augmented with prefix KV-cache reuse.
47///
48/// On each [`generate`](PrefixCachedEngine::generate) call the engine:
49///
50/// 1. Resets the model's KV cache (single-engine, sequential request model).
51/// 2. Looks up the longest cached prefix in the trie.
52/// 3. Injects the matched KV blocks back into the model's CPU cache.
53/// 4. Runs prefill only on the uncached suffix at the correct `pos_start`.
54/// 5. Extracts any newly produced full blocks of KV state and stores them
55/// in the trie for subsequent requests (skipped on GPU tiers where the
56/// CPU cache stays empty).
57/// 6. Sample-decodes new tokens up to `params.max_tokens` or EOS.
58/// 7. Releases the session (decrements ref counts) when done.
59pub struct PrefixCachedEngine<'a> {
60 /// The underlying inference engine.
61 pub inner: InferenceEngine<'a>,
62 /// Prefix-cache-aware prefill helper with the block trie.
63 pub prefix_cache: PrefixAwarePrefill,
64}
65
66impl<'a> PrefixCachedEngine<'a> {
67 /// Wrap an existing [`InferenceEngine`] with a prefix cache.
68 ///
69 /// Derives `num_layers`, `num_kv_heads`, and `head_dim` directly from
70 /// the engine's model configuration, so no manual wiring is required.
71 ///
72 /// # Parameters
73 ///
74 /// - `engine` — the inference engine to wrap.
75 /// - `max_cache_blocks` — maximum number of simultaneously live cache
76 /// blocks. Each block holds `BLOCK_SIZE` (16) tokens of KV data for
77 /// every layer; memory per block is approximately
78 /// `2 × num_layers × num_kv_heads × head_dim × 16 × 4` bytes.
79 pub fn new(engine: InferenceEngine<'a>, max_cache_blocks: usize) -> Self {
80 let cfg = engine.model().config();
81 let cache = PrefixCache::new(
82 max_cache_blocks,
83 BLOCK_SIZE,
84 cfg.num_layers,
85 cfg.num_kv_heads,
86 cfg.head_dim,
87 );
88 let prefix_cache = PrefixAwarePrefill::new(cache);
89 Self {
90 inner: engine,
91 prefix_cache,
92 }
93 }
94
95 /// Generate tokens from `prompt_tokens`, reusing any cached prefix.
96 ///
97 /// Returns the generated token IDs (not including the prompt). On any
98 /// internal error the method logs via `tracing::warn` and returns an
99 /// empty vector — `generate` itself is infallible from the caller's
100 /// perspective so it can be dropped into batch pipelines.
101 pub fn generate(&mut self, prompt_tokens: &[u32], params: &SamplingParams) -> Vec<u32> {
102 if prompt_tokens.is_empty() {
103 return vec![];
104 }
105
106 // ── Step 1: reset model KV cache ─────────────────────────────────────
107 // We treat the wrapper as a single-engine, sequential request server.
108 self.inner.model_mut().reset();
109
110 // ── Step 2: query the prefix cache ───────────────────────────────────
111 let (session, uncached_start) = self.prefix_cache.prepare(prompt_tokens);
112 let block_size = self.prefix_cache.cache.block_size();
113 let cfg = self.inner.model().config().clone();
114 let num_layers = cfg.num_layers;
115
116 // ── Step 3: restore cached blocks into the model's CPU KV cache ──────
117 if uncached_start > 0 && !session.block_indices.is_empty() {
118 for (block_num, &bidx) in session.block_indices.iter().enumerate() {
119 if bidx == usize::MAX {
120 continue;
121 }
122 // Snapshot keys/values per layer before mutably borrowing model.
123 let snapshots: Option<Vec<(Vec<f32>, Vec<f32>)>> =
124 self.prefix_cache.cache.get_block(bidx).map(|block| {
125 (0..num_layers)
126 .map(|l| (block.keys[l].clone(), block.values[l].clone()))
127 .collect()
128 });
129 let snapshots = match snapshots {
130 Some(s) => s,
131 None => continue,
132 };
133 let block_start = block_num * block_size;
134 let kv = self.inner.model_mut().kv_cache_mut();
135 for (layer, (keys, values)) in snapshots.into_iter().enumerate() {
136 kv.inject_block(layer, block_start, block_size, &keys, &values);
137 }
138 }
139 self.inner
140 .model_mut()
141 .kv_cache_mut()
142 .set_seq_len(uncached_start);
143 }
144
145 // ── Step 4: prefill on the uncached suffix only ──────────────────────
146 let mut last_logits = if uncached_start < prompt_tokens.len() {
147 match self
148 .inner
149 .prefill_from_pos(&prompt_tokens[uncached_start..], uncached_start)
150 {
151 Ok(logits) => logits,
152 Err(e) => {
153 tracing::warn!(error = %e, "prefix-cache prefill failed");
154 self.prefix_cache.release_session(session);
155 return vec![];
156 }
157 }
158 } else {
159 // Entire prompt was cached — re-run the final token to get logits
160 // (we still need a fresh logits vector to drive the decode loop).
161 let last_pos = prompt_tokens.len().saturating_sub(1);
162 let last_tok = prompt_tokens[last_pos];
163 match self.inner.decode_step(last_tok, last_pos) {
164 Ok(logits) => logits,
165 Err(e) => {
166 tracing::warn!(error = %e, "prefix-cache decode_step failed");
167 self.prefix_cache.release_session(session);
168 return vec![];
169 }
170 }
171 };
172
173 // ── Step 5: detect whether the CPU KV cache was actually populated ──
174 // GPU tiers (Metal/CUDA) maintain their own KV cache and leave the
175 // CPU `KvCache` untouched; in that case any extraction yields zeros
176 // which would silently corrupt the trie. We sample one layer/head/
177 // range and skip the store_blocks step if everything is zero.
178 let real_cpu_kv = {
179 let kv = self.inner.model().kv_cache();
180 let probe_len = prompt_tokens.len().min(kv.max_seq_len());
181 kv.keys_for(0, 0, probe_len).iter().any(|&x| x != 0.0)
182 };
183
184 // ── Step 6: store newly computed blocks into the trie ────────────────
185 if real_cpu_kv {
186 let new_blocks_count = prompt_tokens.len().saturating_sub(uncached_start) / block_size;
187 if new_blocks_count > 0 {
188 let mut keys_by_block: Vec<KvBlockPair> = Vec::with_capacity(new_blocks_count);
189 for blk in 0..new_blocks_count {
190 let block_pos = uncached_start + blk * block_size;
191 let mut layer_keys: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
192 let mut layer_values: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
193 for layer in 0..num_layers {
194 let (k, v) = self
195 .inner
196 .model()
197 .kv_cache()
198 .extract_block(layer, block_pos, block_size);
199 layer_keys.push(k);
200 layer_values.push(v);
201 }
202 keys_by_block.push((layer_keys, layer_values));
203 }
204 self.prefix_cache
205 .store_blocks(prompt_tokens, uncached_start, keys_by_block);
206 }
207 }
208
209 // ── Step 7: decode loop ──────────────────────────────────────────────
210 // Swap in a per-request sampler matching `params` so that the wrapper
211 // honours per-call sampling while leaving the engine's persistent
212 // sampler unchanged.
213 let mut output = Vec::with_capacity(params.max_tokens);
214 let mut sampler = crate::sampling::Sampler::new(params.clone(), 0);
215 for (pos, _) in (prompt_tokens.len()..).zip(0..params.max_tokens) {
216 let next_token = match sampler.sample(&last_logits) {
217 Ok(t) => t,
218 Err(e) => {
219 tracing::warn!(error = %e, "prefix-cache sampler error");
220 break;
221 }
222 };
223 if next_token == EOS_TOKEN_ID {
224 break;
225 }
226 output.push(next_token);
227 last_logits = match self.inner.decode_step(next_token, pos) {
228 Ok(l) => l,
229 Err(e) => {
230 tracing::warn!(error = %e, "prefix-cache decode loop error");
231 break;
232 }
233 };
234 }
235
236 // ── Step 8: release session ──────────────────────────────────────────
237 self.prefix_cache.release_session(session);
238 output
239 }
240
241 /// Return a snapshot of the current prefix-cache statistics.
242 pub fn cache_stats(&self) -> PrefixCacheStats {
243 self.prefix_cache.stats()
244 }
245
246 /// Clear all entries from the prefix cache.
247 ///
248 /// Does *not* reset the inner engine's KV cache.
249 pub fn clear_cache(&mut self) {
250 self.prefix_cache.cache.clear();
251 }
252}
253
254// ──────────────────────────────────────────────────────────────────
255// Tests
256// ──────────────────────────────────────────────────────────────────
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use oxibonsai_core::config::Qwen3Config;
262 use oxibonsai_model::model::BonsaiModel;
263
264 fn make_engine_no_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
265 let config = Qwen3Config::tiny_test();
266 let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
267 PrefixCachedEngine::new(engine, max_blocks)
268 }
269
270 /// Build a config small enough to keep test runtimes tight while still
271 /// satisfying the Q1_0_g128 constraint (in_features must be a multiple
272 /// of 128).
273 fn small_real_config() -> Qwen3Config {
274 Qwen3Config {
275 hidden_size: 128,
276 intermediate_size: 256,
277 num_layers: 1,
278 num_attention_heads: 4,
279 num_kv_heads: 2,
280 head_dim: 32,
281 vocab_size: 256,
282 max_context_length: 128,
283 rms_norm_eps: 1e-6,
284 rope_freq_base: 10_000.0,
285 architecture: "qwen3".to_string(),
286 model_name: "PrefixCacheTest".to_string(),
287 }
288 }
289
290 fn make_engine_with_real_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
291 use oxibonsai_kernels::{KernelDispatcher, KernelTier};
292 let config = small_real_config();
293 let model = BonsaiModel::new_for_testing_with_blocks(config);
294 // Pin the engine to the Reference (CPU) tier so the CPU KV cache is
295 // populated by the forward path. With auto_detect on a GPU host the
296 // GPU shortcut would bypass the CPU cache entirely.
297 let kernel = KernelDispatcher::with_tier(KernelTier::Reference);
298 let engine =
299 InferenceEngine::from_model_with_kernel(model, kernel, SamplingParams::default(), 42);
300 PrefixCachedEngine::new(engine, max_blocks)
301 }
302
303 #[test]
304 fn prefix_cached_engine_construction() {
305 let engine = make_engine_no_blocks(16);
306 let stats = engine.cache_stats();
307 assert_eq!(stats.cached_blocks, 0);
308 assert_eq!(stats.capacity_blocks, 16);
309 }
310
311 #[test]
312 fn prefix_cached_engine_generate_empty() {
313 let mut engine = make_engine_no_blocks(16);
314 let tokens = engine.generate(&[], &SamplingParams::default());
315 assert!(tokens.is_empty());
316 }
317
318 #[test]
319 fn prefix_cached_engine_clear_cache() {
320 let mut engine = make_engine_no_blocks(16);
321 // Run a generate so the cache might get some blocks.
322 let prompt: Vec<u32> = (0..32).collect();
323 let fast_params = SamplingParams {
324 max_tokens: 4,
325 top_k: 1,
326 temperature: 0.0,
327 ..SamplingParams::default()
328 };
329 let _ = engine.generate(&prompt, &fast_params);
330 engine.clear_cache();
331 let stats = engine.cache_stats();
332 assert_eq!(stats.cached_blocks, 0);
333 }
334
335 #[test]
336 fn prefix_cached_engine_stats_structure() {
337 let engine = make_engine_no_blocks(32);
338 let stats = engine.cache_stats();
339 assert_eq!(stats.capacity_blocks, 32);
340 assert!((stats.hit_rate - 0.0).abs() < f32::EPSILON);
341 }
342
343 #[test]
344 fn prefix_cached_engine_repeated_prompt_builds_cache() {
345 // Use a model with real blocks so the CPU KV cache is actually populated.
346 let mut engine = make_engine_with_real_blocks(32);
347 let prompt: Vec<u32> = (0..32).collect();
348 let fast_params = SamplingParams {
349 max_tokens: 1,
350 top_k: 1,
351 temperature: 0.0,
352 ..SamplingParams::default()
353 };
354
355 // First call: cold cache.
356 let _ = engine.generate(&prompt, &fast_params);
357 let stats_after_first = engine.cache_stats();
358
359 // Second call: same prompt; should record at least one hit and the
360 // cache should contain entries.
361 let _ = engine.generate(&prompt, &fast_params);
362 let stats_after_second = engine.cache_stats();
363
364 assert!(
365 stats_after_first.cached_blocks > 0,
366 "first call should have populated some cache blocks"
367 );
368 assert!(
369 stats_after_second.total_hits > 0,
370 "second call should record cache hits"
371 );
372 }
373
374 /// Acceptance criterion #5 from issue #2: a repeated prompt must
375 /// actually skip prefill work, not merely record bookkeeping hits.
376 #[test]
377 fn prefix_cached_engine_avoids_redundant_prefill_work() {
378 let mut engine = make_engine_with_real_blocks(64);
379 let prompt: Vec<u32> = (0..32).collect();
380 let fast_params = SamplingParams {
381 max_tokens: 2,
382 top_k: 1,
383 temperature: 0.0,
384 ..SamplingParams::default()
385 };
386
387 let out1 = engine.generate(&prompt, &fast_params);
388 let prefill_after_first = engine.inner.prefill_token_count();
389
390 let out2 = engine.generate(&prompt, &fast_params);
391 let prefill_after_second = engine.inner.prefill_token_count();
392
393 let second_call_prefill = prefill_after_second - prefill_after_first;
394 assert!(
395 second_call_prefill < prompt.len() as u64,
396 "second call prefilled {} tokens, expected < {} (prefix cache should have skipped some)",
397 second_call_prefill,
398 prompt.len()
399 );
400 assert!(
401 engine.cache_stats().total_hits > 0,
402 "cache should report hits"
403 );
404 // AC #3 from issue #2: cached path must produce identical output to
405 // the cold-cache path. With temperature=0 and top_k=1 the sampler is
406 // deterministic, so the two generations must match token-for-token.
407 assert_eq!(
408 out1, out2,
409 "AC #3: cached path must produce identical output ({:?} vs {:?})",
410 out1, out2
411 );
412 }
413}