use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use dashmap::DashMap;
use crate::traits::Encoding;
pub struct L0Cache {
map: Arc<DashMap<String, Arc<Encoding>>>,
max_entries: usize,
hits: AtomicU64,
misses: AtomicU64,
}
impl L0Cache {
pub fn new(max_entries: usize) -> Self {
Self {
map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
max_entries,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
#[inline]
pub fn get(&self, key: &str) -> Option<Arc<Encoding>> {
match self.map.get(key) {
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(Arc::clone(entry.value()))
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
}
pub fn insert(&self, key: String, value: Encoding) {
if self.map.len() >= self.max_entries {
let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
if let Some(k) = key_to_remove {
self.map.remove(&k);
}
}
self.map.insert(key, Arc::new(value));
}
pub fn insert_arc(&self, key: String, value: Arc<Encoding>) {
if self.map.len() >= self.max_entries {
let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
if let Some(k) = key_to_remove {
self.map.remove(&k);
}
}
self.map.insert(key, value);
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn stats(&self) -> CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total_requests = hits + misses;
CacheStats {
hits,
misses,
entries: self.len(),
hit_rate: if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
},
}
}
pub fn clear(&self) {
self.map.clear();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
pub fn memory_usage(&self) -> usize {
self.len() * 2200
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use crate::{traits::Encoding, *};
fn mock_encoding(tokens: Vec<u32>) -> Encoding {
Encoding::Sp(tokens)
}
#[test]
fn test_basic_get_set() {
let cache = L0Cache::new(10);
assert!(cache.get("hello").is_none());
cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));
let result = cache.get("hello");
assert!(result.is_some());
assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
}
#[test]
fn test_eviction() {
let cache = L0Cache::new(2);
cache.insert("a".to_string(), mock_encoding(vec![1]));
cache.insert("b".to_string(), mock_encoding(vec![2]));
cache.insert("c".to_string(), mock_encoding(vec![3]));
assert_eq!(cache.len(), 2);
}
#[test]
fn test_stats() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
let _ = cache.get("missing");
let _ = cache.get("test");
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
}
#[test]
fn test_clear() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.get("test").is_none());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let cache = Arc::new(L0Cache::new(1000));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = cache.clone();
handles.push(thread::spawn(move || {
let key = format!("key_{}", i);
cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));
let result = cache_clone.get(&key);
assert!(result.is_some());
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cache.len(), 10);
}
#[test]
fn test_arc_reuse() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
let arc1 = cache.get("test").unwrap();
let arc2 = cache.get("test").unwrap();
assert!(Arc::ptr_eq(&arc1, &arc2));
}
}