Skip to main content

llm_tokenizer/cache/
mod.rs

1//! Tokenizer Caching Layer
2//!
3//! Provides a caching wrapper around any tokenizer implementation to speed up
4//! repeated tokenization of the same strings (e.g., system prompts).
5//!
6//! # Architecture
7//! - **L0 Cache**: Whole-string exact match (90% of wins)
8//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
9//!
10//! # Usage
11//! ```ignore
12//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
13//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
14//! let encoding = cached.encode("Hello world")?;
15//! ```
16
17mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30    chat_template::{
31        ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
32    },
33    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
34};
35
36/// Configuration for the tokenizer cache
37#[derive(Debug, Clone)]
38pub struct CacheConfig {
39    /// Enable L0 (whole-string) cache
40    pub enable_l0: bool,
41    /// Maximum number of entries in L0 cache
42    pub l0_max_entries: usize,
43    /// Enable L1 (prefix) cache
44    pub enable_l1: bool,
45    /// Maximum memory for L1 cache in bytes
46    pub l1_max_memory: usize,
47}
48
49impl Default for CacheConfig {
50    fn default() -> Self {
51        Self {
52            enable_l0: true,
53            l0_max_entries: 10_000, // ~22MB memory for typical prompts
54            enable_l1: false,       // Opt-in for now
55            l1_max_memory: 50 * 1024 * 1024, // 50MB
56        }
57    }
58}
59
60/// A caching wrapper around any tokenizer
61pub struct CachedTokenizer {
62    /// The underlying tokenizer
63    inner: Arc<dyn Tokenizer>,
64    /// L0 cache (whole-string exact match)
65    l0: Option<L0Cache>,
66    /// L1 cache (prefix matching at fixed boundaries)
67    l1: Option<L1Cache>,
68    /// Fingerprint for cache invalidation
69    fingerprint: TokenizerFingerprint,
70    /// Cached special token strings (extracted once at construction)
71    special_token_strings: Vec<String>,
72}
73
74impl CachedTokenizer {
75    /// Create a new cached tokenizer
76    pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
77        let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
78
79        let l0 = if config.enable_l0 {
80            Some(L0Cache::new(config.l0_max_entries))
81        } else {
82            None
83        };
84
85        let l1 = if config.enable_l1 {
86            Some(L1Cache::new(config.l1_max_memory))
87        } else {
88            None
89        };
90
91        // Extract special tokens once at construction time
92        let special_token_strings = Self::extract_special_token_strings(&inner);
93
94        Self {
95            inner,
96            l0,
97            l1,
98            fingerprint,
99            special_token_strings,
100        }
101    }
102
103    /// Extract all special token strings from the tokenizer (called once at construction)
104    fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
105        let special_tokens = tokenizer.get_special_tokens();
106        let mut tokens = Vec::new();
107
108        if let Some(ref token) = special_tokens.bos_token {
109            tokens.push(token.clone());
110        }
111        if let Some(ref token) = special_tokens.eos_token {
112            tokens.push(token.clone());
113        }
114        if let Some(ref token) = special_tokens.unk_token {
115            tokens.push(token.clone());
116        }
117        if let Some(ref token) = special_tokens.sep_token {
118            tokens.push(token.clone());
119        }
120        if let Some(ref token) = special_tokens.pad_token {
121            tokens.push(token.clone());
122        }
123        if let Some(ref token) = special_tokens.cls_token {
124            tokens.push(token.clone());
125        }
126        if let Some(ref token) = special_tokens.mask_token {
127            tokens.push(token.clone());
128        }
129
130        tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
131        tokens
132    }
133
134    /// Get L0 cache statistics
135    pub fn cache_stats(&self) -> Option<CacheStats> {
136        self.l0.as_ref().map(|cache| cache.stats())
137    }
138
139    /// Get L1 cache statistics
140    pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
141        self.l1.as_ref().map(|cache| cache.stats())
142    }
143
144    /// Clear the cache
145    pub fn clear_cache(&self) {
146        if let Some(l0) = &self.l0 {
147            l0.clear();
148        }
149        if let Some(l1) = &self.l1 {
150            l1.clear();
151        }
152    }
153
154    /// Get the fingerprint of the underlying tokenizer
155    pub fn fingerprint(&self) -> &TokenizerFingerprint {
156        &self.fingerprint
157    }
158
159    /// Get a reference to the inner (wrapped) tokenizer
160    pub fn inner(&self) -> &Arc<dyn Tokenizer> {
161        &self.inner
162    }
163}
164
165impl Encoder for CachedTokenizer {
166    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
167        // L0 cache lookup (exact match, keyed on input + add_special_tokens)
168        if let Some(l0) = &self.l0 {
169            if let Some(cached) = l0.get(input, add_special_tokens) {
170                return Ok((*cached).clone());
171            }
172        }
173
174        // L1 cache lookup (prefix match at special token boundaries)
175        if let Some(l1) = &self.l1 {
176            let tokens: Vec<&str> = self
177                .special_token_strings
178                .iter()
179                .map(|s| s.as_str())
180                .collect();
181
182            if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
183                let suffix = &input[prefix_len..];
184                if !suffix.is_empty() {
185                    let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
186
187                    let mut merged_tokens = prefix_tokens;
188                    merged_tokens.extend_from_slice(suffix_encoding.token_ids());
189
190                    let merged_encoding = Encoding::Plain(merged_tokens);
191
192                    if let Some(l0) = &self.l0 {
193                        l0.insert(
194                            input.to_string(),
195                            add_special_tokens,
196                            merged_encoding.clone(),
197                        );
198                    }
199
200                    return Ok(merged_encoding);
201                }
202            }
203        }
204
205        // Full tokenization (both L0 and L1 miss)
206        let encoding = self.inner.encode(input, add_special_tokens)?;
207
208        // Cache in L0
209        if let Some(l0) = &self.l0 {
210            l0.insert(input.to_string(), add_special_tokens, encoding.clone());
211        }
212
213        // Cache in L1 at special token boundaries
214        if let Some(l1) = &self.l1 {
215            let tokens: Vec<&str> = self
216                .special_token_strings
217                .iter()
218                .map(|s| s.as_str())
219                .collect();
220            let _ =
221                l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
222        }
223
224        Ok(encoding)
225    }
226
227    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
228        // Process each input in parallel, leveraging thread-safe caches
229        // This maintains the parallelism from the underlying HuggingFaceTokenizer
230        inputs
231            .par_iter()
232            .map(|&input| self.encode(input, add_special_tokens))
233            .collect()
234    }
235}
236
237impl Decoder for CachedTokenizer {
238    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
239        // Decoding is not cached (it's fast enough and rarely repeated)
240        self.inner.decode(token_ids, skip_special_tokens)
241    }
242}
243
244impl Tokenizer for CachedTokenizer {
245    fn vocab_size(&self) -> usize {
246        self.inner.vocab_size()
247    }
248
249    fn get_special_tokens(&self) -> &SpecialTokens {
250        self.inner.get_special_tokens()
251    }
252
253    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
254        self.inner.token_to_id(token)
255    }
256
257    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
258        self.inner.id_to_token(id)
259    }
260
261    fn as_any(&self) -> &dyn std::any::Any {
262        self
263    }
264
265    fn apply_chat_template(
266        &self,
267        messages: &[serde_json::Value],
268        params: ChatTemplateParams,
269    ) -> Result<String> {
270        self.inner.apply_chat_template(messages, params)
271    }
272
273    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
274        self.inner.chat_template_content_format()
275    }
276
277    fn thinking_toggle(&self) -> ThinkingToggle {
278        self.inner.thinking_toggle()
279    }
280
281    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
282        self.inner.thinking_key_name()
283    }
284    fn think_in_prefill(&self) -> bool {
285        self.inner.think_in_prefill()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use crate::{mock::MockTokenizer, *};
292
293    #[test]
294    fn test_cache_hit() {
295        let tokenizer = Arc::new(MockTokenizer::new());
296        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
297
298        let input = "Hello world";
299
300        // First call - miss
301        let result1 = cached.encode(input, false).unwrap();
302
303        // Second call - hit
304        let result2 = cached.encode(input, false).unwrap();
305
306        // Results should be identical
307        assert_eq!(result1.token_ids(), result2.token_ids());
308
309        // Check cache stats
310        let stats = cached.cache_stats().unwrap();
311        assert_eq!(stats.hits, 1);
312        assert_eq!(stats.misses, 1);
313    }
314
315    #[test]
316    fn test_cache_disabled() {
317        let tokenizer = Arc::new(MockTokenizer::new());
318        let config = CacheConfig {
319            enable_l0: false,
320            l0_max_entries: 0,
321            enable_l1: false,
322            l1_max_memory: 0,
323        };
324        let cached = CachedTokenizer::new(tokenizer, config);
325
326        let input = "Hello world";
327
328        // Both calls should work even without cache
329        let result1 = cached.encode(input, false).unwrap();
330        let result2 = cached.encode(input, false).unwrap();
331
332        assert_eq!(result1.token_ids(), result2.token_ids());
333
334        // No cache stats available
335        assert!(cached.cache_stats().is_none());
336    }
337
338    #[test]
339    fn test_encode_batch() {
340        let tokenizer = Arc::new(MockTokenizer::new());
341        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
342
343        let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
344
345        let results = cached.encode_batch(&inputs, false).unwrap();
346
347        assert_eq!(results.len(), 3);
348
349        // With parallel execution, duplicate inputs may be processed simultaneously
350        // and both see cache misses. Verify results are correct instead.
351        assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
352
353        // After batch processing, cache should be populated
354        // Subsequent calls should hit the cache
355        let _ = cached.encode("Hello", false).unwrap();
356        let stats = cached.cache_stats().unwrap();
357
358        // Should have at least 1 hit from the call above (cache was populated by batch)
359        assert!(
360            stats.hits >= 1,
361            "Expected at least 1 cache hit after batch processing"
362        );
363    }
364
365    #[test]
366    fn test_decoder_passthrough() {
367        let tokenizer = Arc::new(MockTokenizer::new());
368        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
369
370        let tokens = vec![1, 2, 3];
371        let decoded = cached.decode(&tokens, false).unwrap();
372
373        // Should just pass through to inner tokenizer
374        assert!(!decoded.is_empty());
375    }
376
377    #[test]
378    fn test_tokenizer_trait_methods() {
379        let tokenizer = Arc::new(MockTokenizer::new());
380        let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
381
382        // Should pass through to inner tokenizer
383        assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
384        assert!(cached.token_to_id("Hello").is_some());
385        assert!(cached.id_to_token(1).is_some());
386    }
387}