1use std::sync::{
23 atomic::{AtomicU64, Ordering},
24 Arc,
25};
26
27use dashmap::DashMap;
28
29use crate::traits::Encoding;
30
31const EVICTION_SAMPLE_SIZE: usize = 8;
36
37struct CachedEntry {
39 encoding: Arc<Encoding>,
41 last_accessed: AtomicU64,
43}
44
45pub struct L0Cache {
53 map_plain: Arc<DashMap<String, CachedEntry>>,
55 map_special: Arc<DashMap<String, CachedEntry>>,
57 max_entries: usize,
59 hits: AtomicU64,
61 misses: AtomicU64,
63 access_counter: AtomicU64,
65}
66
67impl L0Cache {
68 pub fn new(max_entries: usize) -> Self {
70 let per_map = max_entries.min(1024) / 2 + 1;
71 Self {
72 map_plain: Arc::new(DashMap::with_capacity(per_map)),
73 map_special: Arc::new(DashMap::with_capacity(per_map)),
74 max_entries,
75 hits: AtomicU64::new(0),
76 misses: AtomicU64::new(0),
77 access_counter: AtomicU64::new(0),
78 }
79 }
80
81 #[inline]
82 fn map_for(&self, add_special_tokens: bool) -> &DashMap<String, CachedEntry> {
83 if add_special_tokens {
84 &self.map_special
85 } else {
86 &self.map_plain
87 }
88 }
89
90 #[inline]
92 fn next_timestamp(&self) -> u64 {
93 self.access_counter.fetch_add(1, Ordering::Relaxed)
94 }
95
96 #[inline]
99 pub fn get(&self, key: &str, add_special_tokens: bool) -> Option<Arc<Encoding>> {
100 match self.map_for(add_special_tokens).get(key) {
101 Some(entry) => {
102 self.hits.fetch_add(1, Ordering::Relaxed);
103 let ts = self.next_timestamp();
106 entry.value().last_accessed.store(ts, Ordering::Relaxed);
107 Some(Arc::clone(&entry.value().encoding))
108 }
109 None => {
110 self.misses.fetch_add(1, Ordering::Relaxed);
111 None
112 }
113 }
114 }
115
116 fn maybe_evict(&self) {
123 if self.len() >= self.max_entries {
124 let victim_map = if self.map_plain.len() >= self.map_special.len() {
125 &self.map_plain
126 } else {
127 &self.map_special
128 };
129
130 let key_to_remove = {
134 let mut oldest_key: Option<String> = None;
135 let mut oldest_ts = u64::MAX;
136
137 for (i, entry) in victim_map.iter().enumerate() {
138 let ts = entry.value().last_accessed.load(Ordering::Relaxed);
139 if ts < oldest_ts {
140 oldest_ts = ts;
141 oldest_key = Some(entry.key().clone());
142 }
143 if i + 1 >= EVICTION_SAMPLE_SIZE {
144 break;
145 }
146 }
147 oldest_key
148 };
149
150 if let Some(k) = key_to_remove {
151 victim_map.remove(&k);
152 }
153 }
154 }
155
156 pub fn insert(&self, key: String, add_special_tokens: bool, value: Encoding) {
158 self.maybe_evict();
159 let ts = self.next_timestamp();
160 let entry = CachedEntry {
161 encoding: Arc::new(value),
162 last_accessed: AtomicU64::new(ts),
163 };
164 self.map_for(add_special_tokens).insert(key, entry);
165 }
166
167 pub fn insert_arc(&self, key: String, add_special_tokens: bool, value: Arc<Encoding>) {
169 self.maybe_evict();
170 let ts = self.next_timestamp();
171 let entry = CachedEntry {
172 encoding: value,
173 last_accessed: AtomicU64::new(ts),
174 };
175 self.map_for(add_special_tokens).insert(key, entry);
176 }
177
178 pub fn len(&self) -> usize {
180 self.map_plain.len() + self.map_special.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.map_plain.is_empty() && self.map_special.is_empty()
186 }
187
188 pub fn stats(&self) -> CacheStats {
190 let hits = self.hits.load(Ordering::Relaxed);
191 let misses = self.misses.load(Ordering::Relaxed);
192 let total_requests = hits + misses;
193
194 CacheStats {
195 hits,
196 misses,
197 entries: self.len(),
198 hit_rate: if total_requests > 0 {
199 hits as f64 / total_requests as f64
200 } else {
201 0.0
202 },
203 }
204 }
205
206 pub fn clear(&self) {
208 self.map_plain.clear();
209 self.map_special.clear();
210 self.hits.store(0, Ordering::Relaxed);
211 self.misses.store(0, Ordering::Relaxed);
212 self.access_counter.store(0, Ordering::Relaxed);
213 }
214
215 pub fn memory_usage(&self) -> usize {
217 self.len() * 2200
221 }
222}
223
224#[derive(Debug, Clone)]
225pub struct CacheStats {
226 pub hits: u64,
227 pub misses: u64,
228 pub entries: usize,
229 pub hit_rate: f64,
230}
231
232#[cfg(test)]
233mod tests {
234 use crate::{traits::Encoding, *};
235
236 fn mock_encoding(tokens: Vec<u32>) -> Encoding {
237 Encoding::Plain(tokens)
238 }
239
240 #[test]
241 fn test_basic_get_set() {
242 let cache = L0Cache::new(10);
243
244 assert!(cache.get("hello", false).is_none());
246
247 cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
249
250 let result = cache.get("hello", false);
252 assert!(result.is_some());
253 assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
254 }
255
256 #[test]
257 fn test_add_special_tokens_flag_separates_entries() {
258 let cache = L0Cache::new(10);
259
260 cache.insert("hello".to_string(), false, mock_encoding(vec![1, 2, 3]));
261 cache.insert(
262 "hello".to_string(),
263 true,
264 mock_encoding(vec![100, 1, 2, 3, 101]),
265 );
266
267 let without = cache.get("hello", false).unwrap();
269 let with = cache.get("hello", true).unwrap();
270 assert_eq!(without.token_ids(), &[1, 2, 3]);
271 assert_eq!(with.token_ids(), &[100, 1, 2, 3, 101]);
272 assert_eq!(cache.len(), 2);
273 }
274
275 #[test]
276 fn test_eviction() {
277 let cache = L0Cache::new(2);
278
279 cache.insert("a".to_string(), false, mock_encoding(vec![1]));
280 cache.insert("b".to_string(), false, mock_encoding(vec![2]));
281
282 cache.insert("c".to_string(), false, mock_encoding(vec![3]));
284
285 assert_eq!(cache.len(), 2);
287 }
288
289 #[test]
290 fn test_eviction_across_maps() {
291 let cache = L0Cache::new(2);
292
293 cache.insert("a".to_string(), false, mock_encoding(vec![1]));
295 cache.insert("b".to_string(), false, mock_encoding(vec![2]));
296 assert_eq!(cache.len(), 2);
297
298 cache.insert("c".to_string(), true, mock_encoding(vec![3]));
300 assert_eq!(cache.len(), 2, "total entries must not exceed max_entries");
301 }
302
303 #[test]
304 fn test_stats() {
305 let cache = L0Cache::new(10);
306
307 cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
308
309 let _ = cache.get("missing", false);
311
312 let _ = cache.get("test", false);
314
315 let stats = cache.stats();
316 assert_eq!(stats.hits, 1);
317 assert_eq!(stats.misses, 1);
318 assert_eq!(stats.hit_rate, 0.5);
319 }
320
321 #[test]
322 fn test_clear() {
323 let cache = L0Cache::new(10);
324
325 cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
326 assert_eq!(cache.len(), 1);
327
328 cache.clear();
329 assert_eq!(cache.len(), 0);
330 assert!(cache.get("test", false).is_none());
331 }
332
333 #[test]
334 fn test_concurrent_access() {
335 use std::thread;
336
337 let cache = Arc::new(L0Cache::new(1000));
338 let mut handles = vec![];
339
340 for i in 0..10 {
342 let cache_clone = cache.clone();
343 handles.push(thread::spawn(move || {
344 let key = format!("key_{i}");
345 cache_clone.insert(key.clone(), false, mock_encoding(vec![i as u32]));
346
347 let result = cache_clone.get(&key, false);
348 assert!(result.is_some());
349 }));
350 }
351
352 for handle in handles {
353 handle.join().unwrap();
354 }
355
356 assert_eq!(cache.len(), 10);
357 }
358
359 #[test]
360 fn test_arc_reuse() {
361 let cache = L0Cache::new(10);
362 cache.insert("test".to_string(), false, mock_encoding(vec![1, 2, 3]));
363
364 let arc1 = cache.get("test", false).unwrap();
365 let arc2 = cache.get("test", false).unwrap();
366
367 assert!(Arc::ptr_eq(&arc1, &arc2));
369 }
370
371 #[test]
379 fn test_lru_eviction_keeps_frequently_accessed() {
380 let cache = L0Cache::new(4);
382
383 cache.insert(
385 "system_prompt".to_string(),
386 false,
387 mock_encoding(vec![10, 20, 30]),
388 );
389
390 cache.insert("query_1".to_string(), false, mock_encoding(vec![1]));
392 cache.insert("query_2".to_string(), false, mock_encoding(vec![2]));
393 cache.insert("query_3".to_string(), false, mock_encoding(vec![3]));
394 assert_eq!(cache.len(), 4);
395
396 for i in 4..12 {
401 let result = cache.get("system_prompt", false);
403 assert!(
404 result.is_some(),
405 "system_prompt should still be in the cache after query_{} insertion",
406 i - 1
407 );
408
409 cache.insert(format!("query_{i}"), false, mock_encoding(vec![i]));
411 }
412
413 let system_prompt = cache.get("system_prompt", false);
416 assert!(
417 system_prompt.is_some(),
418 "system_prompt should survive eviction because it was recently accessed"
419 );
420 assert_eq!(system_prompt.unwrap().token_ids(), &[10, 20, 30]);
421
422 assert!(cache.len() <= 4);
424
425 let early_queries_remaining = (1..=3)
427 .filter(|i| cache.get(&format!("query_{i}"), false).is_some())
428 .count();
429 assert_eq!(
430 early_queries_remaining, 0,
431 "all early one-off queries should have been evicted"
432 );
433 }
434
435 #[test]
438 fn test_lru_eviction_prefers_untouched_entries() {
439 let cache = L0Cache::new(3);
440
441 cache.insert("keep_me".to_string(), false, mock_encoding(vec![1]));
443 cache.insert("stale_1".to_string(), false, mock_encoding(vec![2]));
444 cache.insert("stale_2".to_string(), false, mock_encoding(vec![3]));
445
446 let _ = cache.get("keep_me", false);
448
449 cache.insert("new_entry".to_string(), false, mock_encoding(vec![4]));
452
453 assert_eq!(cache.len(), 3);
454
455 assert!(
457 cache.get("keep_me", false).is_some(),
458 "keep_me should survive eviction because it was recently accessed"
459 );
460
461 let stale_remaining = ["stale_1", "stale_2"]
463 .iter()
464 .filter(|k| cache.get(k, false).is_some())
465 .count();
466 assert!(
467 stale_remaining < 2,
468 "at least one stale entry should have been evicted"
469 );
470 }
471}