use std::collections::HashMap;
pub type KvBlockPair = (Vec<Vec<f32>>, Vec<Vec<f32>>);
pub struct CacheBlock {
pub keys: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
pub token_ids: Vec<u32>,
pub last_used: u64,
pub ref_count: usize,
}
impl CacheBlock {
pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize, block_size: usize) -> Self {
let per_layer = num_kv_heads * head_dim * block_size;
let keys = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
let values = (0..num_layers).map(|_| vec![0.0f32; per_layer]).collect();
Self {
keys,
values,
token_ids: Vec::new(),
last_used: 0,
ref_count: 0,
}
}
pub fn memory_bytes(&self) -> usize {
let per_layer = self.keys.first().map(|v| v.len()).unwrap_or(0);
2 * self.keys.len() * per_layer * std::mem::size_of::<f32>()
}
}
struct TrieNode {
children: HashMap<u32, usize>,
block_idx: Option<usize>,
}
impl TrieNode {
fn new() -> Self {
Self {
children: HashMap::new(),
block_idx: None,
}
}
}
pub struct PrefixCache {
nodes: Vec<TrieNode>,
blocks: Vec<CacheBlock>,
occupied_blocks: Vec<usize>,
free_block_pool: Vec<usize>,
max_blocks: usize,
block_size: usize,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
generation: u64,
pub hits: u64,
pub misses: u64,
pub evictions: u64,
}
impl PrefixCache {
pub fn new(
max_blocks: usize,
block_size: usize,
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
let root = TrieNode::new();
Self {
nodes: vec![root],
blocks: Vec::new(),
occupied_blocks: Vec::new(),
free_block_pool: Vec::new(),
max_blocks,
block_size,
num_layers,
num_kv_heads,
head_dim,
generation: 0,
hits: 0,
misses: 0,
evictions: 0,
}
}
pub fn lookup(&mut self, token_ids: &[u32]) -> (usize, Vec<&CacheBlock>) {
let mut node_idx = 0usize; let mut matched_len = 0usize;
let mut matched_block_indices: Vec<usize> = Vec::new();
let full_blocks = token_ids.len() / self.block_size;
for block_num in 0..full_blocks {
let block_start = block_num * self.block_size;
let block_end = block_start + self.block_size;
let block_tokens = &token_ids[block_start..block_end];
let edge_key = Self::block_edge_key(block_tokens);
match self.nodes[node_idx].children.get(&edge_key).copied() {
None => {
self.misses += 1;
break;
}
Some(child_node_idx) => {
let maybe_block_idx = self.nodes[child_node_idx].block_idx;
match maybe_block_idx {
None => {
self.misses += 1;
break;
}
Some(bidx) => {
if self.blocks[bidx].token_ids != block_tokens {
self.misses += 1;
break;
}
self.generation += 1;
self.blocks[bidx].last_used = self.generation;
self.blocks[bidx].ref_count += 1;
matched_len += self.block_size;
matched_block_indices.push(bidx);
self.hits += 1;
node_idx = child_node_idx;
}
}
}
}
}
let block_refs: Vec<&CacheBlock> = matched_block_indices
.iter()
.map(|&bidx| &self.blocks[bidx])
.collect();
(matched_len, block_refs)
}
pub fn insert(
&mut self,
token_ids: &[u32],
block_start: usize,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
) -> usize {
while self.occupied_blocks.len() >= self.max_blocks {
if !self.evict_lru() {
break;
}
}
let block_end = block_start + self.block_size;
let block_tokens = token_ids[block_start..block_end.min(token_ids.len())].to_vec();
let mut node_idx = 0usize;
let num_full_blocks_before = block_start / self.block_size;
for blk in 0..num_full_blocks_before {
let seg_start = blk * self.block_size;
let seg_end = seg_start + self.block_size;
let seg = &token_ids[seg_start..seg_end];
let edge_key = Self::block_edge_key(seg);
if let Some(&child) = self.nodes[node_idx].children.get(&edge_key) {
node_idx = child;
} else {
let new_node_idx = self.nodes.len();
self.nodes.push(TrieNode::new());
self.nodes[node_idx].children.insert(edge_key, new_node_idx);
node_idx = new_node_idx;
}
}
let edge_key = Self::block_edge_key(&block_tokens);
let leaf_node_idx = if let Some(&existing) = self.nodes[node_idx].children.get(&edge_key) {
existing
} else {
let new_node_idx = self.nodes.len();
self.nodes.push(TrieNode::new());
self.nodes[node_idx].children.insert(edge_key, new_node_idx);
new_node_idx
};
self.generation += 1;
let block_idx = if let Some(reuse_idx) = self.free_block_pool.pop() {
let block = &mut self.blocks[reuse_idx];
block.keys = keys;
block.values = values;
block.token_ids = block_tokens;
block.last_used = self.generation;
block.ref_count = 0;
reuse_idx
} else {
let mut blk = CacheBlock::new(
self.num_layers,
self.num_kv_heads,
self.head_dim,
self.block_size,
);
blk.keys = keys;
blk.values = values;
blk.token_ids = block_tokens;
blk.last_used = self.generation;
blk.ref_count = 0;
let idx = self.blocks.len();
self.blocks.push(blk);
idx
};
self.nodes[leaf_node_idx].block_idx = Some(block_idx);
self.occupied_blocks.push(block_idx);
block_idx
}
pub fn release(&mut self, block_idx: usize) {
if block_idx < self.blocks.len() && self.blocks[block_idx].ref_count > 0 {
self.blocks[block_idx].ref_count -= 1;
}
}
pub fn len(&self) -> usize {
self.occupied_blocks.len()
}
pub fn is_empty(&self) -> bool {
self.occupied_blocks.is_empty()
}
pub fn capacity(&self) -> usize {
self.max_blocks
}
pub fn memory_bytes(&self) -> usize {
self.occupied_blocks
.iter()
.map(|&idx| self.blocks[idx].memory_bytes())
.sum()
}
pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f32 / total as f32
}
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn get_block(&self, idx: usize) -> Option<&CacheBlock> {
self.blocks.get(idx)
}
pub fn clear(&mut self) {
self.nodes.clear();
self.nodes.push(TrieNode::new());
self.blocks.clear();
self.occupied_blocks.clear();
self.free_block_pool.clear();
self.generation = 0;
}
fn block_edge_key(tokens: &[u32]) -> u32 {
let mut h: u64 = 0xcbf2_9ce4_8422_2325; for &t in tokens {
h ^= t as u64;
h = h.wrapping_mul(0x0000_0100_0000_01b3); }
((h >> 32) ^ (h & 0xffff_ffff)) as u32
}
fn evict_lru(&mut self) -> bool {
let victim_pos = self
.occupied_blocks
.iter()
.enumerate()
.filter(|(_, &bidx)| self.blocks[bidx].ref_count == 0)
.min_by_key(|(_, &bidx)| self.blocks[bidx].last_used)
.map(|(pos, _)| pos);
let Some(pos) = victim_pos else {
return false;
};
let victim_block_idx = self.occupied_blocks.swap_remove(pos);
for node in &mut self.nodes {
if node.block_idx == Some(victim_block_idx) {
node.block_idx = None;
break;
}
}
self.free_block_pool.push(victim_block_idx);
self.evictions += 1;
true
}
}
pub struct CacheSession {
pub matched_prefix_len: usize,
pub block_indices: Vec<usize>,
}
impl CacheSession {
pub fn new(matched_prefix_len: usize, block_indices: Vec<usize>) -> Self {
Self {
matched_prefix_len,
block_indices,
}
}
pub fn cached_tokens(&self, block_size: usize) -> usize {
self.block_indices.len() * block_size
}
pub fn is_empty(&self) -> bool {
self.block_indices.is_empty()
}
}
pub struct PrefixAwarePrefill {
pub cache: PrefixCache,
}
impl PrefixAwarePrefill {
pub fn new(cache: PrefixCache) -> Self {
Self { cache }
}
pub fn prepare(&mut self, token_ids: &[u32]) -> (CacheSession, usize) {
let (matched_len, matched_blocks) = self.cache.lookup(token_ids);
let num_matched = matched_blocks.len();
drop(matched_blocks);
let block_indices: Vec<usize> = (0..num_matched)
.map(|blk_num| {
let block_start = blk_num * self.cache.block_size;
let block_tokens = &token_ids[block_start..block_start + self.cache.block_size];
let edge_key = PrefixCache::block_edge_key(block_tokens);
self.find_block_idx_for_edge(blk_num, token_ids, edge_key)
})
.collect();
let uncached_start = matched_len;
let session = CacheSession::new(matched_len, block_indices);
(session, uncached_start)
}
pub fn store_blocks(
&mut self,
token_ids: &[u32],
uncached_start: usize,
keys_by_block: Vec<KvBlockPair>,
) {
let block_size = self.cache.block_size;
for (i, (keys, values)) in keys_by_block.into_iter().enumerate() {
let block_start = uncached_start + i * block_size;
let block_end = block_start + block_size;
if block_end > token_ids.len() {
break;
}
self.cache.insert(token_ids, block_start, keys, values);
}
}
pub fn release_session(&mut self, session: CacheSession) {
for bidx in session.block_indices {
self.cache.release(bidx);
}
}
pub fn stats(&self) -> PrefixCacheStats {
PrefixCacheStats {
hit_rate: self.cache.hit_rate(),
cached_blocks: self.cache.len(),
capacity_blocks: self.cache.capacity(),
memory_bytes: self.cache.memory_bytes(),
total_hits: self.cache.hits,
total_misses: self.cache.misses,
total_evictions: self.cache.evictions,
}
}
fn find_block_idx_for_edge(&self, blk_num: usize, token_ids: &[u32], edge_key: u32) -> usize {
let mut node_idx = 0usize;
for blk in 0..blk_num {
let seg_start = blk * self.cache.block_size;
let seg_end = seg_start + self.cache.block_size;
let seg = &token_ids[seg_start..seg_end];
let parent_edge_key = PrefixCache::block_edge_key(seg);
if let Some(&child) = self.cache.nodes[node_idx].children.get(&parent_edge_key) {
node_idx = child;
} else {
return usize::MAX;
}
}
if let Some(&child_idx) = self.cache.nodes[node_idx].children.get(&edge_key) {
self.cache.nodes[child_idx].block_idx.unwrap_or(usize::MAX)
} else {
usize::MAX
}
}
}
#[derive(Debug, serde::Serialize)]
pub struct PrefixCacheStats {
pub hit_rate: f32,
pub cached_blocks: usize,
pub capacity_blocks: usize,
pub memory_bytes: usize,
pub total_hits: u64,
pub total_misses: u64,
pub total_evictions: u64,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_block(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
) -> CacheBlock {
CacheBlock::new(num_layers, num_kv_heads, head_dim, block_size)
}
fn make_kv(
num_layers: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
val: f32,
) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
let per_layer = num_kv_heads * head_dim * block_size;
let keys: Vec<Vec<f32>> = (0..num_layers).map(|_| vec![val; per_layer]).collect();
let values: Vec<Vec<f32>> = (0..num_layers)
.map(|_| vec![val + 1.0; per_layer])
.collect();
(keys, values)
}
#[test]
fn test_cache_block_memory_bytes() {
let block = make_block(2, 4, 8, 4);
let expected = 2 * 2 * (4 * 8 * 4) * std::mem::size_of::<f32>();
assert_eq!(block.memory_bytes(), expected);
}
#[test]
fn test_prefix_cache_insert_and_lookup_hit() {
let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let (keys, values) = make_kv(2, 2, 8, 4, 1.0);
cache.insert(&token_ids, 0, keys, values);
let (matched, blocks) = cache.lookup(&token_ids);
assert_eq!(matched, 4, "should match one full block of 4 tokens");
assert_eq!(blocks.len(), 1);
assert_eq!(cache.hits, 1);
}
#[test]
fn test_prefix_cache_lookup_miss() {
let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
let token_ids: Vec<u32> = vec![10, 20, 30, 40];
let (matched, blocks) = cache.lookup(&token_ids);
assert_eq!(matched, 0);
assert!(blocks.is_empty());
assert_eq!(cache.misses, 1);
}
#[test]
fn test_prefix_cache_partial_prefix_match() {
let mut cache = PrefixCache::new(8, 4, 2, 2, 8);
let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let (keys0, values0) = make_kv(2, 2, 8, 4, 0.5);
cache.insert(&token_ids, 0, keys0, values0);
let query: Vec<u32> = vec![1, 2, 3, 4, 9, 10, 11, 12];
let (matched, blocks) = cache.lookup(&query);
assert_eq!(matched, 4);
assert_eq!(blocks.len(), 1);
}
#[test]
fn test_prefix_cache_lru_eviction() {
let mut cache = PrefixCache::new(2, 4, 1, 1, 4);
let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
let tokens_c: Vec<u32> = vec![9, 10, 11, 12];
let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
let (kc, vc) = make_kv(1, 1, 4, 4, 3.0);
cache.insert(&tokens_a, 0, ka, va);
cache.insert(&tokens_b, 0, kb, vb);
let _ = cache.lookup(&tokens_b);
cache.insert(&tokens_c, 0, kc, vc);
assert_eq!(
cache.len(),
2,
"should have exactly 2 blocks after eviction"
);
assert_eq!(cache.evictions, 1);
let (matched_a, _) = cache.lookup(&tokens_a);
assert_eq!(matched_a, 0, "evicted block should not be found");
}
#[test]
fn test_prefix_cache_ref_count_prevents_eviction() {
let mut cache = PrefixCache::new(1, 4, 1, 1, 4);
let tokens_a: Vec<u32> = vec![1, 2, 3, 4];
let tokens_b: Vec<u32> = vec![5, 6, 7, 8];
let (ka, va) = make_kv(1, 1, 4, 4, 1.0);
let (kb, vb) = make_kv(1, 1, 4, 4, 2.0);
let bidx_a = cache.insert(&tokens_a, 0, ka, va);
cache.blocks[bidx_a].ref_count += 1;
cache.insert(&tokens_b, 0, kb, vb);
assert_eq!(cache.evictions, 0, "pinned block must not be evicted");
cache.release(bidx_a);
assert_eq!(cache.blocks[bidx_a].ref_count, 0);
}
#[test]
fn test_prefix_cache_hit_rate() {
let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
let tokens: Vec<u32> = vec![1, 2, 3, 4];
let (k, v) = make_kv(1, 1, 4, 4, 1.0);
cache.insert(&tokens, 0, k, v);
let _ = cache.lookup(&tokens);
let _ = cache.lookup(&[99, 100, 101, 102]);
let rate = cache.hit_rate();
assert!(
(rate - 0.5).abs() < 1e-5,
"hit rate should be 0.5, got {rate}"
);
}
#[test]
fn test_prefix_cache_clear() {
let mut cache = PrefixCache::new(8, 4, 1, 1, 4);
let tokens: Vec<u32> = vec![1, 2, 3, 4];
let (k, v) = make_kv(1, 1, 4, 4, 1.0);
cache.insert(&tokens, 0, k, v);
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
let (matched, _) = cache.lookup(&tokens);
assert_eq!(matched, 0);
}
#[test]
fn test_cache_session_cached_tokens() {
let session = CacheSession::new(8, vec![0, 1]);
assert_eq!(session.cached_tokens(4), 8);
assert!(!session.is_empty());
let empty = CacheSession::new(0, vec![]);
assert!(empty.is_empty());
assert_eq!(empty.cached_tokens(4), 0);
}
#[test]
fn test_prefix_aware_prefill_prepare() {
let inner = PrefixCache::new(8, 4, 1, 1, 4);
let mut prefill = PrefixAwarePrefill::new(inner);
let token_ids: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let (k, v) = make_kv(1, 1, 4, 4, 1.0);
prefill.cache.insert(&token_ids, 0, k, v);
let (session, uncached_start) = prefill.prepare(&token_ids);
assert_eq!(session.matched_prefix_len, 4);
assert_eq!(uncached_start, 4);
prefill.release_session(session);
}
#[test]
fn test_prefix_cache_stats() {
let inner = PrefixCache::new(8, 4, 1, 1, 4);
let mut prefill = PrefixAwarePrefill::new(inner);
let token_ids: Vec<u32> = vec![1, 2, 3, 4];
let (k, v) = make_kv(1, 1, 4, 4, 1.0);
prefill.cache.insert(&token_ids, 0, k, v);
let _ = prefill.prepare(&token_ids);
let stats = prefill.stats();
assert!(stats.cached_blocks > 0 || stats.total_hits > 0 || stats.total_misses > 0);
assert_eq!(stats.capacity_blocks, 8);
}
#[test]
fn test_prefix_cache_capacity_enforcement() {
let capacity = 4usize;
let mut cache = PrefixCache::new(capacity, 4, 1, 1, 4);
for i in 0..capacity + 2 {
let tokens: Vec<u32> = (0..4).map(|j| (i * 4 + j) as u32).collect();
let (k, v) = make_kv(1, 1, 4, 4, i as f32);
cache.insert(&tokens, 0, k, v);
}
assert!(
cache.len() <= capacity,
"cache should not exceed max_blocks={capacity}, got {}",
cache.len()
);
assert!(
cache.evictions >= 2,
"should have evicted at least 2 blocks"
);
}
}