1use std::{
23 mem::size_of,
24 sync::{
25 atomic::{AtomicU64, Ordering},
26 Arc,
27 },
28};
29
30use blake3;
31use dashmap::DashMap;
32
33use crate::traits::TokenIdType;
34
35type Blake3Hash = [u8; 32];
37
38const NUM_SHARDS: usize = 16;
40
41fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
56 if special_tokens.is_empty() {
57 return Vec::new();
58 }
59
60 let mut boundaries = Vec::new();
61
62 for &token in special_tokens {
64 let mut start = 0;
65 while let Some(pos) = text[start..].find(token) {
66 let boundary = start + pos + token.len();
67 if boundary < text.len() {
69 boundaries.push(boundary);
70 }
71 start = boundary;
72 }
73 }
74
75 boundaries.sort_unstable();
77 boundaries.dedup();
78
79 boundaries
80}
81
82#[derive(Debug, Clone)]
85struct CachedPrefix {
86 tokens: Arc<[TokenIdType]>,
88 last_accessed: Arc<AtomicU64>,
90 size_bytes: usize,
92}
93
94pub struct L1Cache {
96 shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
100 max_memory: usize,
102 current_memory: AtomicU64,
104 hits: AtomicU64,
106 misses: AtomicU64,
108 access_counter: AtomicU64,
110}
111
112impl L1Cache {
113 pub fn new(max_memory: usize) -> Self {
115 let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
116
117 Self {
118 shards,
119 max_memory,
120 current_memory: AtomicU64::new(0),
121 hits: AtomicU64::new(0),
122 misses: AtomicU64::new(0),
123 access_counter: AtomicU64::new(0),
124 }
125 }
126
127 pub fn longest_prefix_match(
133 &self,
134 input: &str,
135 special_tokens: &[&str],
136 ) -> Option<(Vec<TokenIdType>, usize)> {
137 let boundaries = find_special_token_boundaries(input, special_tokens);
138
139 if boundaries.is_empty() {
140 self.misses.fetch_add(1, Ordering::Relaxed);
141 return None;
142 }
143
144 let mut hasher = blake3::Hasher::new();
146 let mut prefix_hashes = Vec::with_capacity(boundaries.len());
147 let mut last_pos = 0;
148 let bytes = input.as_bytes();
149 for &boundary_pos in &boundaries {
150 hasher.update(&bytes[last_pos..boundary_pos]);
151 prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
152 last_pos = boundary_pos;
153 }
154
155 for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
157 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
158
159 if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
160 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
162 entry.last_accessed.store(timestamp, Ordering::Relaxed);
163
164 self.hits.fetch_add(1, Ordering::Relaxed);
165 return Some((entry.tokens.to_vec(), boundary_pos));
167 }
168 }
169
170 self.misses.fetch_add(1, Ordering::Relaxed);
171 None
172 }
173
174 pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
180 &self,
181 input: &str,
182 tokenizer: &E,
183 special_tokens: &[&str],
184 add_special_tokens: bool,
185 ) -> anyhow::Result<()> {
186 let boundaries = find_special_token_boundaries(input, special_tokens);
187
188 if boundaries.is_empty() {
189 return Ok(());
190 }
191
192 let mut hasher = blake3::Hasher::new();
193 let mut running_tokens = Vec::new();
194 let mut last_pos = 0;
195 let mut entries_to_insert = Vec::with_capacity(boundaries.len());
196 let bytes = input.as_bytes();
197 for (i, &boundary_pos) in boundaries.iter().enumerate() {
198 let delta_text = &input[last_pos..boundary_pos];
199
200 hasher.update(&bytes[last_pos..boundary_pos]);
202 let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
203
204 let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
207 running_tokens.extend_from_slice(segment_encoding.token_ids());
208
209 let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
212
213 let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
215
216 entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
217
218 last_pos = boundary_pos;
219 }
220
221 if entries_to_insert.is_empty() {
222 return Ok(());
223 }
224
225 let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
226
227 let current = self.current_memory.load(Ordering::Relaxed) as usize;
229 if current + total_size_needed > self.max_memory {
230 self.evict_lru(total_size_needed);
231 }
232
233 let current_timestamp = self.access_counter.load(Ordering::Relaxed);
235 for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
236 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
237
238 let cached = CachedPrefix {
239 tokens: prefix_tokens, last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
241 size_bytes,
242 };
243
244 self.shards[shard_idx].insert(hash_bytes, cached);
245 self.current_memory
246 .fetch_add(size_bytes as u64, Ordering::Relaxed);
247 }
248
249 Ok(())
250 }
251
252 fn evict_lru(&self, space_needed: usize) {
265 const SAMPLE_SIZE: usize = 32; let mut freed = 0usize;
267 let mut iteration = 0usize;
268
269 while freed < space_needed {
271 let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
273
274 for i in 0..SAMPLE_SIZE {
276 let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
278
279 if let Some(entry) = self.shards[shard_idx].iter().next() {
281 let hash = *entry.key();
282 let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
283 let size = entry.value().size_bytes;
284 samples.push((shard_idx, hash, timestamp, size));
285 }
286 }
287
288 if samples.is_empty() {
289 break;
291 }
292
293 if let Some((shard_idx, hash, _, _)) =
295 samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
296 {
297 if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
299 freed += removed.size_bytes;
300 self.current_memory
301 .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
302 }
303 }
304
305 iteration += 1;
306 }
307 }
308
309 pub fn len(&self) -> usize {
311 self.shards.iter().map(|s| s.len()).sum()
312 }
313
314 pub fn is_empty(&self) -> bool {
316 self.shards.iter().all(|s| s.is_empty())
317 }
318
319 pub fn stats(&self) -> L1CacheStats {
321 let hits = self.hits.load(Ordering::Relaxed);
322 let misses = self.misses.load(Ordering::Relaxed);
323 let total_requests = hits + misses;
324
325 L1CacheStats {
326 hits,
327 misses,
328 entries: self.len(),
329 memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
330 hit_rate: if total_requests > 0 {
331 hits as f64 / total_requests as f64
332 } else {
333 0.0
334 },
335 }
336 }
337
338 pub fn clear(&self) {
340 for shard in &self.shards {
341 shard.clear();
342 }
343 self.current_memory.store(0, Ordering::Relaxed);
344 self.hits.store(0, Ordering::Relaxed);
345 self.misses.store(0, Ordering::Relaxed);
346 }
347}
348
349#[derive(Debug, Clone)]
350pub struct L1CacheStats {
351 pub hits: u64,
352 pub misses: u64,
353 pub entries: usize,
354 pub memory_bytes: usize,
355 pub hit_rate: f64,
356}
357
358#[cfg(test)]
359mod tests {
360 use crate::{mock::MockTokenizer, *};
361
362 #[test]
363 fn test_basic_prefix_match() {
364 let cache = L1Cache::new(1024 * 1024);
365 let special_tokens = &["<|im_start|>", "<|im_end|>"];
366 let tokenizer = MockTokenizer::new();
367
368 let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
370
371 cache
373 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
374 .unwrap();
375
376 assert!(!cache.is_empty());
378
379 let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
381 let result = cache.longest_prefix_match(input2, special_tokens);
382
383 assert!(result.is_some());
385 let (tokens, offset) = result.unwrap();
386 assert!(offset > 0);
387 assert!(!tokens.is_empty());
388 }
389
390 #[test]
391 fn test_short_input_with_boundaries() {
392 let cache = L1Cache::new(1024 * 1024);
393 let special_tokens = &["<|im_start|>", "<|im_end|>"];
394 let tokenizer = MockTokenizer::new();
395
396 let input = "<|im_start|>user\nHi<|im_end|>";
398
399 cache
400 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
401 .unwrap();
402
403 assert!(!cache.is_empty());
405
406 let result = cache.longest_prefix_match(input, special_tokens);
408 assert!(result.is_some());
409 }
410
411 #[test]
412 fn test_longest_match() {
413 let cache = L1Cache::new(1024 * 1024);
414 let special_tokens = &["<|im_start|>", "<|im_end|>"];
415 let tokenizer = MockTokenizer::new();
416
417 let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
419
420 cache
421 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
422 .unwrap();
423
424 assert!(cache.len() >= 2); let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
429 let result = cache.longest_prefix_match(partial_input, special_tokens);
430
431 assert!(result.is_some());
433 let (_, offset) = result.unwrap();
434 assert!(offset > 0);
435 assert!(offset <= partial_input.len());
436 }
437
438 #[test]
439 fn test_stats() {
440 let cache = L1Cache::new(1024 * 1024);
441 let special_tokens = &["<|im_start|>", "<|im_end|>"];
442 let tokenizer = MockTokenizer::new();
443
444 let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
446
447 cache
448 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
449 .unwrap();
450
451 let _ = cache.longest_prefix_match(input, special_tokens);
453
454 let stats = cache.stats();
455 assert!(stats.hits >= 1);
457 assert_eq!(stats.hit_rate, 1.0);
458 }
459
460 #[test]
461 fn test_clear() {
462 let cache = L1Cache::new(1024 * 1024);
463 let special_tokens = &["<|im_start|>", "<|im_end|>"];
464 let tokenizer = MockTokenizer::new();
465
466 let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
468
469 cache
470 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
471 .unwrap();
472 assert!(!cache.is_empty());
473
474 cache.clear();
475 assert!(cache.is_empty());
476
477 let stats = cache.stats();
478 assert_eq!(stats.hits, 0);
479 assert_eq!(stats.misses, 0);
480 }
481
482 #[test]
483 fn test_lru_eviction() {
484 let cache = L1Cache::new(5 * 1024);
486 let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
487 let tokenizer = MockTokenizer::new();
488
489 let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
491 cache
492 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
493 .unwrap();
494
495 let result = cache.longest_prefix_match(input1, special_tokens);
497 assert!(result.is_some());
498
499 let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
501 cache
502 .insert_at_boundaries(input2, &tokenizer, special_tokens, false)
503 .unwrap();
504
505 let result = cache.longest_prefix_match(input2, special_tokens);
507 assert!(result.is_some());
508
509 let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
511 cache
512 .insert_at_boundaries(input3, &tokenizer, special_tokens, false)
513 .unwrap();
514
515 let stats = cache.stats();
517 assert!(stats.memory_bytes <= 5 * 1024);
518
519 let result = cache.longest_prefix_match(input3, special_tokens);
521 assert!(result.is_some());
522 }
523}