1use std::{
23 mem::size_of,
24 sync::{
25 atomic::{AtomicU64, Ordering},
26 Arc,
27 },
28};
29
30use blake3;
31use dashmap::DashMap;
32
33use crate::traits::TokenIdType;
34
35type Blake3Hash = [u8; 32];
37
38const NUM_SHARDS: usize = 16;
40
41fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
56 if special_tokens.is_empty() {
57 return Vec::new();
58 }
59
60 let mut boundaries = Vec::new();
61
62 for &token in special_tokens {
64 let mut start = 0;
65 while let Some(pos) = text[start..].find(token) {
66 let boundary = start + pos + token.len();
67 if boundary < text.len() {
69 boundaries.push(boundary);
70 }
71 start = boundary;
72 }
73 }
74
75 boundaries.sort_unstable();
77 boundaries.dedup();
78
79 boundaries
80}
81
82#[derive(Debug, Clone)]
85struct CachedPrefix {
86 tokens: Arc<[TokenIdType]>,
88 last_accessed: Arc<AtomicU64>,
90 size_bytes: usize,
92}
93
94pub struct L1Cache {
96 shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
100 max_memory: usize,
102 current_memory: AtomicU64,
104 hits: AtomicU64,
106 misses: AtomicU64,
108 access_counter: AtomicU64,
110}
111
112impl L1Cache {
113 pub fn new(max_memory: usize) -> Self {
115 let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
116
117 Self {
118 shards,
119 max_memory,
120 current_memory: AtomicU64::new(0),
121 hits: AtomicU64::new(0),
122 misses: AtomicU64::new(0),
123 access_counter: AtomicU64::new(0),
124 }
125 }
126
127 pub fn longest_prefix_match(
133 &self,
134 input: &str,
135 special_tokens: &[&str],
136 ) -> Option<(Vec<TokenIdType>, usize)> {
137 let boundaries = find_special_token_boundaries(input, special_tokens);
138
139 if boundaries.is_empty() {
140 self.misses.fetch_add(1, Ordering::Relaxed);
141 return None;
142 }
143
144 let mut hasher = blake3::Hasher::new();
146 let mut prefix_hashes = Vec::with_capacity(boundaries.len());
147 let mut last_pos = 0;
148 let bytes = input.as_bytes();
149 for &boundary_pos in &boundaries {
150 hasher.update(&bytes[last_pos..boundary_pos]);
151 prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
152 last_pos = boundary_pos;
153 }
154
155 for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
157 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
158
159 if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
160 let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
162 entry.last_accessed.store(timestamp, Ordering::Relaxed);
163
164 self.hits.fetch_add(1, Ordering::Relaxed);
165 return Some((entry.tokens.to_vec(), boundary_pos));
167 }
168 }
169
170 self.misses.fetch_add(1, Ordering::Relaxed);
171 None
172 }
173
174 pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
180 &self,
181 input: &str,
182 tokenizer: &E,
183 special_tokens: &[&str],
184 add_special_tokens: bool,
185 ) -> anyhow::Result<()> {
186 let boundaries = find_special_token_boundaries(input, special_tokens);
187
188 if boundaries.is_empty() {
189 return Ok(());
190 }
191
192 let mut hasher = blake3::Hasher::new();
193 let mut running_tokens = Vec::new();
194 let mut last_pos = 0;
195 let mut entries_to_insert = Vec::with_capacity(boundaries.len());
196 let bytes = input.as_bytes();
197 for (i, &boundary_pos) in boundaries.iter().enumerate() {
198 let delta_text = &input[last_pos..boundary_pos];
199
200 hasher.update(&bytes[last_pos..boundary_pos]);
202 let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
203
204 let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
207 running_tokens.extend_from_slice(segment_encoding.token_ids());
208
209 let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
212
213 let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
215
216 entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
217
218 last_pos = boundary_pos;
219 }
220
221 if entries_to_insert.is_empty() {
222 return Ok(());
223 }
224
225 let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
226
227 let current = self.current_memory.load(Ordering::Relaxed) as usize;
229 if current + total_size_needed > self.max_memory {
230 self.evict_lru(total_size_needed);
231 }
232
233 let current_timestamp = self.access_counter.load(Ordering::Relaxed);
235 for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
236 let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
237
238 let cached = CachedPrefix {
239 tokens: prefix_tokens,
240 last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
241 size_bytes,
242 };
243
244 if let Some(old) = self.shards[shard_idx].insert(hash_bytes, cached) {
245 let old_size = old.size_bytes as u64;
251 let new_size = size_bytes as u64;
252 if new_size >= old_size {
253 self.current_memory
254 .fetch_add(new_size - old_size, Ordering::Relaxed);
255 } else {
256 self.current_memory
257 .fetch_sub(old_size - new_size, Ordering::Relaxed);
258 }
259 } else {
260 self.current_memory
261 .fetch_add(size_bytes as u64, Ordering::Relaxed);
262 }
263 }
264
265 Ok(())
266 }
267
268 fn evict_lru(&self, space_needed: usize) {
281 const SAMPLE_SIZE: usize = 32; let mut freed = 0usize;
283 let mut iteration = 0usize;
284
285 while freed < space_needed {
287 let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
289
290 for i in 0..SAMPLE_SIZE {
292 let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
294
295 if let Some(entry) = self.shards[shard_idx].iter().next() {
297 let hash = *entry.key();
298 let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
299 let size = entry.value().size_bytes;
300 samples.push((shard_idx, hash, timestamp, size));
301 }
302 }
303
304 if samples.is_empty() {
305 break;
307 }
308
309 if let Some((shard_idx, hash, _, _)) =
311 samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
312 {
313 if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
315 freed += removed.size_bytes;
316 self.current_memory
317 .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
318 }
319 }
320
321 iteration += 1;
322 }
323 }
324
325 pub fn len(&self) -> usize {
327 self.shards.iter().map(|s| s.len()).sum()
328 }
329
330 pub fn is_empty(&self) -> bool {
332 self.shards.iter().all(|s| s.is_empty())
333 }
334
335 pub fn stats(&self) -> L1CacheStats {
337 let hits = self.hits.load(Ordering::Relaxed);
338 let misses = self.misses.load(Ordering::Relaxed);
339 let total_requests = hits + misses;
340
341 L1CacheStats {
342 hits,
343 misses,
344 entries: self.len(),
345 memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
346 hit_rate: if total_requests > 0 {
347 hits as f64 / total_requests as f64
348 } else {
349 0.0
350 },
351 }
352 }
353
354 pub fn clear(&self) {
356 for shard in &self.shards {
357 shard.clear();
358 }
359 self.current_memory.store(0, Ordering::Relaxed);
360 self.hits.store(0, Ordering::Relaxed);
361 self.misses.store(0, Ordering::Relaxed);
362 }
363}
364
365#[derive(Debug, Clone)]
366pub struct L1CacheStats {
367 pub hits: u64,
368 pub misses: u64,
369 pub entries: usize,
370 pub memory_bytes: usize,
371 pub hit_rate: f64,
372}
373
374#[cfg(test)]
375mod tests {
376 use crate::{mock::MockTokenizer, *};
377
378 #[test]
379 fn test_basic_prefix_match() {
380 let cache = L1Cache::new(1024 * 1024);
381 let special_tokens = &["<|im_start|>", "<|im_end|>"];
382 let tokenizer = MockTokenizer::new();
383
384 let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
386
387 cache
389 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
390 .unwrap();
391
392 assert!(!cache.is_empty());
394
395 let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
397 let result = cache.longest_prefix_match(input2, special_tokens);
398
399 assert!(result.is_some());
401 let (tokens, offset) = result.unwrap();
402 assert!(offset > 0);
403 assert!(!tokens.is_empty());
404 }
405
406 #[test]
407 fn test_short_input_with_boundaries() {
408 let cache = L1Cache::new(1024 * 1024);
409 let special_tokens = &["<|im_start|>", "<|im_end|>"];
410 let tokenizer = MockTokenizer::new();
411
412 let input = "<|im_start|>user\nHi<|im_end|>";
414
415 cache
416 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
417 .unwrap();
418
419 assert!(!cache.is_empty());
421
422 let result = cache.longest_prefix_match(input, special_tokens);
424 assert!(result.is_some());
425 }
426
427 #[test]
428 fn test_longest_match() {
429 let cache = L1Cache::new(1024 * 1024);
430 let special_tokens = &["<|im_start|>", "<|im_end|>"];
431 let tokenizer = MockTokenizer::new();
432
433 let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
435
436 cache
437 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
438 .unwrap();
439
440 assert!(cache.len() >= 2); let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
445 let result = cache.longest_prefix_match(partial_input, special_tokens);
446
447 assert!(result.is_some());
449 let (_, offset) = result.unwrap();
450 assert!(offset > 0);
451 assert!(offset <= partial_input.len());
452 }
453
454 #[test]
455 fn test_stats() {
456 let cache = L1Cache::new(1024 * 1024);
457 let special_tokens = &["<|im_start|>", "<|im_end|>"];
458 let tokenizer = MockTokenizer::new();
459
460 let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
462
463 cache
464 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
465 .unwrap();
466
467 let _ = cache.longest_prefix_match(input, special_tokens);
469
470 let stats = cache.stats();
471 assert!(stats.hits >= 1);
473 assert_eq!(stats.hit_rate, 1.0);
474 }
475
476 #[test]
477 fn test_clear() {
478 let cache = L1Cache::new(1024 * 1024);
479 let special_tokens = &["<|im_start|>", "<|im_end|>"];
480 let tokenizer = MockTokenizer::new();
481
482 let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
484
485 cache
486 .insert_at_boundaries(input, &tokenizer, special_tokens, false)
487 .unwrap();
488 assert!(!cache.is_empty());
489
490 cache.clear();
491 assert!(cache.is_empty());
492
493 let stats = cache.stats();
494 assert_eq!(stats.hits, 0);
495 assert_eq!(stats.misses, 0);
496 }
497
498 #[test]
499 fn test_lru_eviction() {
500 let cache = L1Cache::new(5 * 1024);
502 let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
503 let tokenizer = MockTokenizer::new();
504
505 let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
507 cache
508 .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
509 .unwrap();
510
511 let result = cache.longest_prefix_match(input1, special_tokens);
513 assert!(result.is_some());
514
515 let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
517 cache
518 .insert_at_boundaries(input2, &tokenizer, special_tokens, false)
519 .unwrap();
520
521 let result = cache.longest_prefix_match(input2, special_tokens);
523 assert!(result.is_some());
524
525 let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
527 cache
528 .insert_at_boundaries(input3, &tokenizer, special_tokens, false)
529 .unwrap();
530
531 let stats = cache.stats();
533 assert!(stats.memory_bytes <= 5 * 1024);
534
535 let result = cache.longest_prefix_match(input3, special_tokens);
537 assert!(result.is_some());
538 }
539
540 #[test]
541 fn test_concurrent_access() {
542 use std::{sync::Arc, thread};
543
544 let cache = Arc::new(L1Cache::new(1024 * 1024));
545 let special_tokens_owned: Vec<String> =
546 vec!["<|im_start|>".to_string(), "<|im_end|>".to_string()];
547 let special_tokens_arc = Arc::new(special_tokens_owned);
548
549 let mut handles = vec![];
550
551 for i in 0..10 {
554 let cache_clone = cache.clone();
555 let st_clone = special_tokens_arc.clone();
556 handles.push(thread::spawn(move || {
557 let tokenizer = MockTokenizer::new();
558 let special_tokens: Vec<&str> = st_clone.iter().map(|s| s.as_str()).collect();
559
560 let input = format!(
562 "<|im_start|>system\nYou are assistant number {i}.<|im_end|>\
563 <|im_start|>user\nThread {i} says hello world test token.<|im_end|>"
564 );
565
566 cache_clone
568 .insert_at_boundaries(&input, &tokenizer, &special_tokens, false)
569 .unwrap();
570
571 let result = cache_clone.longest_prefix_match(&input, &special_tokens);
573 assert!(
574 result.is_some(),
575 "Thread {i} expected a prefix match after insertion"
576 );
577
578 let (tokens, offset) = result.unwrap();
579 assert!(
580 !tokens.is_empty(),
581 "Thread {i} expected non-empty cached tokens"
582 );
583 assert!(offset > 0, "Thread {i} expected positive byte offset");
584 assert!(
585 offset <= input.len(),
586 "Thread {i}: offset {offset} exceeds input length {}",
587 input.len()
588 );
589 }));
590 }
591
592 for handle in handles {
594 handle.join().unwrap();
595 }
596
597 assert!(!cache.is_empty());
599
600 let stats = cache.stats();
602 assert!(
603 stats.memory_bytes > 0,
604 "Expected non-zero memory tracking after concurrent inserts"
605 );
606 assert!(
607 stats.entries > 0,
608 "Expected non-zero cache entries after concurrent inserts"
609 );
610 assert!(
612 stats.hits >= 10,
613 "Expected at least 10 cache hits, got {}",
614 stats.hits
615 );
616 }
617}