Skip to main content

llm_tokenizer/cache/
l1.rs

1//! L1 Cache: Special-token boundary prefix cache
2//!
3//! Caches tokenization results at ALL special token boundaries.
4//! Special tokens (like `<|im_start|>`, `<|im_end|>`) are atomic in BPE tokenizers (special: true, normalized: false),
5//! making them the ONLY safe split points that guarantee correctness.
6//!
7//! **Design**: Cache at every special token boundary (not at fixed granularity intervals)
8//! - Simple: No granularity parameter, no search windows
9//! - Efficient: Fewer cache entries (10 instead of 64 for typical 8KB prompt)
10//! - Natural: Aligns with actual chat template structure
11//!
12//! Example:
13//!
14//! Template: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\n{query}<|im_end|>"
15//!
16//! Request 1: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>"
17//! Request 2: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>"
18//!
19//! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe)
20//! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
21
22use std::{
23    mem::size_of,
24    sync::{
25        atomic::{AtomicU64, Ordering},
26        Arc,
27    },
28};
29
30use blake3;
31use dashmap::DashMap;
32
33use crate::traits::TokenIdType;
34
35/// Hash type for cache keys
36type Blake3Hash = [u8; 32];
37
38/// Number of shards for concurrent access
39const NUM_SHARDS: usize = 16;
40
41/// Find ALL special token boundaries in the text
42///
43/// **ONLY uses special tokens** - these are atomic (special: true, normalized: false) in BPE,
44/// guaranteeing: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
45///
46/// No fallback to whitespace/punctuation - better to not cache than risk corruption.
47///
48/// Common special tokens:
49/// - ChatML: `<|im_start|>`, `<|im_end|>`
50/// - Llama 3: `<|begin_of_text|>`, `<|end_of_text|>`, `<|eot_id|>`
51/// - GPT: `<|endoftext|>`
52/// - Custom: `<|reserved_special_token_N|>`
53///
54/// Returns positions immediately after each special token (where prefixes can be cached).
55fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
56    if special_tokens.is_empty() {
57        return Vec::new();
58    }
59
60    let mut boundaries = Vec::new();
61
62    // Find all special token end positions
63    for &token in special_tokens {
64        let mut start = 0;
65        while let Some(pos) = text[start..].find(token) {
66            let boundary = start + pos + token.len();
67            // Only cache boundaries that leave some suffix to tokenize
68            if boundary < text.len() {
69                boundaries.push(boundary);
70            }
71            start = boundary;
72        }
73    }
74
75    // Sort and deduplicate (in case multiple special tokens end at same position)
76    boundaries.sort_unstable();
77    boundaries.dedup();
78
79    boundaries
80}
81
82/// A cached prefix entry
83/// Uses Arc<[TokenIdType]> for zero-copy access to tokens
84#[derive(Debug, Clone)]
85struct CachedPrefix {
86    /// The pre-computed token IDs for this prefix (Arc for zero-copy cloning)
87    tokens: Arc<[TokenIdType]>,
88    /// Last access timestamp (for LRU eviction)
89    last_accessed: Arc<AtomicU64>,
90    /// Size in bytes (for memory tracking during eviction)
91    size_bytes: usize,
92}
93
94/// L1 cache implementation with special-token-boundary prefix matching
95pub struct L1Cache {
96    /// Sharded maps for concurrent access
97    /// Key: Blake3 hash of bytes[0..boundary]
98    /// Value: Cached token IDs for that prefix
99    shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
100    /// Maximum memory in bytes
101    max_memory: usize,
102    /// Current memory usage estimate
103    current_memory: AtomicU64,
104    /// Cache hit counter
105    hits: AtomicU64,
106    /// Cache miss counter
107    misses: AtomicU64,
108    /// Monotonic counter for LRU timestamps
109    access_counter: AtomicU64,
110}
111
112impl L1Cache {
113    /// Create a new L1 cache with the specified memory limit
114    pub fn new(max_memory: usize) -> Self {
115        let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
116
117        Self {
118            shards,
119            max_memory,
120            current_memory: AtomicU64::new(0),
121            hits: AtomicU64::new(0),
122            misses: AtomicU64::new(0),
123            access_counter: AtomicU64::new(0),
124        }
125    }
126
127    /// Try to find the longest prefix match at special token boundaries
128    /// Returns (cached_tokens, byte_offset) if found
129    ///
130    /// Uses pre-computed tokens cached during insertion.
131    /// Returns Vec<TokenIdType> as the caller needs to extend it with suffix tokens.
132    pub fn longest_prefix_match(
133        &self,
134        input: &str,
135        special_tokens: &[&str],
136    ) -> Option<(Vec<TokenIdType>, usize)> {
137        let boundaries = find_special_token_boundaries(input, special_tokens);
138
139        if boundaries.is_empty() {
140            self.misses.fetch_add(1, Ordering::Relaxed);
141            return None;
142        }
143
144        // Build all prefix hashes incrementally O(N)
145        let mut hasher = blake3::Hasher::new();
146        let mut prefix_hashes = Vec::with_capacity(boundaries.len());
147        let mut last_pos = 0;
148        let bytes = input.as_bytes();
149        for &boundary_pos in &boundaries {
150            hasher.update(&bytes[last_pos..boundary_pos]);
151            prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
152            last_pos = boundary_pos;
153        }
154
155        // Search from the longest boundary to find the best match
156        for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
157            let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
158
159            if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
160                // Update last accessed timestamp for LRU
161                let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
162                entry.last_accessed.store(timestamp, Ordering::Relaxed);
163
164                self.hits.fetch_add(1, Ordering::Relaxed);
165                // Convert Arc<[T]> to Vec<T> - caller will extend with suffix tokens
166                return Some((entry.tokens.to_vec(), boundary_pos));
167            }
168        }
169
170        self.misses.fetch_add(1, Ordering::Relaxed);
171        None
172    }
173
174    /// Insert prefix entries at ALL special token boundaries
175    ///
176    /// Uses incremental hashing and tokenization for O(N) performance.
177    ///
178    /// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts).
179    pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
180        &self,
181        input: &str,
182        tokenizer: &E,
183        special_tokens: &[&str],
184        add_special_tokens: bool,
185    ) -> anyhow::Result<()> {
186        let boundaries = find_special_token_boundaries(input, special_tokens);
187
188        if boundaries.is_empty() {
189            return Ok(());
190        }
191
192        let mut hasher = blake3::Hasher::new();
193        let mut running_tokens = Vec::new();
194        let mut last_pos = 0;
195        let mut entries_to_insert = Vec::with_capacity(boundaries.len());
196        let bytes = input.as_bytes();
197        for (i, &boundary_pos) in boundaries.iter().enumerate() {
198            let delta_text = &input[last_pos..boundary_pos];
199
200            // 1. Incremental Hash update
201            hasher.update(&bytes[last_pos..boundary_pos]);
202            let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
203
204            // 2. Incremental Tokenization
205            // Only add special tokens (like BOS) for the very first segment to avoid duplicates
206            let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
207            running_tokens.extend_from_slice(segment_encoding.token_ids());
208
209            // 3. Prepare entry
210            // Convert current tokens to Arc<[TokenIdType]> for sharing
211            let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
212
213            // Size = text bytes + token storage
214            let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
215
216            entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
217
218            last_pos = boundary_pos;
219        }
220
221        if entries_to_insert.is_empty() {
222            return Ok(());
223        }
224
225        let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
226
227        // Evict if necessary
228        let current = self.current_memory.load(Ordering::Relaxed) as usize;
229        if current + total_size_needed > self.max_memory {
230            self.evict_lru(total_size_needed);
231        }
232
233        // Insert all entries, accounting for replaced entries in memory tracking
234        let current_timestamp = self.access_counter.load(Ordering::Relaxed);
235        for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
236            let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
237
238            let cached = CachedPrefix {
239                tokens: prefix_tokens,
240                last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
241                size_bytes,
242            };
243
244            if let Some(old) = self.shards[shard_idx].insert(hash_bytes, cached) {
245                // Replaced an existing entry — adjust delta only.
246                // Note: the counter update is not atomic with the shard insert, so
247                // concurrent replacements of the same key can briefly skew the
248                // counter. This is benign — eviction is best-effort and the drift
249                // is bounded to a single entry's size per race.
250                let old_size = old.size_bytes as u64;
251                let new_size = size_bytes as u64;
252                if new_size >= old_size {
253                    self.current_memory
254                        .fetch_add(new_size - old_size, Ordering::Relaxed);
255                } else {
256                    self.current_memory
257                        .fetch_sub(old_size - new_size, Ordering::Relaxed);
258                }
259            } else {
260                self.current_memory
261                    .fetch_add(size_bytes as u64, Ordering::Relaxed);
262            }
263        }
264
265        Ok(())
266    }
267
268    /// Evict least recently used entries using approximate LRU via random sampling
269    ///
270    /// This uses an approximate LRU strategy that's much faster than true LRU:
271    /// - Samples K random entries from the cache (K=32)
272    /// - Evicts the oldest entry among the samples
273    /// - Repeats until enough space is freed
274    ///
275    /// This provides O(samples) complexity instead of O(total_entries * log(total_entries)),
276    /// avoiding latency spikes when eviction is triggered on large caches.
277    ///
278    /// The approximation is excellent in practice - sampling 32 entries from a large cache
279    /// gives high probability of finding very old entries.
280    fn evict_lru(&self, space_needed: usize) {
281        const SAMPLE_SIZE: usize = 32; // Number of entries to sample per eviction round
282        let mut freed = 0usize;
283        let mut iteration = 0usize;
284
285        // Keep evicting until we have enough space
286        while freed < space_needed {
287            // Collect samples from shards
288            let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
289
290            // Sample entries across different shards
291            for i in 0..SAMPLE_SIZE {
292                // Distribute samples across shards using iteration and index for variety
293                let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
294
295                // Get first entry from that shard (DashMap iteration order is arbitrary)
296                if let Some(entry) = self.shards[shard_idx].iter().next() {
297                    let hash = *entry.key();
298                    let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
299                    let size = entry.value().size_bytes;
300                    samples.push((shard_idx, hash, timestamp, size));
301                }
302            }
303
304            if samples.is_empty() {
305                // Cache is empty, nothing to evict
306                break;
307            }
308
309            // Find the oldest entry among samples
310            if let Some((shard_idx, hash, _, _)) =
311                samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
312            {
313                // Remove it
314                if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
315                    freed += removed.size_bytes;
316                    self.current_memory
317                        .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
318                }
319            }
320
321            iteration += 1;
322        }
323    }
324
325    /// Get the number of entries in the cache
326    pub fn len(&self) -> usize {
327        self.shards.iter().map(|s| s.len()).sum()
328    }
329
330    /// Check if the cache is empty
331    pub fn is_empty(&self) -> bool {
332        self.shards.iter().all(|s| s.is_empty())
333    }
334
335    /// Get cache statistics
336    pub fn stats(&self) -> L1CacheStats {
337        let hits = self.hits.load(Ordering::Relaxed);
338        let misses = self.misses.load(Ordering::Relaxed);
339        let total_requests = hits + misses;
340
341        L1CacheStats {
342            hits,
343            misses,
344            entries: self.len(),
345            memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
346            hit_rate: if total_requests > 0 {
347                hits as f64 / total_requests as f64
348            } else {
349                0.0
350            },
351        }
352    }
353
354    /// Clear the cache
355    pub fn clear(&self) {
356        for shard in &self.shards {
357            shard.clear();
358        }
359        self.current_memory.store(0, Ordering::Relaxed);
360        self.hits.store(0, Ordering::Relaxed);
361        self.misses.store(0, Ordering::Relaxed);
362    }
363}
364
365#[derive(Debug, Clone)]
366pub struct L1CacheStats {
367    pub hits: u64,
368    pub misses: u64,
369    pub entries: usize,
370    pub memory_bytes: usize,
371    pub hit_rate: f64,
372}
373
374#[cfg(test)]
375mod tests {
376    use crate::{mock::MockTokenizer, *};
377
378    #[test]
379    fn test_basic_prefix_match() {
380        let cache = L1Cache::new(1024 * 1024);
381        let special_tokens = &["<|im_start|>", "<|im_end|>"];
382        let tokenizer = MockTokenizer::new();
383
384        // Realistic ChatML template with special tokens
385        let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
386
387        // Insert at special token boundaries (re-tokenizes prefixes)
388        cache
389            .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
390            .unwrap();
391
392        // Should have cached at special token boundaries
393        assert!(!cache.is_empty());
394
395        // Search with same prefix but different user query
396        let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
397        let result = cache.longest_prefix_match(input2, special_tokens);
398
399        // Should find a match at the special token boundary (after system message)
400        assert!(result.is_some());
401        let (tokens, offset) = result.unwrap();
402        assert!(offset > 0);
403        assert!(!tokens.is_empty());
404    }
405
406    #[test]
407    fn test_short_input_with_boundaries() {
408        let cache = L1Cache::new(1024 * 1024);
409        let special_tokens = &["<|im_start|>", "<|im_end|>"];
410        let tokenizer = MockTokenizer::new();
411
412        // Short input with special tokens
413        let input = "<|im_start|>user\nHi<|im_end|>";
414
415        cache
416            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
417            .unwrap();
418
419        // Should cache at <|im_start|> boundary (has suffix left)
420        assert!(!cache.is_empty());
421
422        // Should find a match
423        let result = cache.longest_prefix_match(input, special_tokens);
424        assert!(result.is_some());
425    }
426
427    #[test]
428    fn test_longest_match() {
429        let cache = L1Cache::new(1024 * 1024);
430        let special_tokens = &["<|im_start|>", "<|im_end|>"];
431        let tokenizer = MockTokenizer::new();
432
433        // Create multi-turn conversation with multiple special token boundaries (~400 bytes)
434        let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
435
436        cache
437            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
438            .unwrap();
439
440        // Should have multiple entries at special token boundaries
441        assert!(cache.len() >= 2); // At least 2 boundaries
442
443        // Search with partial conversation - should match at a special token boundary
444        let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
445        let result = cache.longest_prefix_match(partial_input, special_tokens);
446
447        // Should find a match at a special token boundary
448        assert!(result.is_some());
449        let (_, offset) = result.unwrap();
450        assert!(offset > 0);
451        assert!(offset <= partial_input.len());
452    }
453
454    #[test]
455    fn test_stats() {
456        let cache = L1Cache::new(1024 * 1024);
457        let special_tokens = &["<|im_start|>", "<|im_end|>"];
458        let tokenizer = MockTokenizer::new();
459
460        // ChatML input with special tokens
461        let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
462
463        cache
464            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
465            .unwrap();
466
467        // Try to find match
468        let _ = cache.longest_prefix_match(input, special_tokens);
469
470        let stats = cache.stats();
471        // Should have at least one hit (the longest special token boundary should match)
472        assert!(stats.hits >= 1);
473        assert_eq!(stats.hit_rate, 1.0);
474    }
475
476    #[test]
477    fn test_clear() {
478        let cache = L1Cache::new(1024 * 1024);
479        let special_tokens = &["<|im_start|>", "<|im_end|>"];
480        let tokenizer = MockTokenizer::new();
481
482        // ChatML input with special tokens
483        let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
484
485        cache
486            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
487            .unwrap();
488        assert!(!cache.is_empty());
489
490        cache.clear();
491        assert!(cache.is_empty());
492
493        let stats = cache.stats();
494        assert_eq!(stats.hits, 0);
495        assert_eq!(stats.misses, 0);
496    }
497
498    #[test]
499    fn test_lru_eviction() {
500        // Create a small cache (5KB) to trigger eviction
501        let cache = L1Cache::new(5 * 1024);
502        let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
503        let tokenizer = MockTokenizer::new();
504
505        // Insert first conversation
506        let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
507        cache
508            .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
509            .unwrap();
510
511        // Access the first entry to update its timestamp
512        let result = cache.longest_prefix_match(input1, special_tokens);
513        assert!(result.is_some());
514
515        // Insert second conversation
516        let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
517        cache
518            .insert_at_boundaries(input2, &tokenizer, special_tokens, false)
519            .unwrap();
520
521        // Access the second entry to make it more recent
522        let result = cache.longest_prefix_match(input2, special_tokens);
523        assert!(result.is_some());
524
525        // Insert third conversation (should trigger eviction of oldest)
526        let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
527        cache
528            .insert_at_boundaries(input3, &tokenizer, special_tokens, false)
529            .unwrap();
530
531        // Verify cache didn't exceed max memory
532        let stats = cache.stats();
533        assert!(stats.memory_bytes <= 5 * 1024);
534
535        // The most recently accessed entries should still be present
536        let result = cache.longest_prefix_match(input3, special_tokens);
537        assert!(result.is_some());
538    }
539
540    #[test]
541    fn test_concurrent_access() {
542        use std::{sync::Arc, thread};
543
544        let cache = Arc::new(L1Cache::new(1024 * 1024));
545        let special_tokens_owned: Vec<String> =
546            vec!["<|im_start|>".to_string(), "<|im_end|>".to_string()];
547        let special_tokens_arc = Arc::new(special_tokens_owned);
548
549        let mut handles = vec![];
550
551        // Spawn 10 threads that each insert different special-token-bounded strings
552        // and query for prefix matches concurrently.
553        for i in 0..10 {
554            let cache_clone = cache.clone();
555            let st_clone = special_tokens_arc.clone();
556            handles.push(thread::spawn(move || {
557                let tokenizer = MockTokenizer::new();
558                let special_tokens: Vec<&str> = st_clone.iter().map(|s| s.as_str()).collect();
559
560                // Each thread uses a unique user message to avoid hash collisions
561                let input = format!(
562                    "<|im_start|>system\nYou are assistant number {i}.<|im_end|>\
563                     <|im_start|>user\nThread {i} says hello world test token.<|im_end|>"
564                );
565
566                // Insert prefix entries at boundaries
567                cache_clone
568                    .insert_at_boundaries(&input, &tokenizer, &special_tokens, false)
569                    .unwrap();
570
571                // Query for the same input - should find a prefix match
572                let result = cache_clone.longest_prefix_match(&input, &special_tokens);
573                assert!(
574                    result.is_some(),
575                    "Thread {i} expected a prefix match after insertion"
576                );
577
578                let (tokens, offset) = result.unwrap();
579                assert!(
580                    !tokens.is_empty(),
581                    "Thread {i} expected non-empty cached tokens"
582                );
583                assert!(offset > 0, "Thread {i} expected positive byte offset");
584                assert!(
585                    offset <= input.len(),
586                    "Thread {i}: offset {offset} exceeds input length {}",
587                    input.len()
588                );
589            }));
590        }
591
592        // Wait for all threads to complete (no panics)
593        for handle in handles {
594            handle.join().unwrap();
595        }
596
597        // Cache should contain entries from the concurrent inserts
598        assert!(!cache.is_empty());
599
600        // Memory tracking should be consistent (non-zero after inserts)
601        let stats = cache.stats();
602        assert!(
603            stats.memory_bytes > 0,
604            "Expected non-zero memory tracking after concurrent inserts"
605        );
606        assert!(
607            stats.entries > 0,
608            "Expected non-zero cache entries after concurrent inserts"
609        );
610        // Total hits should be at least 10 (one per thread)
611        assert!(
612            stats.hits >= 10,
613            "Expected at least 10 cache hits, got {}",
614            stats.hits
615        );
616    }
617}