use std::collections::HashMap;
use std::time::Instant;
use oxillama_arch::traits::KvCacheAccess;
use super::KvCache;
#[derive(Debug, Clone)]
pub struct PrefixCacheConfig {
pub max_entries: usize,
pub max_memory_bytes: usize,
pub min_prefix_len: usize,
}
impl Default for PrefixCacheConfig {
fn default() -> Self {
Self {
max_entries: 256,
max_memory_bytes: 512 * 1024 * 1024, min_prefix_len: 4,
}
}
}
#[derive(Clone)]
pub struct CachedKvState {
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
seq_len: usize,
}
impl CachedKvState {
pub fn seq_len(&self) -> usize {
self.seq_len
}
pub fn keys(&self) -> &[Vec<f32>] {
&self.keys
}
pub fn values(&self) -> &[Vec<f32>] {
&self.values
}
fn memory_bytes(&self) -> usize {
let float_count: usize = self
.keys
.iter()
.chain(self.values.iter())
.map(|v| v.len())
.sum();
float_count * std::mem::size_of::<f32>()
}
}
struct RadixNode {
tokens: Vec<u32>,
children: HashMap<u32, Box<RadixNode>>,
cached_kv: Option<CachedKvState>,
last_access: Instant,
ref_count: u32,
}
impl RadixNode {
fn new(tokens: Vec<u32>) -> Self {
Self {
tokens,
children: HashMap::new(),
cached_kv: None,
last_access: Instant::now(),
ref_count: 0,
}
}
fn lookup<'a>(
&'a mut self,
query: &[u32],
matched_so_far: usize,
) -> Option<(usize, &'a CachedKvState)> {
let common = common_prefix_len(&self.tokens, query);
if common < self.tokens.len() {
return None;
}
let total_matched = matched_so_far + common;
let remaining = &query[common..];
self.last_access = Instant::now();
let mut best: Option<(usize, &'a CachedKvState)> = None;
if let Some(&first_token) = remaining.first() {
if let Some(child) = self.children.get_mut(&first_token) {
best = child.lookup(remaining, total_matched);
}
}
if best.is_none() {
if let Some(ref kv) = self.cached_kv {
best = Some((total_matched, kv));
}
}
best
}
fn insert(&mut self, tokens: &[u32], kv: CachedKvState) {
if tokens.is_empty() {
self.cached_kv = Some(kv);
self.last_access = Instant::now();
return;
}
let common = common_prefix_len(&self.tokens, tokens);
if common < self.tokens.len() {
self.split_at(common);
}
let remaining = &tokens[common..];
if remaining.is_empty() {
self.cached_kv = Some(kv);
self.last_access = Instant::now();
return;
}
let first = remaining[0];
let child = self
.children
.entry(first)
.or_insert_with(|| Box::new(RadixNode::new(remaining.to_vec())));
if child.tokens == remaining {
child.cached_kv = Some(kv);
child.last_access = Instant::now();
} else {
child.insert(remaining, kv);
}
}
fn split_at(&mut self, pos: usize) {
let suffix = self.tokens[pos..].to_vec();
let first_of_suffix = suffix[0];
let mut new_child = RadixNode::new(suffix);
new_child.children = std::mem::take(&mut self.children);
new_child.cached_kv = self.cached_kv.take();
new_child.last_access = self.last_access;
new_child.ref_count = self.ref_count;
self.tokens.truncate(pos);
self.children.insert(first_of_suffix, Box::new(new_child));
}
fn count_entries(&self) -> usize {
let mine = usize::from(self.cached_kv.is_some());
let children_count: usize = self.children.values().map(|c| c.count_entries()).sum();
mine + children_count
}
fn total_memory(&self) -> usize {
let mine = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
let children_mem: usize = self.children.values().map(|c| c.total_memory()).sum();
mine + children_mem
}
fn evict_lru_one(&mut self) -> usize {
let mut oldest_time = Instant::now();
let mut oldest_path: Option<Vec<u32>> = None;
let mut oldest_mem: usize = 0;
self.find_lru_candidate(&mut oldest_time, &mut oldest_path, &mut oldest_mem, &[]);
if let Some(path) = oldest_path {
self.remove_cached_at(&path)
} else {
0
}
}
fn find_lru_candidate(
&self,
oldest_time: &mut Instant,
oldest_path: &mut Option<Vec<u32>>,
oldest_mem: &mut usize,
prefix: &[u32],
) {
if self.cached_kv.is_some() && self.ref_count == 0 && self.last_access < *oldest_time {
*oldest_time = self.last_access;
let mut path = prefix.to_vec();
path.extend_from_slice(&self.tokens);
*oldest_path = Some(path);
*oldest_mem = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
}
for child in self.children.values() {
let mut child_prefix = prefix.to_vec();
child_prefix.extend_from_slice(&self.tokens);
child.find_lru_candidate(oldest_time, oldest_path, oldest_mem, &child_prefix);
}
}
fn remove_cached_at(&mut self, path: &[u32]) -> usize {
let common = common_prefix_len(&self.tokens, path);
if common < self.tokens.len() {
return 0;
}
let remaining = &path[common..];
if remaining.is_empty() {
let freed = self.cached_kv.as_ref().map_or(0, |kv| kv.memory_bytes());
self.cached_kv = None;
return freed;
}
if let Some(&first) = remaining.first() {
if let Some(child) = self.children.get_mut(&first) {
let freed = child.remove_cached_at(remaining);
if child.cached_kv.is_none() && child.children.is_empty() {
self.children.remove(&first);
}
return freed;
}
}
0
}
fn clear_all(&mut self) {
self.cached_kv = None;
self.children.clear();
}
}
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
}
pub struct PrefixKvCache {
root: RadixNode,
config: PrefixCacheConfig,
hit_count: u64,
miss_count: u64,
}
impl PrefixKvCache {
pub fn new(config: PrefixCacheConfig) -> Self {
Self {
root: RadixNode::new(Vec::new()),
config,
hit_count: 0,
miss_count: 0,
}
}
pub fn lookup(&mut self, tokens: &[u32]) -> Option<(usize, &CachedKvState)> {
if tokens.is_empty() {
self.miss_count += 1;
return None;
}
let result = self.root.lookup(tokens, 0);
match result {
Some((matched, kv)) if matched >= self.config.min_prefix_len => {
self.hit_count += 1;
Some((matched, kv))
}
_ => {
self.miss_count += 1;
None
}
}
}
pub fn store(
&mut self,
tokens: &[u32],
kv_cache: &dyn KvCacheAccess,
seq_len: usize,
kv_dim: usize,
num_layers: usize,
) {
if tokens.len() < self.config.min_prefix_len {
return;
}
let mut keys = Vec::with_capacity(num_layers);
let mut values = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let k = kv_cache.get_keys(layer).unwrap_or(&[]);
let v = kv_cache.get_values(layer).unwrap_or(&[]);
let end = seq_len * kv_dim;
keys.push(k[..end.min(k.len())].to_vec());
values.push(v[..end.min(v.len())].to_vec());
}
let snapshot = CachedKvState {
keys,
values,
seq_len,
};
self.root.insert(tokens, snapshot);
self.evict_lru();
}
pub fn store_snapshot(&mut self, tokens: &[u32], snapshot: CachedKvState) {
if tokens.len() < self.config.min_prefix_len {
return;
}
self.root.insert(tokens, snapshot);
self.evict_lru();
}
pub fn restore(cached: &CachedKvState, target: &mut KvCache) {
target.restore_from_snapshot(&cached.keys, &cached.values, cached.seq_len);
}
fn evict_lru(&mut self) {
while self.root.count_entries() > self.config.max_entries {
if self.root.evict_lru_one() == 0 {
break; }
}
while self.root.total_memory() > self.config.max_memory_bytes {
if self.root.evict_lru_one() == 0 {
break;
}
}
}
pub fn len(&self) -> usize {
self.root.count_entries()
}
pub fn is_empty(&self) -> bool {
self.root.count_entries() == 0
}
pub fn clear(&mut self) {
self.root.clear_all();
self.hit_count = 0;
self.miss_count = 0;
}
pub fn memory_usage(&self) -> usize {
self.root.total_memory()
}
pub fn hits(&self) -> u64 {
self.hit_count
}
pub fn misses(&self) -> u64 {
self.miss_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxillama_arch::traits::KvCacheAccess;
fn make_filled_cache(
num_layers: usize,
kv_dim: usize,
num_tokens: usize,
) -> (KvCache, Vec<u32>) {
let mut cache = KvCache::new(num_layers, 128, kv_dim);
let tokens: Vec<u32> = (0..num_tokens as u32).collect();
for t in 0..num_tokens {
for layer in 0..num_layers {
let base = (layer * 1000 + t) as f32;
let key: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.01).collect();
let val: Vec<f32> = (0..kv_dim).map(|d| base + d as f32 * 0.02).collect();
cache
.store_kv(layer, &key, &val)
.expect("store_kv should succeed");
}
cache.advance();
}
(cache, tokens)
}
fn default_config() -> PrefixCacheConfig {
PrefixCacheConfig {
max_entries: 64,
max_memory_bytes: 16 * 1024 * 1024,
min_prefix_len: 1,
}
}
#[test]
fn test_insert_and_lookup_exact() {
let mut pcache = PrefixKvCache::new(default_config());
let (cache, tokens) = make_filled_cache(2, 4, 5);
pcache.store(&tokens, &cache, 5, 4, 2);
assert_eq!(pcache.len(), 1);
let result = pcache.lookup(&tokens);
assert!(result.is_some());
let (matched, kv) = result.expect("lookup should succeed");
assert_eq!(matched, 5);
assert_eq!(kv.seq_len(), 5);
}
#[test]
fn test_lookup_longer_query_returns_cached_prefix() {
let mut pcache = PrefixKvCache::new(default_config());
let (cache, tokens) = make_filled_cache(2, 4, 5);
pcache.store(&tokens, &cache, 5, 4, 2);
let longer: Vec<u32> = (0..10).collect();
let result = pcache.lookup(&longer);
assert!(result.is_some());
let (matched, _) = result.expect("lookup should succeed");
assert_eq!(matched, 5);
}
#[test]
fn test_lookup_no_match_returns_none() {
let mut pcache = PrefixKvCache::new(default_config());
let (cache, tokens) = make_filled_cache(1, 4, 5);
pcache.store(&tokens, &cache, 5, 4, 1);
let other = vec![100, 200, 300];
let result = pcache.lookup(&other);
assert!(result.is_none());
}
#[test]
fn test_empty_cache_lookup_returns_none() {
let mut pcache = PrefixKvCache::new(default_config());
let result = pcache.lookup(&[1, 2, 3]);
assert!(result.is_none());
}
#[test]
fn test_empty_query_returns_none() {
let mut pcache = PrefixKvCache::new(default_config());
let result = pcache.lookup(&[]);
assert!(result.is_none());
}
#[test]
fn test_multiple_prefixes_with_shared_root() {
let mut pcache = PrefixKvCache::new(default_config());
let tokens_a = vec![0u32, 1, 2, 3, 4];
let tokens_b = vec![0u32, 1, 2, 10, 11];
let (cache_a, _) = make_filled_cache(1, 4, 5);
let (cache_b, _) = make_filled_cache(1, 4, 5);
pcache.store(&tokens_a, &cache_a, 5, 4, 1);
pcache.store(&tokens_b, &cache_b, 5, 4, 1);
assert_eq!(pcache.len(), 2);
let (m_a, _) = pcache.lookup(&tokens_a).expect("lookup A");
assert_eq!(m_a, 5);
let (m_b, _) = pcache.lookup(&tokens_b).expect("lookup B");
assert_eq!(m_b, 5);
let shared_only = vec![0u32, 1, 2];
let result = pcache.lookup(&shared_only);
assert!(result.is_none());
}
#[test]
fn test_lru_eviction_by_entries() {
let config = PrefixCacheConfig {
max_entries: 2,
max_memory_bytes: usize::MAX,
min_prefix_len: 1,
};
let mut pcache = PrefixKvCache::new(config);
for i in 0u32..3 {
let tokens = vec![100 + i, 200 + i];
let snapshot = CachedKvState {
keys: vec![vec![i as f32; 4]],
values: vec![vec![i as f32; 4]],
seq_len: 2,
};
pcache.store_snapshot(&tokens, snapshot);
}
assert!(pcache.len() <= 2);
}
#[test]
fn test_lru_eviction_by_memory() {
let config = PrefixCacheConfig {
max_entries: 100,
max_memory_bytes: 64, min_prefix_len: 1,
};
let mut pcache = PrefixKvCache::new(config);
for i in 0u32..5 {
let tokens = vec![100 + i, 200 + i];
let snapshot = CachedKvState {
keys: vec![vec![i as f32; 4]],
values: vec![vec![i as f32; 4]],
seq_len: 2,
};
pcache.store_snapshot(&tokens, snapshot);
}
assert!(pcache.memory_usage() <= 64);
}
#[test]
fn test_clear_resets_everything() {
let mut pcache = PrefixKvCache::new(default_config());
let (cache, tokens) = make_filled_cache(1, 4, 5);
pcache.store(&tokens, &cache, 5, 4, 1);
let _ = pcache.lookup(&tokens);
pcache.clear();
assert!(pcache.is_empty());
assert_eq!(pcache.len(), 0);
assert_eq!(pcache.memory_usage(), 0);
assert_eq!(pcache.hits(), 0);
assert_eq!(pcache.misses(), 0);
}
#[test]
fn test_store_and_restore_round_trip() {
let num_layers = 2;
let kv_dim = 4;
let num_tokens = 5;
let mut pcache = PrefixKvCache::new(default_config());
let (source_cache, tokens) = make_filled_cache(num_layers, kv_dim, num_tokens);
pcache.store(&tokens, &source_cache, num_tokens, kv_dim, num_layers);
let (_, cached_kv) = pcache.lookup(&tokens).expect("lookup must succeed");
let cached_kv_clone = cached_kv.clone();
let mut target = KvCache::new(num_layers, 128, kv_dim);
PrefixKvCache::restore(&cached_kv_clone, &mut target);
assert_eq!(target.seq_len(), num_tokens);
for layer in 0..num_layers {
let src_keys = source_cache.get_keys(layer).expect("get_keys");
let tgt_keys = target.get_keys(layer).expect("get_keys");
assert_eq!(src_keys.len(), tgt_keys.len(), "layer {layer} key length");
for (i, (&s, &t)) in src_keys.iter().zip(tgt_keys.iter()).enumerate() {
assert!(
(s - t).abs() < 1e-7,
"layer {layer} key[{i}]: source={s}, target={t}"
);
}
let src_vals = source_cache.get_values(layer).expect("get_values");
let tgt_vals = target.get_values(layer).expect("get_values");
assert_eq!(src_vals.len(), tgt_vals.len(), "layer {layer} value length");
for (i, (&s, &t)) in src_vals.iter().zip(tgt_vals.iter()).enumerate() {
assert!(
(s - t).abs() < 1e-7,
"layer {layer} value[{i}]: source={s}, target={t}"
);
}
}
}
#[test]
fn test_memory_usage_tracking() {
let mut pcache = PrefixKvCache::new(default_config());
assert_eq!(pcache.memory_usage(), 0);
let snapshot = CachedKvState {
keys: vec![vec![0.0f32; 8]], values: vec![vec![0.0f32; 8]],
seq_len: 2,
};
pcache.store_snapshot(&[1, 2], snapshot);
assert_eq!(pcache.memory_usage(), 64);
}
#[test]
fn test_hit_miss_counters() {
let mut pcache = PrefixKvCache::new(default_config());
assert_eq!(pcache.hits(), 0);
assert_eq!(pcache.misses(), 0);
let _ = pcache.lookup(&[1, 2, 3]);
assert_eq!(pcache.misses(), 1);
assert_eq!(pcache.hits(), 0);
let snapshot = CachedKvState {
keys: vec![vec![0.0; 4]],
values: vec![vec![0.0; 4]],
seq_len: 2,
};
pcache.store_snapshot(&[1, 2], snapshot);
let _ = pcache.lookup(&[1, 2]);
assert_eq!(pcache.hits(), 1);
assert_eq!(pcache.misses(), 1);
let _ = pcache.lookup(&[99, 100]);
assert_eq!(pcache.hits(), 1);
assert_eq!(pcache.misses(), 2);
}
#[test]
fn test_min_prefix_len_filters_short_store() {
let config = PrefixCacheConfig {
max_entries: 64,
max_memory_bytes: 16 * 1024 * 1024,
min_prefix_len: 5,
};
let mut pcache = PrefixKvCache::new(config);
let (cache, _) = make_filled_cache(1, 4, 3);
pcache.store(&[0, 1, 2], &cache, 3, 4, 1);
assert!(pcache.is_empty());
}
#[test]
fn test_min_prefix_len_filters_short_lookup() {
let config = PrefixCacheConfig {
max_entries: 64,
max_memory_bytes: 16 * 1024 * 1024,
min_prefix_len: 5,
};
let mut pcache = PrefixKvCache::new(config);
let (cache, tokens) = make_filled_cache(1, 4, 10);
pcache.store(&tokens, &cache, 10, 4, 1);
assert_eq!(pcache.len(), 1);
let short_query = vec![0u32, 1, 2];
let result = pcache.lookup(&short_query);
assert!(result.is_none());
}
#[test]
fn test_is_empty_and_len() {
let mut pcache = PrefixKvCache::new(default_config());
assert!(pcache.is_empty());
assert_eq!(pcache.len(), 0);
let snapshot = CachedKvState {
keys: vec![vec![0.0; 4]],
values: vec![vec![0.0; 4]],
seq_len: 2,
};
pcache.store_snapshot(&[1, 2], snapshot);
assert!(!pcache.is_empty());
assert_eq!(pcache.len(), 1);
}
#[test]
fn test_common_prefix_len() {
assert_eq!(common_prefix_len(&[], &[]), 0);
assert_eq!(common_prefix_len(&[1, 2, 3], &[]), 0);
assert_eq!(common_prefix_len(&[], &[1, 2, 3]), 0);
assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 3]), 3);
assert_eq!(common_prefix_len(&[1, 2, 3], &[1, 2, 4]), 2);
assert_eq!(common_prefix_len(&[1, 2, 3], &[4, 5, 6]), 0);
assert_eq!(common_prefix_len(&[1, 2], &[1, 2, 3, 4]), 2);
}
#[test]
fn test_node_split_preserves_data() {
let mut pcache = PrefixKvCache::new(default_config());
let snap_a = CachedKvState {
keys: vec![vec![1.0; 4]],
values: vec![vec![2.0; 4]],
seq_len: 4,
};
let snap_b = CachedKvState {
keys: vec![vec![3.0; 4]],
values: vec![vec![4.0; 4]],
seq_len: 4,
};
pcache.store_snapshot(&[1, 2, 3, 4], snap_a);
pcache.store_snapshot(&[1, 2, 5, 6], snap_b);
assert_eq!(pcache.len(), 2);
let (m_a, kv_a) = pcache.lookup(&[1, 2, 3, 4]).expect("lookup A");
assert_eq!(m_a, 4);
assert_eq!(kv_a.keys()[0][0], 1.0);
let (m_b, kv_b) = pcache.lookup(&[1, 2, 5, 6]).expect("lookup B");
assert_eq!(m_b, 4);
assert_eq!(kv_b.keys()[0][0], 3.0);
}
}