1use std::{
23 mem::size_of,
24 sync::{
25 atomic::{AtomicU64, Ordering},
26 Arc,
27 },
28};
29
30use blake3;
31use dashmap::DashMap;
32
33use crate::traits::TokenIdType;
34
35type Blake3Hash = [u8; 32];
37
38const NUM_SHARDS: usize = 16;
40
41fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
56 if special_tokens.is_empty() {
57 return Vec::new();
58 }
59
60 let mut boundaries = Vec::new();
61
62 for &token in special_tokens {
64 let mut start = 0;
65 while let Some(pos) = text[start..].find(token) {
66 let boundary = start + pos + token.len();
67 if boundary < text.len() {
69 boundaries.push(boundary);
70 }
71 start = boundary;
72 }
73 }
74
75 boundaries.sort_unstable();
77 boundaries.dedup();
78
79 boundaries
80}
81
82#[derive(Debug, Clone)]
85struct CachedPrefix {
86 tokens: Arc<[TokenIdType]>,
88 last_accessed: Arc<AtomicU64>,
90 size_bytes: usize,
92}
93
94pub struct L1Cache {
96 shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
100 max_memory: usize,
102 current_memory: AtomicU64,
104 hits: AtomicU64,
106 misses: AtomicU64,
108 access_counter: AtomicU64,
110}
111
112impl L1Cache {
113 pub fn new(max_memory: usize) -> Self {
115 let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
116
117 Self {
118 shards,
119 max_memory,
120 current_memory: AtomicU64::new(0),
121 hits: AtomicU64::new(0),
122 misses: AtomicU64::new(0),
123 access_counter: AtomicU64::new(0),
124 }
125 }
126
127 pub fn longest_prefix_match(
133 &self,
134 input: &str,
135 special_tokens: &[&str],
136 add_special_tokens: bool,
137 ) -> Option<(Vec<TokenIdType>, usize)> {
138 let boundaries = find_special_token_boundaries(input, special_tokens);
139
140 if boundaries.is_empty() {
141 self.misses.fetch_add(1, Ordering::Relaxed);
142 return None;
143 }
144
145 let mut hasher = blake3::Hasher::new();
149 hasher.update(&[add_special_tokens as u8]);
150 let mut prefix_hashes = Vec::with_capacity(boundaries.len());
151 let mut last_pos = 0;
152 let bytes = input.as_bytes();
153 for &boundary_pos in &boundaries {
154 hasher.update(&bytes[last_pos..boundary_pos]);
155 prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
156 last_pos = boundary_pos;
157 }
158
159 for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
161 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
162
163 if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
164 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
166 entry.last_accessed.store(timestamp, Ordering::Relaxed);
167
168 self.hits.fetch_add(1, Ordering::Relaxed);
169 return Some((entry.tokens.to_vec(), boundary_pos));
171 }
172 }
173
174 self.misses.fetch_add(1, Ordering::Relaxed);
175 None
176 }
177
178 pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
184 &self,
185 input: &str,
186 tokenizer: &E,
187 special_tokens: &[&str],
188 add_special_tokens: bool,
189 ) -> anyhow::Result<()> {
190 let boundaries = find_special_token_boundaries(input, special_tokens);
191
192 if boundaries.is_empty() {
193 return Ok(());
194 }
195
196 let mut hasher = blake3::Hasher::new();
197 hasher.update(&[add_special_tokens as u8]);
198 let mut running_tokens = Vec::new();
199 let mut last_pos = 0;
200 let mut entries_to_insert = Vec::with_capacity(boundaries.len());
201 let bytes = input.as_bytes();
202 for (i, &boundary_pos) in boundaries.iter().enumerate() {
203 let delta_text = &input[last_pos..boundary_pos];
204
205 hasher.update(&bytes[last_pos..boundary_pos]);
207 let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
208
209 let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
212 running_tokens.extend_from_slice(segment_encoding.token_ids());
213
214 let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
217
218 let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
220
221 entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
222
223 last_pos = boundary_pos;
224 }
225
226 if entries_to_insert.is_empty() {
227 return Ok(());
228 }
229
230 let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
231
232 let current = self.current_memory.load(Ordering::Relaxed) as usize;
234 if current + total_size_needed > self.max_memory {
235 self.evict_lru(total_size_needed);
236 }
237
238 let current_timestamp = self.access_counter.load(Ordering::Relaxed);
240 for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
241 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
242
243 let cached = CachedPrefix {
244 tokens: prefix_tokens,
245 last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
246 size_bytes,
247 };
248
249 if let Some(old) = self.shards[shard_idx].insert(hash_bytes, cached) {
250 let old_size = old.size_bytes as u64;
256 let new_size = size_bytes as u64;
257 if new_size >= old_size {
258 self.current_memory
259 .fetch_add(new_size - old_size, Ordering::Relaxed);
260 } else {
261 self.current_memory
262 .fetch_sub(old_size - new_size, Ordering::Relaxed);
263 }
264 } else {
265 self.current_memory
266 .fetch_add(size_bytes as u64, Ordering::Relaxed);
267 }
268 }
269
270 Ok(())
271 }
272
273 fn evict_lru(&self, space_needed: usize) {
286 const SAMPLE_SIZE: usize = 32; let mut freed = 0usize;
288 let mut iteration = 0usize;
289
290 while freed < space_needed {
292 let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
294
295 for i in 0..SAMPLE_SIZE {
297 let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
299
300 if let Some(entry) = self.shards[shard_idx].iter().next() {
302 let hash = *entry.key();
303 let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
304 let size = entry.value().size_bytes;
305 samples.push((shard_idx, hash, timestamp, size));
306 }
307 }
308
309 if samples.is_empty() {
310 break;
312 }
313
314 if let Some((shard_idx, hash, _, _)) =
316 samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
317 {
318 if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
320 freed += removed.size_bytes;
321 self.current_memory
322 .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
323 }
324 }
325
326 iteration += 1;
327 }
328 }
329
330 pub fn len(&self) -> usize {
332 self.shards.iter().map(|s| s.len()).sum()
333 }
334
335 pub fn is_empty(&self) -> bool {
337 self.shards.iter().all(|s| s.is_empty())
338 }
339
340 pub fn stats(&self) -> L1CacheStats {
342 let hits = self.hits.load(Ordering::Relaxed);
343 let misses = self.misses.load(Ordering::Relaxed);
344 let total_requests = hits + misses;
345
346 L1CacheStats {
347 hits,
348 misses,
349 entries: self.len(),
350 memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
351 hit_rate: if total_requests > 0 {
352 hits as f64 / total_requests as f64
353 } else {
354 0.0
355 },
356 }
357 }
358
359 pub fn clear(&self) {
361 for shard in &self.shards {
362 shard.clear();
363 }
364 self.current_memory.store(0, Ordering::Relaxed);
365 self.hits.store(0, Ordering::Relaxed);
366 self.misses.store(0, Ordering::Relaxed);
367 }
368}
369
370#[derive(Debug, Clone)]
371pub struct L1CacheStats {
372 pub hits: u64,
373 pub misses: u64,
374 pub entries: usize,
375 pub memory_bytes: usize,
376 pub hit_rate: f64,
377}
378
379#[cfg(test)]
380mod tests {
381 use crate::{mock::MockTokenizer, *};
382
383 #[test]
384 fn test_basic_prefix_match() {
385 let cache = L1Cache::new(1024 * 1024);
386 let special_tokens = &["<|im_start|>", "<|im_end|>"];
387 let tokenizer = MockTokenizer::new();
388
389 let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
391
392 cache
394 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
395 .unwrap();
396
397 assert!(!cache.is_empty());
399
400 let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
402 let result = cache.longest_prefix_match(input2, special_tokens, false);
403
404 assert!(result.is_some());
406 let (tokens, offset) = result.unwrap();
407 assert!(offset > 0);
408 assert!(!tokens.is_empty());
409 }
410
411 #[test]
412 fn test_short_input_with_boundaries() {
413 let cache = L1Cache::new(1024 * 1024);
414 let special_tokens = &["<|im_start|>", "<|im_end|>"];
415 let tokenizer = MockTokenizer::new();
416
417 let input = "<|im_start|>user\nHi<|im_end|>";
419
420 cache
421 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
422 .unwrap();
423
424 assert!(!cache.is_empty());
426
427 let result = cache.longest_prefix_match(input, special_tokens, false);
429 assert!(result.is_some());
430 }
431
432 #[test]
433 fn test_longest_match() {
434 let cache = L1Cache::new(1024 * 1024);
435 let special_tokens = &["<|im_start|>", "<|im_end|>"];
436 let tokenizer = MockTokenizer::new();
437
438 let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
440
441 cache
442 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
443 .unwrap();
444
445 assert!(cache.len() >= 2); let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
450 let result = cache.longest_prefix_match(partial_input, special_tokens, false);
451
452 assert!(result.is_some());
454 let (_, offset) = result.unwrap();
455 assert!(offset > 0);
456 assert!(offset <= partial_input.len());
457 }
458
459 #[test]
460 fn test_stats() {
461 let cache = L1Cache::new(1024 * 1024);
462 let special_tokens = &["<|im_start|>", "<|im_end|>"];
463 let tokenizer = MockTokenizer::new();
464
465 let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
467
468 cache
469 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
470 .unwrap();
471
472 let _ = cache.longest_prefix_match(input, special_tokens, false);
474
475 let stats = cache.stats();
476 assert!(stats.hits >= 1);
478 assert_eq!(stats.hit_rate, 1.0);
479 }
480
481 #[test]
482 fn test_clear() {
483 let cache = L1Cache::new(1024 * 1024);
484 let special_tokens = &["<|im_start|>", "<|im_end|>"];
485 let tokenizer = MockTokenizer::new();
486
487 let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
489
490 cache
491 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
492 .unwrap();
493 assert!(!cache.is_empty());
494
495 cache.clear();
496 assert!(cache.is_empty());
497
498 let stats = cache.stats();
499 assert_eq!(stats.hits, 0);
500 assert_eq!(stats.misses, 0);
501 }
502
503 #[test]
504 fn test_lru_eviction() {
505 let cache = L1Cache::new(5 * 1024);
507 let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
508 let tokenizer = MockTokenizer::new();
509
510 let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
512 cache
513 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
514 .unwrap();
515
516 let result = cache.longest_prefix_match(input1, special_tokens, false);
518 assert!(result.is_some());
519
520 let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
522 cache
523 .insert_at_boundaries(input2, &tokenizer, special_tokens, false)
524 .unwrap();
525
526 let result = cache.longest_prefix_match(input2, special_tokens, false);
528 assert!(result.is_some());
529
530 let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
532 cache
533 .insert_at_boundaries(input3, &tokenizer, special_tokens, false)
534 .unwrap();
535
536 let stats = cache.stats();
538 assert!(stats.memory_bytes <= 5 * 1024);
539
540 let result = cache.longest_prefix_match(input3, special_tokens, false);
542 assert!(result.is_some());
543 }
544
545 #[test]
546 fn test_concurrent_access() {
547 use std::{sync::Arc, thread};
548
549 let cache = Arc::new(L1Cache::new(1024 * 1024));
550 let special_tokens_owned: Vec<String> =
551 vec!["<|im_start|>".to_string(), "<|im_end|>".to_string()];
552 let special_tokens_arc = Arc::new(special_tokens_owned);
553
554 let mut handles = vec![];
555
556 for i in 0..10 {
559 let cache_clone = cache.clone();
560 let st_clone = special_tokens_arc.clone();
561 handles.push(thread::spawn(move || {
562 let tokenizer = MockTokenizer::new();
563 let special_tokens: Vec<&str> = st_clone.iter().map(|s| s.as_str()).collect();
564
565 let input = format!(
567 "<|im_start|>system\nYou are assistant number {i}.<|im_end|>\
568 <|im_start|>user\nThread {i} says hello world test token.<|im_end|>"
569 );
570
571 cache_clone
573 .insert_at_boundaries(&input, &tokenizer, &special_tokens, false)
574 .unwrap();
575
576 let result = cache_clone.longest_prefix_match(&input, &special_tokens, false);
578 assert!(
579 result.is_some(),
580 "Thread {i} expected a prefix match after insertion"
581 );
582
583 let (tokens, offset) = result.unwrap();
584 assert!(
585 !tokens.is_empty(),
586 "Thread {i} expected non-empty cached tokens"
587 );
588 assert!(offset > 0, "Thread {i} expected positive byte offset");
589 assert!(
590 offset <= input.len(),
591 "Thread {i}: offset {offset} exceeds input length {}",
592 input.len()
593 );
594 }));
595 }
596
597 for handle in handles {
599 handle.join().unwrap();
600 }
601
602 assert!(!cache.is_empty());
604
605 let stats = cache.stats();
607 assert!(
608 stats.memory_bytes > 0,
609 "Expected non-zero memory tracking after concurrent inserts"
610 );
611 assert!(
612 stats.entries > 0,
613 "Expected non-zero cache entries after concurrent inserts"
614 );
615 assert!(
617 stats.hits >= 10,
618 "Expected at least 10 cache hits, got {}",
619 stats.hits
620 );
621 }
622
623 struct BosTokenizer;
626
627 const BOS_ID: TokenIdType = 99;
628
629 impl Encoder for BosTokenizer {
630 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
631 let mut ids: Vec<TokenIdType> = Vec::new();
632 if add_special_tokens {
633 ids.push(BOS_ID);
634 }
635 ids.extend(input.bytes().map(TokenIdType::from));
636 Ok(Encoding::Plain(ids))
637 }
638
639 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
640 inputs
641 .iter()
642 .map(|i| self.encode(i, add_special_tokens))
643 .collect()
644 }
645 }
646
647 #[test]
648 fn test_add_special_tokens_separates_keys() {
649 let cache = L1Cache::new(1024 * 1024);
650 let special_tokens = &["<|im_start|>", "<|im_end|>"];
651 let tokenizer = BosTokenizer;
652 let input = "<|im_start|>system\nhi<|im_end|><|im_start|>user\nq<|im_end|>";
653
654 cache
656 .insert_at_boundaries(input, &tokenizer, special_tokens, true)
657 .unwrap();
658 cache
659 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
660 .unwrap();
661
662 let (with_bos, _) = cache
664 .longest_prefix_match(input, special_tokens, true)
665 .expect("match for add_special_tokens=true");
666 let (without_bos, _) = cache
667 .longest_prefix_match(input, special_tokens, false)
668 .expect("match for add_special_tokens=false");
669
670 assert_eq!(with_bos.first(), Some(&BOS_ID));
671 assert_ne!(without_bos.first(), Some(&BOS_ID));
672 }
673
674 #[test]
675 fn test_opposite_flag_does_not_collide() {
676 let cache = L1Cache::new(1024 * 1024);
677 let special_tokens = &["<|im_start|>", "<|im_end|>"];
678 let tokenizer = BosTokenizer;
679 let input = "<|im_start|>system\nhi<|im_end|><|im_start|>user\nq<|im_end|>";
680
681 cache
683 .insert_at_boundaries(input, &tokenizer, special_tokens, true)
684 .unwrap();
685
686 assert!(cache
688 .longest_prefix_match(input, special_tokens, false)
689 .is_none());
690 }
691}