llm_tokenizer/cache/
mod.rs

1//! Tokenizer Caching Layer
2//!
3//! Provides a caching wrapper around any tokenizer implementation to speed up
4//! repeated tokenization of the same strings (e.g., system prompts).
5//!
6//! # Architecture
7//! - **L0 Cache**: Whole-string exact match (90% of wins)
8//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
9//!
10//! # Usage
11//! ```ignore
12//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
13//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
14//! let encoding = cached.encode("Hello world")?;
15//! ```
16
17mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer};
30
31/// Configuration for the tokenizer cache
32#[derive(Debug, Clone)]
33pub struct CacheConfig {
34    /// Enable L0 (whole-string) cache
35    pub enable_l0: bool,
36    /// Maximum number of entries in L0 cache
37    pub l0_max_entries: usize,
38    /// Enable L1 (prefix) cache
39    pub enable_l1: bool,
40    /// Maximum memory for L1 cache in bytes
41    pub l1_max_memory: usize,
42}
43
44impl Default for CacheConfig {
45    fn default() -> Self {
46        Self {
47            enable_l0: true,
48            l0_max_entries: 10_000, // ~22MB memory for typical prompts
49            enable_l1: false,       // Opt-in for now
50            l1_max_memory: 50 * 1024 * 1024, // 50MB
51        }
52    }
53}
54
55/// A caching wrapper around any tokenizer
56pub struct CachedTokenizer {
57    /// The underlying tokenizer
58    inner: Arc<dyn Tokenizer>,
59    /// L0 cache (whole-string exact match)
60    l0: Option<L0Cache>,
61    /// L1 cache (prefix matching at fixed boundaries)
62    l1: Option<L1Cache>,
63    /// Configuration
64    #[allow(dead_code)]
65    config: CacheConfig,
66    /// Fingerprint for cache invalidation
67    fingerprint: TokenizerFingerprint,
68    /// Cached special token strings (extracted once at construction)
69    special_token_strings: Vec<String>,
70}
71
72impl CachedTokenizer {
73    /// Create a new cached tokenizer
74    pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
75        let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
76
77        let l0 = if config.enable_l0 {
78            Some(L0Cache::new(config.l0_max_entries))
79        } else {
80            None
81        };
82
83        let l1 = if config.enable_l1 {
84            Some(L1Cache::new(config.l1_max_memory))
85        } else {
86            None
87        };
88
89        // Extract special tokens once at construction time
90        let special_token_strings = Self::extract_special_token_strings(&inner);
91
92        Self {
93            inner,
94            l0,
95            l1,
96            config,
97            fingerprint,
98            special_token_strings,
99        }
100    }
101
102    /// Extract all special token strings from the tokenizer (called once at construction)
103    fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
104        let special_tokens = tokenizer.get_special_tokens();
105        let mut tokens = Vec::new();
106
107        if let Some(ref token) = special_tokens.bos_token {
108            tokens.push(token.clone());
109        }
110        if let Some(ref token) = special_tokens.eos_token {
111            tokens.push(token.clone());
112        }
113        if let Some(ref token) = special_tokens.unk_token {
114            tokens.push(token.clone());
115        }
116        if let Some(ref token) = special_tokens.sep_token {
117            tokens.push(token.clone());
118        }
119        if let Some(ref token) = special_tokens.pad_token {
120            tokens.push(token.clone());
121        }
122        if let Some(ref token) = special_tokens.cls_token {
123            tokens.push(token.clone());
124        }
125        if let Some(ref token) = special_tokens.mask_token {
126            tokens.push(token.clone());
127        }
128
129        tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
130        tokens
131    }
132
133    /// Get L0 cache statistics
134    pub fn cache_stats(&self) -> Option<CacheStats> {
135        self.l0.as_ref().map(|cache| cache.stats())
136    }
137
138    /// Get L1 cache statistics
139    pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
140        self.l1.as_ref().map(|cache| cache.stats())
141    }
142
143    /// Clear the cache
144    pub fn clear_cache(&self) {
145        if let Some(l0) = &self.l0 {
146            l0.clear();
147        }
148        if let Some(l1) = &self.l1 {
149            l1.clear();
150        }
151    }
152
153    /// Get the fingerprint of the underlying tokenizer
154    pub fn fingerprint(&self) -> &TokenizerFingerprint {
155        &self.fingerprint
156    }
157
158    /// Get a reference to the inner (wrapped) tokenizer
159    pub fn inner(&self) -> &Arc<dyn Tokenizer> {
160        &self.inner
161    }
162}
163
164impl Encoder for CachedTokenizer {
165    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
166        // L0 cache lookup (exact match) - returns Arc<Encoding> for zero-copy
167        // Note: L0 cache doesn't distinguish by add_special_tokens flag
168        // This is acceptable for the current use case where embeddings always use true
169        // and chat always uses false with different input content
170        if let Some(l0) = &self.l0 {
171            if let Some(cached) = l0.get(input) {
172                // Unwrap the Arc - since Encoding is Clone, we can return the inner value
173                // For callers who need the tokens, they can access via token_ids() which is &[u32]
174                return Ok((*cached).clone());
175            }
176        }
177
178        // L1 cache lookup (prefix match at special token boundaries)
179        if let Some(l1) = &self.l1 {
180            // Use pre-computed special tokens refs (avoids allocation per call)
181            let tokens: Vec<&str> = self
182                .special_token_strings
183                .iter()
184                .map(|s| s.as_str())
185                .collect();
186
187            if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
188                // We have a prefix match - tokenize the suffix
189                let suffix = &input[prefix_len..];
190                if !suffix.is_empty() {
191                    let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
192
193                    // Merge prefix tokens + suffix tokens
194                    // Safe because we're splitting at special token boundaries
195                    let mut merged_tokens = prefix_tokens;
196                    merged_tokens.extend_from_slice(suffix_encoding.token_ids());
197
198                    let merged_encoding = Encoding::Sp(merged_tokens);
199
200                    // Cache the full result in L0
201                    if let Some(l0) = &self.l0 {
202                        l0.insert(input.to_string(), merged_encoding.clone());
203                    }
204
205                    return Ok(merged_encoding);
206                }
207            }
208        }
209
210        // Full tokenization (both L0 and L1 miss)
211        let encoding = self.inner.encode(input, add_special_tokens)?;
212
213        // Cache in L0
214        if let Some(l0) = &self.l0 {
215            l0.insert(input.to_string(), encoding.clone());
216        }
217
218        // Cache in L1 at special token boundaries
219        // Re-tokenizes prefixes for correctness (optimized for high prefix reuse)
220        if let Some(l1) = &self.l1 {
221            let tokens: Vec<&str> = self
222                .special_token_strings
223                .iter()
224                .map(|s| s.as_str())
225                .collect();
226            let _ =
227                l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
228            // Ignore errors in cache insertion - cache is best-effort
229        }
230
231        Ok(encoding)
232    }
233
234    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
235        // Process each input in parallel, leveraging thread-safe caches
236        // This maintains the parallelism from the underlying HuggingFaceTokenizer
237        inputs
238            .par_iter()
239            .map(|&input| self.encode(input, add_special_tokens))
240            .collect()
241    }
242}
243
244impl Decoder for CachedTokenizer {
245    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
246        // Decoding is not cached (it's fast enough and rarely repeated)
247        self.inner.decode(token_ids, skip_special_tokens)
248    }
249}
250
251impl Tokenizer for CachedTokenizer {
252    fn vocab_size(&self) -> usize {
253        self.inner.vocab_size()
254    }
255
256    fn get_special_tokens(&self) -> &SpecialTokens {
257        self.inner.get_special_tokens()
258    }
259
260    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
261        self.inner.token_to_id(token)
262    }
263
264    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
265        self.inner.id_to_token(id)
266    }
267
268    fn as_any(&self) -> &dyn std::any::Any {
269        self
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use crate::{mock::MockTokenizer, *};
276
277    #[test]
278    fn test_cache_hit() {
279        let tokenizer = Arc::new(MockTokenizer::new());
280        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
281
282        let input = "Hello world";
283
284        // First call - miss
285        let result1 = cached.encode(input, false).unwrap();
286
287        // Second call - hit
288        let result2 = cached.encode(input, false).unwrap();
289
290        // Results should be identical
291        assert_eq!(result1.token_ids(), result2.token_ids());
292
293        // Check cache stats
294        let stats = cached.cache_stats().unwrap();
295        assert_eq!(stats.hits, 1);
296        assert_eq!(stats.misses, 1);
297    }
298
299    #[test]
300    fn test_cache_disabled() {
301        let tokenizer = Arc::new(MockTokenizer::new());
302        let config = CacheConfig {
303            enable_l0: false,
304            l0_max_entries: 0,
305            enable_l1: false,
306            l1_max_memory: 0,
307        };
308        let cached = CachedTokenizer::new(tokenizer, config);
309
310        let input = "Hello world";
311
312        // Both calls should work even without cache
313        let result1 = cached.encode(input, false).unwrap();
314        let result2 = cached.encode(input, false).unwrap();
315
316        assert_eq!(result1.token_ids(), result2.token_ids());
317
318        // No cache stats available
319        assert!(cached.cache_stats().is_none());
320    }
321
322    #[test]
323    fn test_encode_batch() {
324        let tokenizer = Arc::new(MockTokenizer::new());
325        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
326
327        let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
328
329        let results = cached.encode_batch(&inputs, false).unwrap();
330
331        assert_eq!(results.len(), 3);
332
333        // With parallel execution, duplicate inputs may be processed simultaneously
334        // and both see cache misses. Verify results are correct instead.
335        assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
336
337        // After batch processing, cache should be populated
338        // Subsequent calls should hit the cache
339        let _ = cached.encode("Hello", false).unwrap();
340        let stats = cached.cache_stats().unwrap();
341
342        // Should have at least 1 hit from the call above (cache was populated by batch)
343        assert!(
344            stats.hits >= 1,
345            "Expected at least 1 cache hit after batch processing"
346        );
347    }
348
349    #[test]
350    fn test_decoder_passthrough() {
351        let tokenizer = Arc::new(MockTokenizer::new());
352        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
353
354        let tokens = vec![1, 2, 3];
355        let decoded = cached.decode(&tokens, false).unwrap();
356
357        // Should just pass through to inner tokenizer
358        assert!(!decoded.is_empty());
359    }
360
361    #[test]
362    fn test_tokenizer_trait_methods() {
363        let tokenizer = Arc::new(MockTokenizer::new());
364        let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
365
366        // Should pass through to inner tokenizer
367        assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
368        assert!(cached.token_to_id("Hello").is_some());
369        assert!(cached.id_to_token(1).is_some());
370    }
371}