use ferrum_types::{FerrumError, Result, TokenId};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, trace};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PrefixId(Vec<TokenId>);
impl PrefixId {
pub fn new(tokens: Vec<TokenId>) -> Self {
Self(tokens)
}
pub fn tokens(&self) -> &[TokenId] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl From<Vec<TokenId>> for PrefixId {
fn from(tokens: Vec<TokenId>) -> Self {
Self::new(tokens)
}
}
impl From<&[TokenId]> for PrefixId {
fn from(tokens: &[TokenId]) -> Self {
Self::new(tokens.to_vec())
}
}
#[derive(Debug, Clone)]
pub struct CachedPrefix {
pub prefix_id: PrefixId,
pub kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
pub last_logits: Vec<f32>,
pub ref_count: usize,
pub last_access: std::time::Instant,
pub size: usize,
}
impl CachedPrefix {
pub fn new(
prefix_id: PrefixId,
kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
last_logits: Vec<f32>,
) -> Self {
let size = prefix_id.len();
Self {
prefix_id,
kv_handle,
last_logits,
ref_count: 1,
last_access: std::time::Instant::now(),
size,
}
}
pub fn add_ref(&mut self) {
self.ref_count += 1;
self.touch();
}
pub fn remove_ref(&mut self) -> Result<()> {
if self.ref_count == 0 {
return Err(FerrumError::invalid_parameter(
"Cannot remove ref from zero-ref prefix",
));
}
self.ref_count -= 1;
Ok(())
}
pub fn touch(&mut self) {
self.last_access = std::time::Instant::now();
}
pub fn can_evict(&self) -> bool {
self.ref_count == 0
}
}
#[derive(Debug)]
pub struct PrefixCache {
prefixes: RwLock<HashMap<PrefixId, CachedPrefix>>,
max_prefixes: usize,
min_prefix_length: usize,
hits: parking_lot::Mutex<usize>,
misses: parking_lot::Mutex<usize>,
evictions: parking_lot::Mutex<usize>,
}
impl PrefixCache {
pub fn new(max_prefixes: usize, min_prefix_length: usize) -> Self {
Self {
prefixes: RwLock::new(HashMap::new()),
max_prefixes,
min_prefix_length,
hits: parking_lot::Mutex::new(0),
misses: parking_lot::Mutex::new(0),
evictions: parking_lot::Mutex::new(0),
}
}
pub fn find_prefix(
&self,
tokens: &[TokenId],
) -> Option<(
PrefixId,
Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
Vec<f32>,
)> {
if tokens.len() < self.min_prefix_length {
return None;
}
let prefixes = self.prefixes.read();
let mut best_match = None;
let mut best_len = 0;
for (prefix_id, cached_prefix) in prefixes.iter() {
if tokens.starts_with(prefix_id.tokens()) && prefix_id.len() > best_len {
best_match = Some((
prefix_id.clone(),
cached_prefix.kv_handle.clone(),
cached_prefix.last_logits.clone(),
));
best_len = prefix_id.len();
}
}
if let Some(ref match_info) = best_match {
*self.hits.lock() += 1;
trace!("Prefix cache hit: {} tokens", best_len);
drop(prefixes); let mut prefixes = self.prefixes.write();
if let Some(cached_prefix) = prefixes.get_mut(&match_info.0) {
cached_prefix.touch();
}
} else {
*self.misses.lock() += 1;
trace!("Prefix cache miss for {} tokens", tokens.len());
}
best_match
}
pub fn store_prefix(
&self,
prefix_tokens: &[TokenId],
kv_handle: Arc<dyn ferrum_interfaces::KvCacheHandle + Send + Sync>,
last_logits: Vec<f32>,
) -> Result<()> {
if prefix_tokens.len() < self.min_prefix_length {
return Ok(()); }
let prefix_id = PrefixId::from(prefix_tokens);
let cached_prefix = CachedPrefix::new(prefix_id.clone(), kv_handle, last_logits);
let mut prefixes = self.prefixes.write();
if prefixes.len() >= self.max_prefixes && !prefixes.contains_key(&prefix_id) {
self.evict_lru(&mut prefixes);
}
if let Some(existing) = prefixes.get_mut(&prefix_id) {
existing.add_ref();
} else {
prefixes.insert(prefix_id, cached_prefix);
debug!("Stored new prefix: {} tokens", prefix_tokens.len());
}
Ok(())
}
pub fn remove_ref(&self, prefix_tokens: &[TokenId]) -> Result<()> {
let prefix_id = PrefixId::from(prefix_tokens);
let mut prefixes = self.prefixes.write();
if let Some(cached_prefix) = prefixes.get_mut(&prefix_id) {
cached_prefix.remove_ref()?;
if cached_prefix.ref_count == 0 {
prefixes.remove(&prefix_id);
debug!(
"Removed unreferenced prefix: {} tokens",
prefix_tokens.len()
);
}
}
Ok(())
}
fn evict_lru(&self, prefixes: &mut HashMap<PrefixId, CachedPrefix>) {
let mut oldest_id = None;
let mut oldest_time = None;
for (prefix_id, cached_prefix) in prefixes.iter() {
if cached_prefix.can_evict() {
if let Some(current_oldest) = oldest_time {
if cached_prefix.last_access < current_oldest {
oldest_time = Some(cached_prefix.last_access);
oldest_id = Some(prefix_id.clone());
}
} else {
oldest_time = Some(cached_prefix.last_access);
oldest_id = Some(prefix_id.clone());
}
}
}
if oldest_id.is_none() {
for (prefix_id, cached_prefix) in prefixes.iter() {
if let Some(current_oldest) = oldest_time {
if cached_prefix.last_access < current_oldest {
oldest_time = Some(cached_prefix.last_access);
oldest_id = Some(prefix_id.clone());
}
} else {
oldest_time = Some(cached_prefix.last_access);
oldest_id = Some(prefix_id.clone());
}
}
}
if let Some(prefix_id) = oldest_id {
prefixes.remove(&prefix_id);
*self.evictions.lock() += 1;
debug!("Evicted LRU prefix: {} tokens", prefix_id.len());
}
}
pub fn evict_n(&self, n: usize) -> usize {
let mut prefixes = self.prefixes.write();
let mut evicted = 0;
for _ in 0..n {
if prefixes.is_empty() {
break;
}
self.evict_lru(&mut prefixes);
evicted += 1;
}
evicted
}
pub fn stats(&self) -> PrefixCacheStats {
let hits = *self.hits.lock();
let misses = *self.misses.lock();
let evictions = *self.evictions.lock();
let prefixes = self.prefixes.read();
let total_size: usize = prefixes.values().map(|p| p.size).sum();
let active_prefixes = prefixes.len();
drop(prefixes);
PrefixCacheStats {
hits,
misses,
evictions,
active_prefixes,
total_cached_tokens: total_size,
hit_rate: {
if hits + misses > 0 {
hits as f32 / (hits + misses) as f32
} else {
0.0
}
},
}
}
pub fn clear(&self) {
let mut prefixes = self.prefixes.write();
prefixes.clear();
*self.hits.lock() = 0;
*self.misses.lock() = 0;
*self.evictions.lock() = 0;
debug!("Cleared prefix cache");
}
pub fn config(&self) -> (usize, usize) {
(self.max_prefixes, self.min_prefix_length)
}
}
impl Default for PrefixCache {
fn default() -> Self {
Self::new(100, 8) }
}
#[derive(Debug, Clone)]
pub struct PrefixCacheStats {
pub hits: usize,
pub misses: usize,
pub evictions: usize,
pub active_prefixes: usize,
pub total_cached_tokens: usize,
pub hit_rate: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockKvHandle {
tokens: usize,
device: ferrum_types::Device,
block_table: ferrum_interfaces::BlockTable,
}
impl MockKvHandle {
fn new(tokens: usize) -> Self {
Self {
tokens,
device: ferrum_types::Device::CPU,
block_table: ferrum_interfaces::BlockTable::new(16),
}
}
}
impl ferrum_interfaces::KvCacheHandle for MockKvHandle {
fn block_table(&self) -> &ferrum_interfaces::BlockTable {
&self.block_table
}
fn block_table_mut(&mut self) -> &mut ferrum_interfaces::BlockTable {
&mut self.block_table
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn device(&self) -> ferrum_types::Device {
self.device.clone()
}
fn num_tokens(&self) -> usize {
self.tokens
}
fn num_layers(&self) -> usize {
32
}
fn num_heads(&self) -> usize {
32
}
fn head_dim(&self) -> usize {
128
}
fn key_cache(
&self,
_layer: usize,
) -> ferrum_types::Result<Option<ferrum_interfaces::TensorRef>> {
Ok(None)
}
fn value_cache(
&self,
_layer: usize,
) -> ferrum_types::Result<Option<ferrum_interfaces::TensorRef>> {
Ok(None)
}
fn clone_handle(&self) -> ferrum_types::Result<Arc<dyn ferrum_interfaces::KvCacheHandle>> {
Ok(Arc::new(Self {
tokens: self.tokens,
device: self.device.clone(),
block_table: self.block_table.clone(),
}))
}
fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
ferrum_interfaces::kv_cache::CacheHandleStats {
memory_bytes: 0,
blocks_allocated: 0,
tokens_stored: self.tokens,
utilization: 0.0,
last_access: std::time::Instant::now(),
}
}
fn is_valid(&self) -> bool {
true
}
fn cache_id(&self) -> String {
"mock".to_string()
}
}
#[test]
fn test_prefix_cache_creation() {
let cache = PrefixCache::new(50, 4);
let (max_prefixes, min_len) = cache.config();
assert_eq!(max_prefixes, 50);
assert_eq!(min_len, 4);
}
#[test]
fn test_prefix_storage_and_retrieval() {
let cache = PrefixCache::new(10, 2);
let tokens = vec![TokenId::new(1), TokenId::new(2), TokenId::new(3)];
let handle = Arc::new(MockKvHandle::new(3));
cache
.store_prefix(&tokens, handle.clone(), vec![0.1; 10])
.unwrap();
let result = cache.find_prefix(&tokens);
assert!(result.is_some());
let longer_tokens = vec![
TokenId::new(1),
TokenId::new(2),
TokenId::new(3),
TokenId::new(4),
];
let result = cache.find_prefix(&longer_tokens);
assert!(result.is_some());
let (found_prefix, _, _) = result.unwrap();
assert_eq!(found_prefix.tokens(), &tokens);
}
#[test]
fn test_prefix_length_filtering() {
let cache = PrefixCache::new(10, 5);
let short_tokens = vec![TokenId::new(1), TokenId::new(2)]; let handle = Arc::new(MockKvHandle::new(2));
cache
.store_prefix(&short_tokens, handle, vec![0.1; 10])
.unwrap();
let result = cache.find_prefix(&short_tokens);
assert!(result.is_none());
}
#[test]
fn test_lru_eviction() {
let cache = PrefixCache::new(2, 1);
let tokens1 = vec![TokenId::new(1)];
let tokens2 = vec![TokenId::new(2)];
let tokens3 = vec![TokenId::new(3)];
let handle = Arc::new(MockKvHandle::new(1));
cache
.store_prefix(&tokens1, handle.clone(), vec![0.1; 10])
.unwrap();
cache
.store_prefix(&tokens2, handle.clone(), vec![0.1; 10])
.unwrap();
cache.find_prefix(&tokens1);
cache
.store_prefix(&tokens3, handle.clone(), vec![0.1; 10])
.unwrap();
assert!(cache.find_prefix(&tokens1).is_some());
assert!(cache.find_prefix(&tokens2).is_none());
assert!(cache.find_prefix(&tokens3).is_some());
}
#[test]
fn test_cache_stats() {
let cache = PrefixCache::new(10, 2);
let tokens = vec![TokenId::new(1), TokenId::new(2)];
let handle = Arc::new(MockKvHandle::new(2));
cache.store_prefix(&tokens, handle, vec![0.1; 10]).unwrap();
cache.find_prefix(&tokens);
let other_tokens = vec![TokenId::new(3), TokenId::new(4)];
cache.find_prefix(&other_tokens);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
assert_eq!(stats.active_prefixes, 1);
}
}