Skip to main content

llm_tokenizer/cache/
l0.rs

1//! L0 Cache: Whole-string exact match cache
2//!
3//! This is the simplest and most effective cache layer.
4//! Key: (input string, add_special_tokens) → Value: full encoding result (Arc-wrapped for zero-copy cache hits)
5//!
6//! Expected hit rate: 60-90% for workloads with repeated system prompts
7//!
8//! ## Eviction strategy: Approximate LRU
9//!
10//! Uses an approximate LRU strategy (sample + evict oldest) instead of arbitrary
11//! eviction. This is critical for the main use case of caching system prompts:
12//! - System prompts are inserted early and accessed on every request
13//! - Arbitrary eviction could remove these high-value entries
14//! - FIFO would be even worse: it would evict the oldest entries first, which
15//!   are exactly the system prompts we want to keep
16//! - Full LRU requires O(n) scanning; approximate LRU via sampling gives
17//!   excellent results with O(SAMPLE_SIZE) work per eviction
18//!
19//! Each cache entry tracks a `last_accessed` timestamp (monotonic counter).
20//! On eviction, we sample a few entries and remove the least-recently-used one.
21
22use std::sync::{
23    atomic::{AtomicU64, Ordering},
24    Arc,
25};
26
27use dashmap::DashMap;
28
29use crate::traits::Encoding;
30
31/// Number of entries to sample when looking for an eviction candidate.
32/// Higher values give better LRU approximation but cost more per eviction.
33/// 8 is a good balance: P(evicting an entry in the oldest 10%) ≈ 57% even
34/// with just 8 samples from a 10K-entry cache.
35const EVICTION_SAMPLE_SIZE: usize = 8;
36
37/// A cached encoding entry with access tracking for approximate LRU eviction.
38struct CachedEntry {
39    /// The cached encoding result
40    encoding: Arc<Encoding>,
41    /// Monotonic timestamp of last access (for LRU eviction)
42    last_accessed: AtomicU64,
43}
44
45/// L0 cache implementation using DashMap for lock-free reads.
46///
47/// Uses two separate maps (one per `add_special_tokens` value) so that
48/// lookups can borrow the key as `&str` without allocating a `String`.
49///
50/// Eviction uses approximate LRU: when capacity is reached, sample a few
51/// entries and evict the one with the oldest `last_accessed` timestamp.
52pub struct L0Cache {
53    /// Cache for encode(input, add_special_tokens = false)
54    map_plain: Arc<DashMap<String, CachedEntry>>,
55    /// Cache for encode(input, add_special_tokens = true)
56    map_special: Arc<DashMap<String, CachedEntry>>,
57    /// Maximum number of entries (across both maps) before eviction
58    max_entries: usize,
59    /// Cache hit counter
60    hits: AtomicU64,
61    /// Cache miss counter
62    misses: AtomicU64,
63    /// Monotonic counter for LRU timestamps
64    access_counter: AtomicU64,
65}
66
67impl L0Cache {
68    /// Create a new L0 cache with the specified capacity
69    pub fn new(max_entries: usize) -> Self {
70        let per_map = max_entries.min(1024) / 2 + 1;
71        Self {
72            map_plain: Arc::new(DashMap::with_capacity(per_map)),
73            map_special: Arc::new(DashMap::with_capacity(per_map)),
74            max_entries,
75            hits: AtomicU64::new(0),
76            misses: AtomicU64::new(0),
77            access_counter: AtomicU64::new(0),
78        }
79    }
80
81    #[inline]
82    fn map_for(&self, add_special_tokens: bool) -> &DashMap<String, CachedEntry> {
83        if add_special_tokens {
84            &self.map_special
85        } else {
86            &self.map_plain
87        }
88    }
89
90    /// Get the next monotonic timestamp for access tracking.
91    #[inline]
92    fn next_timestamp(&self) -> u64 {
93        self.access_counter.fetch_add(1, Ordering::Relaxed)
94    }
95
96    /// Get an encoding from the cache (returns Arc for zero-copy access).
97    /// Zero-allocation on the lookup path.
98    #[inline]
99    pub fn get(&self, key: &str, add_special_tokens: bool) -> Option<Arc<Encoding>> {
100        match self.map_for(add_special_tokens).get(key) {
101            Some(entry) => {
102                self.hits.fetch_add(1, Ordering::Relaxed);
103                // Update last-accessed timestamp for LRU tracking.
104                // This is a single atomic store -- no contention on the map lock.
105                let ts = self.next_timestamp();
106                entry.value().last_accessed.store(ts, Ordering::Relaxed);
107                Some(Arc::clone(&entry.value().encoding))
108            }
109            None => {
110                self.misses.fetch_add(1, Ordering::Relaxed);
111                None
112            }
113        }
114    }
115
116    /// Evict the least-recently-used entry (approximately) if total capacity is reached.
117    ///
118    /// Uses approximate LRU via sampling: picks EVICTION_SAMPLE_SIZE entries from
119    /// the larger map and evicts the one with the smallest (oldest) `last_accessed`
120    /// timestamp. This avoids scanning all entries while still providing good LRU
121    /// behavior in practice.
122    fn maybe_evict(&self) {
123        if self.len() >= self.max_entries {
124            let victim_map = if self.map_plain.len() >= self.map_special.len() {
125                &self.map_plain
126            } else {
127                &self.map_special
128            };
129
130            // Sample up to EVICTION_SAMPLE_SIZE entries and find the oldest.
131            // Scope the iterator so all DashMap shard read-locks are released
132            // before we call remove().
133            let key_to_remove = {
134                let mut oldest_key: Option<String> = None;
135                let mut oldest_ts = u64::MAX;
136
137                for (i, entry) in victim_map.iter().enumerate() {
138                    let ts = entry.value().last_accessed.load(Ordering::Relaxed);
139                    if ts < oldest_ts {
140                        oldest_ts = ts;
141                        oldest_key = Some(entry.key().clone());
142                    }
143                    if i + 1 >= EVICTION_SAMPLE_SIZE {
144                        break;
145                    }
146                }
147                oldest_key
148            };
149
150            if let Some(k) = key_to_remove {
151                victim_map.remove(&k);
152            }
153        }
154    }
155
156    /// Insert an encoding into the cache
157    pub fn insert(&self, key: String, add_special_tokens: bool, value: Encoding) {
158        self.maybe_evict();
159        let ts = self.next_timestamp();
160        let entry = CachedEntry {
161            encoding: Arc::new(value),
162            last_accessed: AtomicU64::new(ts),
163        };
164        self.map_for(add_special_tokens).insert(key, entry);
165    }
166
167    /// Insert a pre-wrapped Arc encoding into the cache (avoids double-wrapping)
168    pub fn insert_arc(&self, key: String, add_special_tokens: bool, value: Arc<Encoding>) {
169        self.maybe_evict();
170        let ts = self.next_timestamp();
171        let entry = CachedEntry {
172            encoding: value,
173            last_accessed: AtomicU64::new(ts),
174        };
175        self.map_for(add_special_tokens).insert(key, entry);
176    }
177
178    /// Get the current number of entries in the cache
179    pub fn len(&self) -> usize {
180        self.map_plain.len() + self.map_special.len()
181    }
182
183    /// Check if the cache is empty
184    pub fn is_empty(&self) -> bool {
185        self.map_plain.is_empty() && self.map_special.is_empty()
186    }
187
188    /// Get cache statistics
189    pub fn stats(&self) -> CacheStats {
190        let hits = self.hits.load(Ordering::Relaxed);
191        let misses = self.misses.load(Ordering::Relaxed);
192        let total_requests = hits + misses;
193
194        CacheStats {
195            hits,
196            misses,
197            entries: self.len(),
198            hit_rate: if total_requests > 0 {
199                hits as f64 / total_requests as f64
200            } else {
201                0.0
202            },
203        }
204    }
205
206    /// Clear the cache
207    pub fn clear(&self) {
208        self.map_plain.clear();
209        self.map_special.clear();
210        self.hits.store(0, Ordering::Relaxed);
211        self.misses.store(0, Ordering::Relaxed);
212        self.access_counter.store(0, Ordering::Relaxed);
213    }
214
215    /// Estimate memory usage in bytes
216    pub fn memory_usage(&self) -> usize {
217        // Rough estimate:
218        // - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead
219        // - Average: ~2.2KB per entry
220        self.len() * 2200
221    }
222}
223
224#[derive(Debug, Clone)]
225pub struct CacheStats {
226    pub hits: u64,
227    pub misses: u64,
228    pub entries: usize,
229    pub hit_rate: f64,
230}
231
232#[cfg(test)]
233mod tests {
234    use crate::{traits::Encoding, *};
235
236    fn mock_encoding(tokens: Vec<u32>) -> Encoding {
237        Encoding::Plain(tokens)
238    }
239
240    #[test]
241    fn test_basic_get_set() {
242        let cache = L0Cache::new(10);
243
244        // Miss
245        assert!(cache.get("hello", false).is_none());
246
247        // Insert
248        cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
249
250        // Hit
251        let result = cache.get("hello", false);
252        assert!(result.is_some());
253        assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
254    }
255
256    #[test]
257    fn test_add_special_tokens_flag_separates_entries() {
258        let cache = L0Cache::new(10);
259
260        cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
261        cache.insert(
262            "hello".to_string(),
263            true,
264            mock_encoding(vec![100, 1, 2, 3, 101]),
265        );
266
267        // Different flags should return different results
268        let without = cache.get("hello", false).unwrap();
269        let with = cache.get("hello", true).unwrap();
270        assert_eq!(without.token_ids(), &[1, 2, 3]);
271        assert_eq!(with.token_ids(), &[100, 1, 2, 3, 101]);
272        assert_eq!(cache.len(), 2);
273    }
274
275    #[test]
276    fn test_eviction() {
277        let cache = L0Cache::new(2);
278
279        cache.insert("a".to_string(), false, mock_encoding(vec![1]));
280        cache.insert("b".to_string(), false, mock_encoding(vec![2]));
281
282        // Should evict when adding third
283        cache.insert("c".to_string(), false, mock_encoding(vec![3]));
284
285        // Cache should have exactly 2 entries
286        assert_eq!(cache.len(), 2);
287    }
288
289    #[test]
290    fn test_eviction_across_maps() {
291        let cache = L0Cache::new(2);
292
293        // Fill up map_plain to capacity
294        cache.insert("a".to_string(), false, mock_encoding(vec![1]));
295        cache.insert("b".to_string(), false, mock_encoding(vec![2]));
296        assert_eq!(cache.len(), 2);
297
298        // Insert into map_special — should evict from map_plain (the larger map)
299        cache.insert("c".to_string(), true, mock_encoding(vec![3]));
300        assert_eq!(cache.len(), 2, "total entries must not exceed max_entries");
301    }
302
303    #[test]
304    fn test_stats() {
305        let cache = L0Cache::new(10);
306
307        cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
308
309        // 1 miss
310        let _ = cache.get("missing", false);
311
312        // 1 hit
313        let _ = cache.get("test", false);
314
315        let stats = cache.stats();
316        assert_eq!(stats.hits, 1);
317        assert_eq!(stats.misses, 1);
318        assert_eq!(stats.hit_rate, 0.5);
319    }
320
321    #[test]
322    fn test_clear() {
323        let cache = L0Cache::new(10);
324
325        cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
326        assert_eq!(cache.len(), 1);
327
328        cache.clear();
329        assert_eq!(cache.len(), 0);
330        assert!(cache.get("test", false).is_none());
331    }
332
333    #[test]
334    fn test_concurrent_access() {
335        use std::thread;
336
337        let cache = Arc::new(L0Cache::new(1000));
338        let mut handles = vec![];
339
340        // Spawn 10 threads
341        for i in 0..10 {
342            let cache_clone = cache.clone();
343            handles.push(thread::spawn(move || {
344                let key = format!("key_{i}");
345                cache_clone.insert(key.clone(), false, mock_encoding(vec![i as u32]));
346
347                let result = cache_clone.get(&key, false);
348                assert!(result.is_some());
349            }));
350        }
351
352        for handle in handles {
353            handle.join().unwrap();
354        }
355
356        assert_eq!(cache.len(), 10);
357    }
358
359    #[test]
360    fn test_arc_reuse() {
361        let cache = L0Cache::new(10);
362        cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
363
364        let arc1 = cache.get("test", false).unwrap();
365        let arc2 = cache.get("test", false).unwrap();
366
367        // Both should point to the same allocation
368        assert!(Arc::ptr_eq(&arc1, &arc2));
369    }
370
371    /// Verify that approximate LRU eviction keeps frequently-accessed entries
372    /// and evicts stale ones. This simulates the system-prompt use case:
373    /// a "system_prompt" entry is inserted first and accessed on every request,
374    /// while one-off queries are inserted and never accessed again.
375    /// Under the old arbitrary eviction, the system prompt could be evicted.
376    /// Under approximate LRU, it should survive because its last_accessed
377    /// timestamp is continuously refreshed by each get().
378    #[test]
379    fn test_lru_eviction_keeps_frequently_accessed() {
380        // Small cache: capacity 4
381        let cache = L0Cache::new(4);
382
383        // Insert a "system prompt" — the high-value entry we want to keep
384        cache.insert(
385            "system_prompt".to_string(),
386            false,
387            mock_encoding(vec![10, 20, 30]),
388        );
389
390        // Insert 3 one-off queries (fills cache to capacity = 4)
391        cache.insert("query_1".to_string(), false, mock_encoding(vec![1]));
392        cache.insert("query_2".to_string(), false, mock_encoding(vec![2]));
393        cache.insert("query_3".to_string(), false, mock_encoding(vec![3]));
394        assert_eq!(cache.len(), 4);
395
396        // Simulate realistic workload: each new request accesses the system
397        // prompt (cache hit) and then inserts a new one-off query.
398        // This interleaved access pattern keeps the system prompt's timestamp
399        // fresh relative to all the one-off queries.
400        for i in 4..12 {
401            // Every request hits the system prompt first (like a real API server)
402            let result = cache.get("system_prompt", false);
403            assert!(
404                result.is_some(),
405                "system_prompt should still be in the cache after query_{} insertion",
406                i - 1
407            );
408
409            // Then a new one-off query is inserted, triggering eviction
410            cache.insert(format!("query_{i}"), false, mock_encoding(vec![i]));
411        }
412
413        // The system prompt should still be present — LRU protects it
414        // because it was accessed more recently than the eviction victims.
415        let system_prompt = cache.get("system_prompt", false);
416        assert!(
417            system_prompt.is_some(),
418            "system_prompt should survive eviction because it was recently accessed"
419        );
420        assert_eq!(system_prompt.unwrap().token_ids(), &[10, 20, 30]);
421
422        // Cache size should still be at capacity
423        assert!(cache.len() <= 4);
424
425        // The early one-off queries should all be evicted by now
426        let early_queries_remaining = (1..=3)
427            .filter(|i| cache.get(&format!("query_{i}"), false).is_some())
428            .count();
429        assert_eq!(
430            early_queries_remaining, 0,
431            "all early one-off queries should have been evicted"
432        );
433    }
434
435    /// Verify that entries without any get() access are evicted before
436    /// entries that have been accessed, even when inserted in the same order.
437    #[test]
438    fn test_lru_eviction_prefers_untouched_entries() {
439        let cache = L0Cache::new(3);
440
441        // Insert three entries
442        cache.insert("keep_me".to_string(), false, mock_encoding(vec![1]));
443        cache.insert("stale_1".to_string(), false, mock_encoding(vec![2]));
444        cache.insert("stale_2".to_string(), false, mock_encoding(vec![3]));
445
446        // Access "keep_me" to make it the most recently used
447        let _ = cache.get("keep_me", false);
448
449        // Insert a new entry, forcing eviction. The eviction should pick
450        // one of the stale entries (stale_1 or stale_2) rather than keep_me.
451        cache.insert("new_entry".to_string(), false, mock_encoding(vec![4]));
452
453        assert_eq!(cache.len(), 3);
454
455        // "keep_me" should survive because it was accessed
456        assert!(
457            cache.get("keep_me", false).is_some(),
458            "keep_me should survive eviction because it was recently accessed"
459        );
460
461        // At least one of the stale entries should have been evicted
462        let stale_remaining = ["stale_1", "stale_2"]
463            .iter()
464            .filter(|k| cache.get(k, false).is_some())
465            .count();
466        assert!(
467            stale_remaining < 2,
468            "at least one stale entry should have been evicted"
469        );
470    }
471}