use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use dashmap::DashMap;
use crate::traits::Encoding;
const EVICTION_SAMPLE_SIZE: usize = 8;
struct CachedEntry {
encoding: Arc<Encoding>,
last_accessed: AtomicU64,
}
pub struct L0Cache {
map_plain: Arc<DashMap<String, CachedEntry>>,
map_special: Arc<DashMap<String, CachedEntry>>,
max_entries: usize,
hits: AtomicU64,
misses: AtomicU64,
access_counter: AtomicU64,
}
impl L0Cache {
pub fn new(max_entries: usize) -> Self {
let per_map = max_entries.min(1024) / 2 + 1;
Self {
map_plain: Arc::new(DashMap::with_capacity(per_map)),
map_special: Arc::new(DashMap::with_capacity(per_map)),
max_entries,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
access_counter: AtomicU64::new(0),
}
}
#[inline]
fn map_for(&self, add_special_tokens: bool) -> &DashMap<String, CachedEntry> {
if add_special_tokens {
&self.map_special
} else {
&self.map_plain
}
}
#[inline]
fn next_timestamp(&self) -> u64 {
self.access_counter.fetch_add(1, Ordering::Relaxed)
}
#[inline]
pub fn get(&self, key: &str, add_special_tokens: bool) -> Option<Arc<Encoding>> {
match self.map_for(add_special_tokens).get(key) {
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
let ts = self.next_timestamp();
entry.value().last_accessed.store(ts, Ordering::Relaxed);
Some(Arc::clone(&entry.value().encoding))
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
}
fn maybe_evict(&self) {
if self.len() >= self.max_entries {
let victim_map = if self.map_plain.len() >= self.map_special.len() {
&self.map_plain
} else {
&self.map_special
};
let key_to_remove = {
let mut oldest_key: Option<String> = None;
let mut oldest_ts = u64::MAX;
for (i, entry) in victim_map.iter().enumerate() {
let ts = entry.value().last_accessed.load(Ordering::Relaxed);
if ts < oldest_ts {
oldest_ts = ts;
oldest_key = Some(entry.key().clone());
}
if i + 1 >= EVICTION_SAMPLE_SIZE {
break;
}
}
oldest_key
};
if let Some(k) = key_to_remove {
victim_map.remove(&k);
}
}
}
pub fn insert(&self, key: String, add_special_tokens: bool, value: Encoding) {
self.maybe_evict();
let ts = self.next_timestamp();
let entry = CachedEntry {
encoding: Arc::new(value),
last_accessed: AtomicU64::new(ts),
};
self.map_for(add_special_tokens).insert(key, entry);
}
pub fn insert_arc(&self, key: String, add_special_tokens: bool, value: Arc<Encoding>) {
self.maybe_evict();
let ts = self.next_timestamp();
let entry = CachedEntry {
encoding: value,
last_accessed: AtomicU64::new(ts),
};
self.map_for(add_special_tokens).insert(key, entry);
}
pub fn len(&self) -> usize {
self.map_plain.len() + self.map_special.len()
}
pub fn is_empty(&self) -> bool {
self.map_plain.is_empty() && self.map_special.is_empty()
}
pub fn stats(&self) -> CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total_requests = hits + misses;
CacheStats {
hits,
misses,
entries: self.len(),
hit_rate: if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
},
}
}
pub fn clear(&self) {
self.map_plain.clear();
self.map_special.clear();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.access_counter.store(0, Ordering::Relaxed);
}
pub fn memory_usage(&self) -> usize {
self.len() * 2200
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use crate::{traits::Encoding, *};
fn mock_encoding(tokens: Vec<u32>) -> Encoding {
Encoding::Plain(tokens)
}
#[test]
fn test_basic_get_set() {
let cache = L0Cache::new(10);
assert!(cache.get("hello", false).is_none());
cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
let result = cache.get("hello", false);
assert!(result.is_some());
assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
}
#[test]
fn test_add_special_tokens_flag_separates_entries() {
let cache = L0Cache::new(10);
cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
cache.insert(
"hello".to_string(),
true,
mock_encoding(vec![100, 1, 2, 3, 101]),
);
let without = cache.get("hello", false).unwrap();
let with = cache.get("hello", true).unwrap();
assert_eq!(without.token_ids(), &[1, 2, 3]);
assert_eq!(with.token_ids(), &[100, 1, 2, 3, 101]);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_eviction() {
let cache = L0Cache::new(2);
cache.insert("a".to_string(), false, mock_encoding(vec![1]));
cache.insert("b".to_string(), false, mock_encoding(vec![2]));
cache.insert("c".to_string(), false, mock_encoding(vec![3]));
assert_eq!(cache.len(), 2);
}
#[test]
fn test_eviction_across_maps() {
let cache = L0Cache::new(2);
cache.insert("a".to_string(), false, mock_encoding(vec![1]));
cache.insert("b".to_string(), false, mock_encoding(vec![2]));
assert_eq!(cache.len(), 2);
cache.insert("c".to_string(), true, mock_encoding(vec![3]));
assert_eq!(cache.len(), 2, "total entries must not exceed max_entries");
}
#[test]
fn test_stats() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
let _ = cache.get("missing", false);
let _ = cache.get("test", false);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
}
#[test]
fn test_clear() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.get("test", false).is_none());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let cache = Arc::new(L0Cache::new(1000));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = cache.clone();
handles.push(thread::spawn(move || {
let key = format!("key_{i}");
cache_clone.insert(key.clone(), false, mock_encoding(vec![i as u32]));
let result = cache_clone.get(&key, false);
assert!(result.is_some());
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cache.len(), 10);
}
#[test]
fn test_arc_reuse() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
let arc1 = cache.get("test", false).unwrap();
let arc2 = cache.get("test", false).unwrap();
assert!(Arc::ptr_eq(&arc1, &arc2));
}
#[test]
fn test_lru_eviction_keeps_frequently_accessed() {
let cache = L0Cache::new(4);
cache.insert(
"system_prompt".to_string(),
false,
mock_encoding(vec![10, 20, 30]),
);
cache.insert("query_1".to_string(), false, mock_encoding(vec![1]));
cache.insert("query_2".to_string(), false, mock_encoding(vec![2]));
cache.insert("query_3".to_string(), false, mock_encoding(vec![3]));
assert_eq!(cache.len(), 4);
for i in 4..12 {
let result = cache.get("system_prompt", false);
assert!(
result.is_some(),
"system_prompt should still be in the cache after query_{} insertion",
i - 1
);
cache.insert(format!("query_{i}"), false, mock_encoding(vec![i]));
}
let system_prompt = cache.get("system_prompt", false);
assert!(
system_prompt.is_some(),
"system_prompt should survive eviction because it was recently accessed"
);
assert_eq!(system_prompt.unwrap().token_ids(), &[10, 20, 30]);
assert!(cache.len() <= 4);
let early_queries_remaining = (1..=3)
.filter(|i| cache.get(&format!("query_{i}"), false).is_some())
.count();
assert_eq!(
early_queries_remaining, 0,
"all early one-off queries should have been evicted"
);
}
#[test]
fn test_lru_eviction_prefers_untouched_entries() {
let cache = L0Cache::new(3);
cache.insert("keep_me".to_string(), false, mock_encoding(vec![1]));
cache.insert("stale_1".to_string(), false, mock_encoding(vec![2]));
cache.insert("stale_2".to_string(), false, mock_encoding(vec![3]));
let _ = cache.get("keep_me", false);
cache.insert("new_entry".to_string(), false, mock_encoding(vec![4]));
assert_eq!(cache.len(), 3);
assert!(
cache.get("keep_me", false).is_some(),
"keep_me should survive eviction because it was recently accessed"
);
let stale_remaining = ["stale_1", "stale_2"]
.iter()
.filter(|k| cache.get(k, false).is_some())
.count();
assert!(
stale_remaining < 2,
"at least one stale entry should have been evicted"
);
}
}