use oxibonsai_model::prefix_cache::{
KvBlockPair, PrefixAwarePrefill, PrefixCache, PrefixCacheStats,
};
use crate::engine::{InferenceEngine, EOS_TOKEN_ID};
use crate::sampling::SamplingParams;
const BLOCK_SIZE: usize = 16;
pub struct PrefixCachedEngine<'a> {
pub inner: InferenceEngine<'a>,
pub prefix_cache: PrefixAwarePrefill,
}
impl<'a> PrefixCachedEngine<'a> {
pub fn new(engine: InferenceEngine<'a>, max_cache_blocks: usize) -> Self {
let cfg = engine.model().config();
let cache = PrefixCache::new(
max_cache_blocks,
BLOCK_SIZE,
cfg.num_layers,
cfg.num_kv_heads,
cfg.head_dim,
);
let prefix_cache = PrefixAwarePrefill::new(cache);
Self {
inner: engine,
prefix_cache,
}
}
pub fn generate(&mut self, prompt_tokens: &[u32], params: &SamplingParams) -> Vec<u32> {
if prompt_tokens.is_empty() {
return vec![];
}
self.inner.model_mut().reset();
let (session, uncached_start) = self.prefix_cache.prepare(prompt_tokens);
let block_size = self.prefix_cache.cache.block_size();
let cfg = self.inner.model().config().clone();
let num_layers = cfg.num_layers;
if uncached_start > 0 && !session.block_indices.is_empty() {
for (block_num, &bidx) in session.block_indices.iter().enumerate() {
if bidx == usize::MAX {
continue;
}
let snapshots: Option<Vec<(Vec<f32>, Vec<f32>)>> =
self.prefix_cache.cache.get_block(bidx).map(|block| {
(0..num_layers)
.map(|l| (block.keys[l].clone(), block.values[l].clone()))
.collect()
});
let snapshots = match snapshots {
Some(s) => s,
None => continue,
};
let block_start = block_num * block_size;
let kv = self.inner.model_mut().kv_cache_mut();
for (layer, (keys, values)) in snapshots.into_iter().enumerate() {
kv.inject_block(layer, block_start, block_size, &keys, &values);
}
}
self.inner
.model_mut()
.kv_cache_mut()
.set_seq_len(uncached_start);
}
let mut last_logits = if uncached_start < prompt_tokens.len() {
match self
.inner
.prefill_from_pos(&prompt_tokens[uncached_start..], uncached_start)
{
Ok(logits) => logits,
Err(e) => {
tracing::warn!(error = %e, "prefix-cache prefill failed");
self.prefix_cache.release_session(session);
return vec![];
}
}
} else {
let last_pos = prompt_tokens.len().saturating_sub(1);
let last_tok = prompt_tokens[last_pos];
match self.inner.decode_step(last_tok, last_pos) {
Ok(logits) => logits,
Err(e) => {
tracing::warn!(error = %e, "prefix-cache decode_step failed");
self.prefix_cache.release_session(session);
return vec![];
}
}
};
let real_cpu_kv = {
let kv = self.inner.model().kv_cache();
let probe_len = prompt_tokens.len().min(kv.max_seq_len());
kv.keys_for(0, 0, probe_len).iter().any(|&x| x != 0.0)
};
if real_cpu_kv {
let new_blocks_count = prompt_tokens.len().saturating_sub(uncached_start) / block_size;
if new_blocks_count > 0 {
let mut keys_by_block: Vec<KvBlockPair> = Vec::with_capacity(new_blocks_count);
for blk in 0..new_blocks_count {
let block_pos = uncached_start + blk * block_size;
let mut layer_keys: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
let mut layer_values: Vec<Vec<f32>> = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let (k, v) = self
.inner
.model()
.kv_cache()
.extract_block(layer, block_pos, block_size);
layer_keys.push(k);
layer_values.push(v);
}
keys_by_block.push((layer_keys, layer_values));
}
self.prefix_cache
.store_blocks(prompt_tokens, uncached_start, keys_by_block);
}
}
let mut output = Vec::with_capacity(params.max_tokens);
let mut sampler = crate::sampling::Sampler::new(params.clone(), 0);
for (pos, _) in (prompt_tokens.len()..).zip(0..params.max_tokens) {
let next_token = match sampler.sample(&last_logits) {
Ok(t) => t,
Err(e) => {
tracing::warn!(error = %e, "prefix-cache sampler error");
break;
}
};
if next_token == EOS_TOKEN_ID {
break;
}
output.push(next_token);
last_logits = match self.inner.decode_step(next_token, pos) {
Ok(l) => l,
Err(e) => {
tracing::warn!(error = %e, "prefix-cache decode loop error");
break;
}
};
}
self.prefix_cache.release_session(session);
output
}
pub fn cache_stats(&self) -> PrefixCacheStats {
self.prefix_cache.stats()
}
pub fn clear_cache(&mut self) {
self.prefix_cache.cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_model::model::BonsaiModel;
fn make_engine_no_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
let config = Qwen3Config::tiny_test();
let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
PrefixCachedEngine::new(engine, max_blocks)
}
fn small_real_config() -> Qwen3Config {
Qwen3Config {
hidden_size: 128,
intermediate_size: 256,
num_layers: 1,
num_attention_heads: 4,
num_kv_heads: 2,
head_dim: 32,
vocab_size: 256,
max_context_length: 128,
rms_norm_eps: 1e-6,
rope_freq_base: 10_000.0,
architecture: "qwen3".to_string(),
model_name: "PrefixCacheTest".to_string(),
}
}
fn make_engine_with_real_blocks(max_blocks: usize) -> PrefixCachedEngine<'static> {
use oxibonsai_kernels::{KernelDispatcher, KernelTier};
let config = small_real_config();
let model = BonsaiModel::new_for_testing_with_blocks(config);
let kernel = KernelDispatcher::with_tier(KernelTier::Reference);
let engine =
InferenceEngine::from_model_with_kernel(model, kernel, SamplingParams::default(), 42);
PrefixCachedEngine::new(engine, max_blocks)
}
#[test]
fn prefix_cached_engine_construction() {
let engine = make_engine_no_blocks(16);
let stats = engine.cache_stats();
assert_eq!(stats.cached_blocks, 0);
assert_eq!(stats.capacity_blocks, 16);
}
#[test]
fn prefix_cached_engine_generate_empty() {
let mut engine = make_engine_no_blocks(16);
let tokens = engine.generate(&[], &SamplingParams::default());
assert!(tokens.is_empty());
}
#[test]
fn prefix_cached_engine_clear_cache() {
let mut engine = make_engine_no_blocks(16);
let prompt: Vec<u32> = (0..32).collect();
let fast_params = SamplingParams {
max_tokens: 4,
top_k: 1,
temperature: 0.0,
..SamplingParams::default()
};
let _ = engine.generate(&prompt, &fast_params);
engine.clear_cache();
let stats = engine.cache_stats();
assert_eq!(stats.cached_blocks, 0);
}
#[test]
fn prefix_cached_engine_stats_structure() {
let engine = make_engine_no_blocks(32);
let stats = engine.cache_stats();
assert_eq!(stats.capacity_blocks, 32);
assert!((stats.hit_rate - 0.0).abs() < f32::EPSILON);
}
#[test]
fn prefix_cached_engine_repeated_prompt_builds_cache() {
let mut engine = make_engine_with_real_blocks(32);
let prompt: Vec<u32> = (0..32).collect();
let fast_params = SamplingParams {
max_tokens: 1,
top_k: 1,
temperature: 0.0,
..SamplingParams::default()
};
let _ = engine.generate(&prompt, &fast_params);
let stats_after_first = engine.cache_stats();
let _ = engine.generate(&prompt, &fast_params);
let stats_after_second = engine.cache_stats();
assert!(
stats_after_first.cached_blocks > 0,
"first call should have populated some cache blocks"
);
assert!(
stats_after_second.total_hits > 0,
"second call should record cache hits"
);
}
#[test]
fn prefix_cached_engine_avoids_redundant_prefill_work() {
let mut engine = make_engine_with_real_blocks(64);
let prompt: Vec<u32> = (0..32).collect();
let fast_params = SamplingParams {
max_tokens: 2,
top_k: 1,
temperature: 0.0,
..SamplingParams::default()
};
let out1 = engine.generate(&prompt, &fast_params);
let prefill_after_first = engine.inner.prefill_token_count();
let out2 = engine.generate(&prompt, &fast_params);
let prefill_after_second = engine.inner.prefill_token_count();
let second_call_prefill = prefill_after_second - prefill_after_first;
assert!(
second_call_prefill < prompt.len() as u64,
"second call prefilled {} tokens, expected < {} (prefix cache should have skipped some)",
second_call_prefill,
prompt.len()
);
assert!(
engine.cache_stats().total_hits > 0,
"cache should report hits"
);
assert_eq!(
out1, out2,
"AC #3: cached path must produce identical output ({:?} vs {:?})",
out1, out2
);
}
}