llm_tokenizer/cache/
l0.rs1use std::sync::{
9 atomic::{AtomicU64, Ordering},
10 Arc,
11};
12
13use dashmap::DashMap;
14
15use crate::traits::Encoding;
16
17pub struct L0Cache {
20 map: Arc<DashMap<String, Arc<Encoding>>>,
22 max_entries: usize,
24 hits: AtomicU64,
26 misses: AtomicU64,
28}
29
30impl L0Cache {
31 pub fn new(max_entries: usize) -> Self {
33 Self {
34 map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
35 max_entries,
36 hits: AtomicU64::new(0),
37 misses: AtomicU64::new(0),
38 }
39 }
40
41 #[inline]
43 pub fn get(&self, key: &str) -> Option<Arc<Encoding>> {
44 match self.map.get(key) {
45 Some(entry) => {
46 self.hits.fetch_add(1, Ordering::Relaxed);
47 Some(Arc::clone(entry.value()))
49 }
50 None => {
51 self.misses.fetch_add(1, Ordering::Relaxed);
52 None
53 }
54 }
55 }
56
57 pub fn insert(&self, key: String, value: Encoding) {
59 if self.map.len() >= self.max_entries {
62 let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
63
64 if let Some(k) = key_to_remove {
66 self.map.remove(&k);
67 }
68 }
69
70 self.map.insert(key, Arc::new(value));
71 }
72
73 pub fn insert_arc(&self, key: String, value: Arc<Encoding>) {
75 if self.map.len() >= self.max_entries {
76 let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) };
77 if let Some(k) = key_to_remove {
78 self.map.remove(&k);
79 }
80 }
81 self.map.insert(key, value);
82 }
83
84 pub fn len(&self) -> usize {
86 self.map.len()
87 }
88
89 pub fn is_empty(&self) -> bool {
91 self.map.is_empty()
92 }
93
94 pub fn stats(&self) -> CacheStats {
96 let hits = self.hits.load(Ordering::Relaxed);
97 let misses = self.misses.load(Ordering::Relaxed);
98 let total_requests = hits + misses;
99
100 CacheStats {
101 hits,
102 misses,
103 entries: self.len(),
104 hit_rate: if total_requests > 0 {
105 hits as f64 / total_requests as f64
106 } else {
107 0.0
108 },
109 }
110 }
111
112 pub fn clear(&self) {
114 self.map.clear();
115 self.hits.store(0, Ordering::Relaxed);
116 self.misses.store(0, Ordering::Relaxed);
117 }
118
119 pub fn memory_usage(&self) -> usize {
121 self.len() * 2200
125 }
126}
127
128#[derive(Debug, Clone)]
129pub struct CacheStats {
130 pub hits: u64,
131 pub misses: u64,
132 pub entries: usize,
133 pub hit_rate: f64,
134}
135
136#[cfg(test)]
137mod tests {
138 use crate::{traits::Encoding, *};
139
140 fn mock_encoding(tokens: Vec<u32>) -> Encoding {
141 Encoding::Sp(tokens)
142 }
143
144 #[test]
145 fn test_basic_get_set() {
146 let cache = L0Cache::new(10);
147
148 assert!(cache.get("hello").is_none());
150
151 cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));
153
154 let result = cache.get("hello");
156 assert!(result.is_some());
157 assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
158 }
159
160 #[test]
161 fn test_eviction() {
162 let cache = L0Cache::new(2);
163
164 cache.insert("a".to_string(), mock_encoding(vec![1]));
165 cache.insert("b".to_string(), mock_encoding(vec![2]));
166
167 cache.insert("c".to_string(), mock_encoding(vec![3]));
169
170 assert_eq!(cache.len(), 2);
172 }
173
174 #[test]
175 fn test_stats() {
176 let cache = L0Cache::new(10);
177
178 cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
179
180 let _ = cache.get("missing");
182
183 let _ = cache.get("test");
185
186 let stats = cache.stats();
187 assert_eq!(stats.hits, 1);
188 assert_eq!(stats.misses, 1);
189 assert_eq!(stats.hit_rate, 0.5);
190 }
191
192 #[test]
193 fn test_clear() {
194 let cache = L0Cache::new(10);
195
196 cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
197 assert_eq!(cache.len(), 1);
198
199 cache.clear();
200 assert_eq!(cache.len(), 0);
201 assert!(cache.get("test").is_none());
202 }
203
204 #[test]
205 fn test_concurrent_access() {
206 use std::thread;
207
208 let cache = Arc::new(L0Cache::new(1000));
209 let mut handles = vec![];
210
211 for i in 0..10 {
213 let cache_clone = cache.clone();
214 handles.push(thread::spawn(move || {
215 let key = format!("key_{}", i);
217 cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));
218
219 let result = cache_clone.get(&key);
221 assert!(result.is_some());
222 }));
223 }
224
225 for handle in handles {
226 handle.join().unwrap();
227 }
228
229 assert_eq!(cache.len(), 10);
231 }
232
233 #[test]
234 fn test_arc_reuse() {
235 let cache = L0Cache::new(10);
237 cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
238
239 let arc1 = cache.get("test").unwrap();
240 let arc2 = cache.get("test").unwrap();
241
242 assert!(Arc::ptr_eq(&arc1, &arc2));
244 }
245}