Skip to main content

llm_tokenizer/cache/
l1.rs

1//! L1 Cache: Special-token boundary prefix cache
2//!
3//! Caches tokenization results at ALL special token boundaries.
4//! Special tokens (like `<|im_start|>`, `<|im_end|>`) are atomic in BPE tokenizers (special: true, normalized: false),
5//! making them the ONLY safe split points that guarantee correctness.
6//!
7//! **Design**: Cache at every special token boundary (not at fixed granularity intervals)
8//! - Simple: No granularity parameter, no search windows
9//! - Efficient: Fewer cache entries (10 instead of 64 for typical 8KB prompt)
10//! - Natural: Aligns with actual chat template structure
11//!
12//! Example:
13//!
14//! Template: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\n{query}<|im_end|>"
15//!
16//! Request 1: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>"
17//! Request 2: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>"
18//!
19//! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe)
20//! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
21
22use std::{
23    mem::size_of,
24    sync::{
25        atomic::{AtomicU64, Ordering},
26        Arc,
27    },
28};
29
30use blake3;
31use dashmap::DashMap;
32
33use crate::traits::TokenIdType;
34
35/// Hash type for cache keys
36type Blake3Hash = [u8; 32];
37
38/// Number of shards for concurrent access
39const NUM_SHARDS: usize = 16;
40
41/// Find ALL special token boundaries in the text
42///
43/// **ONLY uses special tokens** - these are atomic (special: true, normalized: false) in BPE,
44/// guaranteeing: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
45///
46/// No fallback to whitespace/punctuation - better to not cache than risk corruption.
47///
48/// Common special tokens:
49/// - ChatML: `<|im_start|>`, `<|im_end|>`
50/// - Llama 3: `<|begin_of_text|>`, `<|end_of_text|>`, `<|eot_id|>`
51/// - GPT: `<|endoftext|>`
52/// - Custom: `<|reserved_special_token_N|>`
53///
54/// Returns positions immediately after each special token (where prefixes can be cached).
55fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
56    if special_tokens.is_empty() {
57        return Vec::new();
58    }
59
60    let mut boundaries = Vec::new();
61
62    // Find all special token end positions
63    for &token in special_tokens {
64        let mut start = 0;
65        while let Some(pos) = text[start..].find(token) {
66            let boundary = start + pos + token.len();
67            // Only cache boundaries that leave some suffix to tokenize
68            if boundary < text.len() {
69                boundaries.push(boundary);
70            }
71            start = boundary;
72        }
73    }
74
75    // Sort and deduplicate (in case multiple special tokens end at same position)
76    boundaries.sort_unstable();
77    boundaries.dedup();
78
79    boundaries
80}
81
82/// A cached prefix entry
83/// Uses Arc<[TokenIdType]> for zero-copy access to tokens
84#[derive(Debug, Clone)]
85struct CachedPrefix {
86    /// The pre-computed token IDs for this prefix (Arc for zero-copy cloning)
87    tokens: Arc<[TokenIdType]>,
88    /// Last access timestamp (for LRU eviction)
89    last_accessed: Arc<AtomicU64>,
90    /// Size in bytes (for memory tracking during eviction)
91    size_bytes: usize,
92}
93
94/// L1 cache implementation with special-token-boundary prefix matching
95pub struct L1Cache {
96    /// Sharded maps for concurrent access
97    /// Key: Blake3 hash of bytes[0..boundary]
98    /// Value: Cached token IDs for that prefix
99    shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
100    /// Maximum memory in bytes
101    max_memory: usize,
102    /// Current memory usage estimate
103    current_memory: AtomicU64,
104    /// Cache hit counter
105    hits: AtomicU64,
106    /// Cache miss counter
107    misses: AtomicU64,
108    /// Monotonic counter for LRU timestamps
109    access_counter: AtomicU64,
110}
111
112impl L1Cache {
113    /// Create a new L1 cache with the specified memory limit
114    pub fn new(max_memory: usize) -> Self {
115        let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
116
117        Self {
118            shards,
119            max_memory,
120            current_memory: AtomicU64::new(0),
121            hits: AtomicU64::new(0),
122            misses: AtomicU64::new(0),
123            access_counter: AtomicU64::new(0),
124        }
125    }
126
127    /// Try to find the longest prefix match at special token boundaries
128    /// Returns (cached_tokens, byte_offset) if found
129    ///
130    /// Uses pre-computed tokens cached during insertion.
131    /// Returns Vec<TokenIdType> as the caller needs to extend it with suffix tokens.
132    pub fn longest_prefix_match(
133        &self,
134        input: &str,
135        special_tokens: &[&str],
136        add_special_tokens: bool,
137    ) -> Option<(Vec<TokenIdType>, usize)> {
138        let boundaries = find_special_token_boundaries(input, special_tokens);
139
140        if boundaries.is_empty() {
141            self.misses.fetch_add(1, Ordering::Relaxed);
142            return None;
143        }
144
145        // Build all prefix hashes incrementally O(N).
146        // Seed with add_special_tokens so prefixes tokenized with/without a
147        // leading BOS map to distinct keys (the first segment honors this flag).
148        let mut hasher = blake3::Hasher::new();
149        hasher.update(&[add_special_tokens as u8]);
150        let mut prefix_hashes = Vec::with_capacity(boundaries.len());
151        let mut last_pos = 0;
152        let bytes = input.as_bytes();
153        for &boundary_pos in &boundaries {
154            hasher.update(&bytes[last_pos..boundary_pos]);
155            prefix_hashes.push((boundary_pos, *hasher.clone().finalize().as_bytes()));
156            last_pos = boundary_pos;
157        }
158
159        // Search from the longest boundary to find the best match
160        for (boundary_pos, hash_bytes) in prefix_hashes.into_iter().rev() {
161            let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
162
163            if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
164                // Update last accessed timestamp for LRU
165                let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
166                entry.last_accessed.store(timestamp, Ordering::Relaxed);
167
168                self.hits.fetch_add(1, Ordering::Relaxed);
169                // Convert Arc<[T]> to Vec<T> - caller will extend with suffix tokens
170                return Some((entry.tokens.to_vec(), boundary_pos));
171            }
172        }
173
174        self.misses.fetch_add(1, Ordering::Relaxed);
175        None
176    }
177
178    /// Insert prefix entries at ALL special token boundaries
179    ///
180    /// Uses incremental hashing and tokenization for O(N) performance.
181    ///
182    /// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts).
183    pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
184        &self,
185        input: &str,
186        tokenizer: &E,
187        special_tokens: &[&str],
188        add_special_tokens: bool,
189    ) -> anyhow::Result<()> {
190        let boundaries = find_special_token_boundaries(input, special_tokens);
191
192        if boundaries.is_empty() {
193            return Ok(());
194        }
195
196        let mut hasher = blake3::Hasher::new();
197        hasher.update(&[add_special_tokens as u8]);
198        let mut running_tokens = Vec::new();
199        let mut last_pos = 0;
200        let mut entries_to_insert = Vec::with_capacity(boundaries.len());
201        let bytes = input.as_bytes();
202        for (i, &boundary_pos) in boundaries.iter().enumerate() {
203            let delta_text = &input[last_pos..boundary_pos];
204
205            // 1. Incremental Hash update
206            hasher.update(&bytes[last_pos..boundary_pos]);
207            let hash_bytes: Blake3Hash = *hasher.clone().finalize().as_bytes();
208
209            // 2. Incremental Tokenization
210            // Only add special tokens (like BOS) for the very first segment to avoid duplicates
211            let segment_encoding = tokenizer.encode(delta_text, (i == 0) && add_special_tokens)?;
212            running_tokens.extend_from_slice(segment_encoding.token_ids());
213
214            // 3. Prepare entry
215            // Convert current tokens to Arc<[TokenIdType]> for sharing
216            let prefix_tokens: Arc<[TokenIdType]> = running_tokens.as_slice().into();
217
218            // Size = text bytes + token storage
219            let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
220
221            entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
222
223            last_pos = boundary_pos;
224        }
225
226        if entries_to_insert.is_empty() {
227            return Ok(());
228        }
229
230        let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
231
232        // Evict if necessary
233        let current = self.current_memory.load(Ordering::Relaxed) as usize;
234        if current + total_size_needed > self.max_memory {
235            self.evict_lru(total_size_needed);
236        }
237
238        // Insert all entries, accounting for replaced entries in memory tracking
239        let current_timestamp = self.access_counter.load(Ordering::Relaxed);
240        for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
241            let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
242
243            let cached = CachedPrefix {
244                tokens: prefix_tokens,
245                last_accessed: Arc::new(AtomicU64::new(current_timestamp)),
246                size_bytes,
247            };
248
249            if let Some(old) = self.shards[shard_idx].insert(hash_bytes, cached) {
250                // Replaced an existing entry — adjust delta only.
251                // Note: the counter update is not atomic with the shard insert, so
252                // concurrent replacements of the same key can briefly skew the
253                // counter. This is benign — eviction is best-effort and the drift
254                // is bounded to a single entry's size per race.
255                let old_size = old.size_bytes as u64;
256                let new_size = size_bytes as u64;
257                if new_size >= old_size {
258                    self.current_memory
259                        .fetch_add(new_size - old_size, Ordering::Relaxed);
260                } else {
261                    self.current_memory
262                        .fetch_sub(old_size - new_size, Ordering::Relaxed);
263                }
264            } else {
265                self.current_memory
266                    .fetch_add(size_bytes as u64, Ordering::Relaxed);
267            }
268        }
269
270        Ok(())
271    }
272
273    /// Evict least recently used entries using approximate LRU via random sampling
274    ///
275    /// This uses an approximate LRU strategy that's much faster than true LRU:
276    /// - Samples K random entries from the cache (K=32)
277    /// - Evicts the oldest entry among the samples
278    /// - Repeats until enough space is freed
279    ///
280    /// This provides O(samples) complexity instead of O(total_entries * log(total_entries)),
281    /// avoiding latency spikes when eviction is triggered on large caches.
282    ///
283    /// The approximation is excellent in practice - sampling 32 entries from a large cache
284    /// gives high probability of finding very old entries.
285    fn evict_lru(&self, space_needed: usize) {
286        const SAMPLE_SIZE: usize = 32; // Number of entries to sample per eviction round
287        let mut freed = 0usize;
288        let mut iteration = 0usize;
289
290        // Keep evicting until we have enough space
291        while freed < space_needed {
292            // Collect samples from shards
293            let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
294
295            // Sample entries across different shards
296            for i in 0..SAMPLE_SIZE {
297                // Distribute samples across shards using iteration and index for variety
298                let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
299
300                // Get first entry from that shard (DashMap iteration order is arbitrary)
301                if let Some(entry) = self.shards[shard_idx].iter().next() {
302                    let hash = *entry.key();
303                    let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
304                    let size = entry.value().size_bytes;
305                    samples.push((shard_idx, hash, timestamp, size));
306                }
307            }
308
309            if samples.is_empty() {
310                // Cache is empty, nothing to evict
311                break;
312            }
313
314            // Find the oldest entry among samples
315            if let Some((shard_idx, hash, _, _)) =
316                samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
317            {
318                // Remove it
319                if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
320                    freed += removed.size_bytes;
321                    self.current_memory
322                        .fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
323                }
324            }
325
326            iteration += 1;
327        }
328    }
329
330    /// Get the number of entries in the cache
331    pub fn len(&self) -> usize {
332        self.shards.iter().map(|s| s.len()).sum()
333    }
334
335    /// Check if the cache is empty
336    pub fn is_empty(&self) -> bool {
337        self.shards.iter().all(|s| s.is_empty())
338    }
339
340    /// Get cache statistics
341    pub fn stats(&self) -> L1CacheStats {
342        let hits = self.hits.load(Ordering::Relaxed);
343        let misses = self.misses.load(Ordering::Relaxed);
344        let total_requests = hits + misses;
345
346        L1CacheStats {
347            hits,
348            misses,
349            entries: self.len(),
350            memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
351            hit_rate: if total_requests > 0 {
352                hits as f64 / total_requests as f64
353            } else {
354                0.0
355            },
356        }
357    }
358
359    /// Clear the cache
360    pub fn clear(&self) {
361        for shard in &self.shards {
362            shard.clear();
363        }
364        self.current_memory.store(0, Ordering::Relaxed);
365        self.hits.store(0, Ordering::Relaxed);
366        self.misses.store(0, Ordering::Relaxed);
367    }
368}
369
370#[derive(Debug, Clone)]
371pub struct L1CacheStats {
372    pub hits: u64,
373    pub misses: u64,
374    pub entries: usize,
375    pub memory_bytes: usize,
376    pub hit_rate: f64,
377}
378
379#[cfg(test)]
380mod tests {
381    use crate::{mock::MockTokenizer, *};
382
383    #[test]
384    fn test_basic_prefix_match() {
385        let cache = L1Cache::new(1024 * 1024);
386        let special_tokens = &["<|im_start|>", "<|im_end|>"];
387        let tokenizer = MockTokenizer::new();
388
389        // Realistic ChatML template with special tokens
390        let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
391
392        // Insert at special token boundaries (re-tokenizes prefixes)
393        cache
394            .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
395            .unwrap();
396
397        // Should have cached at special token boundaries
398        assert!(!cache.is_empty());
399
400        // Search with same prefix but different user query
401        let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
402        let result = cache.longest_prefix_match(input2, special_tokens, false);
403
404        // Should find a match at the special token boundary (after system message)
405        assert!(result.is_some());
406        let (tokens, offset) = result.unwrap();
407        assert!(offset > 0);
408        assert!(!tokens.is_empty());
409    }
410
411    #[test]
412    fn test_short_input_with_boundaries() {
413        let cache = L1Cache::new(1024 * 1024);
414        let special_tokens = &["<|im_start|>", "<|im_end|>"];
415        let tokenizer = MockTokenizer::new();
416
417        // Short input with special tokens
418        let input = "<|im_start|>user\nHi<|im_end|>";
419
420        cache
421            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
422            .unwrap();
423
424        // Should cache at <|im_start|> boundary (has suffix left)
425        assert!(!cache.is_empty());
426
427        // Should find a match
428        let result = cache.longest_prefix_match(input, special_tokens, false);
429        assert!(result.is_some());
430    }
431
432    #[test]
433    fn test_longest_match() {
434        let cache = L1Cache::new(1024 * 1024);
435        let special_tokens = &["<|im_start|>", "<|im_end|>"];
436        let tokenizer = MockTokenizer::new();
437
438        // Create multi-turn conversation with multiple special token boundaries (~400 bytes)
439        let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
440
441        cache
442            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
443            .unwrap();
444
445        // Should have multiple entries at special token boundaries
446        assert!(cache.len() >= 2); // At least 2 boundaries
447
448        // Search with partial conversation - should match at a special token boundary
449        let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
450        let result = cache.longest_prefix_match(partial_input, special_tokens, false);
451
452        // Should find a match at a special token boundary
453        assert!(result.is_some());
454        let (_, offset) = result.unwrap();
455        assert!(offset > 0);
456        assert!(offset <= partial_input.len());
457    }
458
459    #[test]
460    fn test_stats() {
461        let cache = L1Cache::new(1024 * 1024);
462        let special_tokens = &["<|im_start|>", "<|im_end|>"];
463        let tokenizer = MockTokenizer::new();
464
465        // ChatML input with special tokens
466        let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
467
468        cache
469            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
470            .unwrap();
471
472        // Try to find match
473        let _ = cache.longest_prefix_match(input, special_tokens, false);
474
475        let stats = cache.stats();
476        // Should have at least one hit (the longest special token boundary should match)
477        assert!(stats.hits >= 1);
478        assert_eq!(stats.hit_rate, 1.0);
479    }
480
481    #[test]
482    fn test_clear() {
483        let cache = L1Cache::new(1024 * 1024);
484        let special_tokens = &["<|im_start|>", "<|im_end|>"];
485        let tokenizer = MockTokenizer::new();
486
487        // ChatML input with special tokens
488        let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
489
490        cache
491            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
492            .unwrap();
493        assert!(!cache.is_empty());
494
495        cache.clear();
496        assert!(cache.is_empty());
497
498        let stats = cache.stats();
499        assert_eq!(stats.hits, 0);
500        assert_eq!(stats.misses, 0);
501    }
502
503    #[test]
504    fn test_lru_eviction() {
505        // Create a small cache (5KB) to trigger eviction
506        let cache = L1Cache::new(5 * 1024);
507        let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
508        let tokenizer = MockTokenizer::new();
509
510        // Insert first conversation
511        let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
512        cache
513            .insert_at_boundaries(input1, &tokenizer, special_tokens, false)
514            .unwrap();
515
516        // Access the first entry to update its timestamp
517        let result = cache.longest_prefix_match(input1, special_tokens, false);
518        assert!(result.is_some());
519
520        // Insert second conversation
521        let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
522        cache
523            .insert_at_boundaries(input2, &tokenizer, special_tokens, false)
524            .unwrap();
525
526        // Access the second entry to make it more recent
527        let result = cache.longest_prefix_match(input2, special_tokens, false);
528        assert!(result.is_some());
529
530        // Insert third conversation (should trigger eviction of oldest)
531        let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
532        cache
533            .insert_at_boundaries(input3, &tokenizer, special_tokens, false)
534            .unwrap();
535
536        // Verify cache didn't exceed max memory
537        let stats = cache.stats();
538        assert!(stats.memory_bytes <= 5 * 1024);
539
540        // The most recently accessed entries should still be present
541        let result = cache.longest_prefix_match(input3, special_tokens, false);
542        assert!(result.is_some());
543    }
544
545    #[test]
546    fn test_concurrent_access() {
547        use std::{sync::Arc, thread};
548
549        let cache = Arc::new(L1Cache::new(1024 * 1024));
550        let special_tokens_owned: Vec<String> =
551            vec!["<|im_start|>".to_string(), "<|im_end|>".to_string()];
552        let special_tokens_arc = Arc::new(special_tokens_owned);
553
554        let mut handles = vec![];
555
556        // Spawn 10 threads that each insert different special-token-bounded strings
557        // and query for prefix matches concurrently.
558        for i in 0..10 {
559            let cache_clone = cache.clone();
560            let st_clone = special_tokens_arc.clone();
561            handles.push(thread::spawn(move || {
562                let tokenizer = MockTokenizer::new();
563                let special_tokens: Vec<&str> = st_clone.iter().map(|s| s.as_str()).collect();
564
565                // Each thread uses a unique user message to avoid hash collisions
566                let input = format!(
567                    "<|im_start|>system\nYou are assistant number {i}.<|im_end|>\
568                     <|im_start|>user\nThread {i} says hello world test token.<|im_end|>"
569                );
570
571                // Insert prefix entries at boundaries
572                cache_clone
573                    .insert_at_boundaries(&input, &tokenizer, &special_tokens, false)
574                    .unwrap();
575
576                // Query for the same input - should find a prefix match
577                let result = cache_clone.longest_prefix_match(&input, &special_tokens, false);
578                assert!(
579                    result.is_some(),
580                    "Thread {i} expected a prefix match after insertion"
581                );
582
583                let (tokens, offset) = result.unwrap();
584                assert!(
585                    !tokens.is_empty(),
586                    "Thread {i} expected non-empty cached tokens"
587                );
588                assert!(offset > 0, "Thread {i} expected positive byte offset");
589                assert!(
590                    offset <= input.len(),
591                    "Thread {i}: offset {offset} exceeds input length {}",
592                    input.len()
593                );
594            }));
595        }
596
597        // Wait for all threads to complete (no panics)
598        for handle in handles {
599            handle.join().unwrap();
600        }
601
602        // Cache should contain entries from the concurrent inserts
603        assert!(!cache.is_empty());
604
605        // Memory tracking should be consistent (non-zero after inserts)
606        let stats = cache.stats();
607        assert!(
608            stats.memory_bytes > 0,
609            "Expected non-zero memory tracking after concurrent inserts"
610        );
611        assert!(
612            stats.entries > 0,
613            "Expected non-zero cache entries after concurrent inserts"
614        );
615        // Total hits should be at least 10 (one per thread)
616        assert!(
617            stats.hits >= 10,
618            "Expected at least 10 cache hits, got {}",
619            stats.hits
620        );
621    }
622
623    /// Encoder that prepends a sentinel BOS token when `add_special_tokens` is
624    /// set, so the same text yields different tokens per flag.
625    struct BosTokenizer;
626
627    const BOS_ID: TokenIdType = 99;
628
629    impl Encoder for BosTokenizer {
630        fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
631            let mut ids: Vec<TokenIdType> = Vec::new();
632            if add_special_tokens {
633                ids.push(BOS_ID);
634            }
635            ids.extend(input.bytes().map(TokenIdType::from));
636            Ok(Encoding::Plain(ids))
637        }
638
639        fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
640            inputs
641                .iter()
642                .map(|i| self.encode(i, add_special_tokens))
643                .collect()
644        }
645    }
646
647    #[test]
648    fn test_add_special_tokens_separates_keys() {
649        let cache = L1Cache::new(1024 * 1024);
650        let special_tokens = &["<|im_start|>", "<|im_end|>"];
651        let tokenizer = BosTokenizer;
652        let input = "<|im_start|>system\nhi<|im_end|><|im_start|>user\nq<|im_end|>";
653
654        // Insert the same input under both flags.
655        cache
656            .insert_at_boundaries(input, &tokenizer, special_tokens, true)
657            .unwrap();
658        cache
659            .insert_at_boundaries(input, &tokenizer, special_tokens, false)
660            .unwrap();
661
662        // Each flag must return its own prefix: BOS present only for `true`.
663        let (with_bos, _) = cache
664            .longest_prefix_match(input, special_tokens, true)
665            .expect("match for add_special_tokens=true");
666        let (without_bos, _) = cache
667            .longest_prefix_match(input, special_tokens, false)
668            .expect("match for add_special_tokens=false");
669
670        assert_eq!(with_bos.first(), Some(&BOS_ID));
671        assert_ne!(without_bos.first(), Some(&BOS_ID));
672    }
673
674    #[test]
675    fn test_opposite_flag_does_not_collide() {
676        let cache = L1Cache::new(1024 * 1024);
677        let special_tokens = &["<|im_start|>", "<|im_end|>"];
678        let tokenizer = BosTokenizer;
679        let input = "<|im_start|>system\nhi<|im_end|><|im_start|>user\nq<|im_end|>";
680
681        // Only the `true` flag is populated.
682        cache
683            .insert_at_boundaries(input, &tokenizer, special_tokens, true)
684            .unwrap();
685
686        // A lookup with the opposite flag must miss rather than return BOS tokens.
687        assert!(cache
688            .longest_prefix_match(input, special_tokens, false)
689            .is_none());
690    }
691}