llm_tokenizer/cache/
l0.rs

1//! L0 Cache: Whole-string exact match cache
2//!
3//! This is the simplest and most effective cache layer.
4//! Key: input string → Value: full encoding result (Arc-wrapped for zero-copy cache hits)
5//!
6//! Expected hit rate: 60-90% for workloads with repeated system prompts
7
8use std::sync::{
9    atomic::{AtomicU64, Ordering},
10    Arc,
11};
12
13use dashmap::DashMap;
14
15use crate::traits::Encoding;
16
17/// L0 cache implementation using DashMap for lock-free reads
18/// Uses Arc<Encoding> internally to provide zero-copy cache hits
19pub struct L0Cache {
20    /// The cache map: input string → Arc-wrapped encoding for cheap cloning
21    map: Arc<DashMap<String, Arc<Encoding>>>,
22    /// Maximum number of entries before eviction
23    max_entries: usize,
24    /// Cache hit counter
25    hits: AtomicU64,
26    /// Cache miss counter
27    misses: AtomicU64,
28}
29
30impl L0Cache {
31    /// Create a new L0 cache with the specified capacity
32    pub fn new(max_entries: usize) -> Self {
33        Self {
34            map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
35            max_entries,
36            hits: AtomicU64::new(0),
37            misses: AtomicU64::new(0),
38        }
39    }
40
41    /// Get an encoding from the cache (returns Arc for zero-copy access)
42    #[inline]
43    pub fn get(&self, key: &str) -> Option<Arc<Encoding>> {
44        match self.map.get(key) {
45            Some(entry) => {
46                self.hits.fetch_add(1, Ordering::Relaxed);
47                // Arc::clone is cheap (just increment reference count)
48                Some(Arc::clone(entry.value()))
49            }
50            None => {
51                self.misses.fetch_add(1, Ordering::Relaxed);
52                None
53            }
54        }
55    }
56
57    /// Insert an encoding into the cache
58    pub fn insert(&self, key: String, value: Encoding) {
59        // Simple eviction: if we're at capacity, remove a random entry
60        // DashMap doesn't support LRU directly, so we use a simple strategy
61        if self.map.len() >= self.max_entries {
62            let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
63
64            // Now remove it
65            if let Some(k) = key_to_remove {
66                self.map.remove(&k);
67            }
68        }
69
70        self.map.insert(key, Arc::new(value));
71    }
72
73    /// Insert a pre-wrapped Arc encoding into the cache (avoids double-wrapping)
74    pub fn insert_arc(&self, key: String, value: Arc<Encoding>) {
75        if self.map.len() >= self.max_entries {
76            let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
77            if let Some(k) = key_to_remove {
78                self.map.remove(&k);
79            }
80        }
81        self.map.insert(key, value);
82    }
83
84    /// Get the current number of entries in the cache
85    pub fn len(&self) -> usize {
86        self.map.len()
87    }
88
89    /// Check if the cache is empty
90    pub fn is_empty(&self) -> bool {
91        self.map.is_empty()
92    }
93
94    /// Get cache statistics
95    pub fn stats(&self) -> CacheStats {
96        let hits = self.hits.load(Ordering::Relaxed);
97        let misses = self.misses.load(Ordering::Relaxed);
98        let total_requests = hits + misses;
99
100        CacheStats {
101            hits,
102            misses,
103            entries: self.len(),
104            hit_rate: if total_requests > 0 {
105                hits as f64 / total_requests as f64
106            } else {
107                0.0
108            },
109        }
110    }
111
112    /// Clear the cache
113    pub fn clear(&self) {
114        self.map.clear();
115        self.hits.store(0, Ordering::Relaxed);
116        self.misses.store(0, Ordering::Relaxed);
117    }
118
119    /// Estimate memory usage in bytes
120    pub fn memory_usage(&self) -> usize {
121        // Rough estimate:
122        // - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead
123        // - Average: ~2.2KB per entry
124        self.len() * 2200
125    }
126}
127
128#[derive(Debug, Clone)]
129pub struct CacheStats {
130    pub hits: u64,
131    pub misses: u64,
132    pub entries: usize,
133    pub hit_rate: f64,
134}
135
136#[cfg(test)]
137mod tests {
138    use crate::{traits::Encoding, *};
139
140    fn mock_encoding(tokens: Vec<u32>) -> Encoding {
141        Encoding::Sp(tokens)
142    }
143
144    #[test]
145    fn test_basic_get_set() {
146        let cache = L0Cache::new(10);
147
148        // Miss
149        assert!(cache.get("hello").is_none());
150
151        // Insert
152        cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));
153
154        // Hit - now returns Arc<Encoding>
155        let result = cache.get("hello");
156        assert!(result.is_some());
157        assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
158    }
159
160    #[test]
161    fn test_eviction() {
162        let cache = L0Cache::new(2);
163
164        cache.insert("a".to_string(), mock_encoding(vec![1]));
165        cache.insert("b".to_string(), mock_encoding(vec![2]));
166
167        // Should evict when adding third
168        cache.insert("c".to_string(), mock_encoding(vec![3]));
169
170        // Cache should have exactly 2 entries
171        assert_eq!(cache.len(), 2);
172    }
173
174    #[test]
175    fn test_stats() {
176        let cache = L0Cache::new(10);
177
178        cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
179
180        // 1 miss (initial get that returned None)
181        let _ = cache.get("missing");
182
183        // 1 hit
184        let _ = cache.get("test");
185
186        let stats = cache.stats();
187        assert_eq!(stats.hits, 1);
188        assert_eq!(stats.misses, 1);
189        assert_eq!(stats.hit_rate, 0.5);
190    }
191
192    #[test]
193    fn test_clear() {
194        let cache = L0Cache::new(10);
195
196        cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
197        assert_eq!(cache.len(), 1);
198
199        cache.clear();
200        assert_eq!(cache.len(), 0);
201        assert!(cache.get("test").is_none());
202    }
203
204    #[test]
205    fn test_concurrent_access() {
206        use std::thread;
207
208        let cache = Arc::new(L0Cache::new(1000));
209        let mut handles = vec![];
210
211        // Spawn 10 threads
212        for i in 0..10 {
213            let cache_clone = cache.clone();
214            handles.push(thread::spawn(move || {
215                // Each thread inserts and reads
216                let key = format!("key_{}", i);
217                cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));
218
219                // Read it back
220                let result = cache_clone.get(&key);
221                assert!(result.is_some());
222            }));
223        }
224
225        for handle in handles {
226            handle.join().unwrap();
227        }
228
229        // Should have 10 entries
230        assert_eq!(cache.len(), 10);
231    }
232
233    #[test]
234    fn test_arc_reuse() {
235        // Test that multiple gets return the same Arc (reference counting)
236        let cache = L0Cache::new(10);
237        cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
238
239        let arc1 = cache.get("test").unwrap();
240        let arc2 = cache.get("test").unwrap();
241
242        // Both should point to the same allocation
243        assert!(Arc::ptr_eq(&arc1, &arc2));
244    }
245}