Skip to main content

llm_tokenizer/cache/
mod.rs

1//! Tokenizer Caching Layer
2//!
3//! Provides a caching wrapper around any tokenizer implementation to speed up
4//! repeated tokenization of the same strings (e.g., system prompts).
5//!
6//! # Architecture
7//! - **L0 Cache**: Whole-string exact match (90% of wins)
8//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
9//!
10//! # Usage
11//! ```ignore
12//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
13//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
14//! let encoding = cached.encode("Hello world")?;
15//! ```
16
17mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30    chat_template::{
31        ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
32    },
33    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
34};
35
36/// Configuration for the tokenizer cache
37#[derive(Debug, Clone)]
38pub struct CacheConfig {
39    /// Enable L0 (whole-string) cache
40    pub enable_l0: bool,
41    /// Maximum number of entries in L0 cache
42    pub l0_max_entries: usize,
43    /// Enable L1 (prefix) cache
44    pub enable_l1: bool,
45    /// Maximum memory for L1 cache in bytes
46    pub l1_max_memory: usize,
47}
48
49impl Default for CacheConfig {
50    fn default() -> Self {
51        Self {
52            enable_l0: true,
53            l0_max_entries: 10_000, // ~22MB memory for typical prompts
54            enable_l1: false,       // Opt-in for now
55            l1_max_memory: 50 * 1024 * 1024, // 50MB
56        }
57    }
58}
59
60/// A caching wrapper around any tokenizer
61pub struct CachedTokenizer {
62    /// The underlying tokenizer
63    inner: Arc<dyn Tokenizer>,
64    /// L0 cache (whole-string exact match)
65    l0: Option<L0Cache>,
66    /// L1 cache (prefix matching at fixed boundaries)
67    l1: Option<L1Cache>,
68    /// Fingerprint for cache invalidation
69    fingerprint: TokenizerFingerprint,
70    /// Cached special token strings (extracted once at construction)
71    special_token_strings: Vec<String>,
72}
73
74impl CachedTokenizer {
75    /// Create a new cached tokenizer
76    pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
77        let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
78
79        let l0 = if config.enable_l0 {
80            Some(L0Cache::new(config.l0_max_entries))
81        } else {
82            None
83        };
84
85        let l1 = if config.enable_l1 {
86            Some(L1Cache::new(config.l1_max_memory))
87        } else {
88            None
89        };
90
91        // Extract special tokens once at construction time
92        let special_token_strings = Self::extract_special_token_strings(&inner);
93
94        Self {
95            inner,
96            l0,
97            l1,
98            fingerprint,
99            special_token_strings,
100        }
101    }
102
103    /// Extract all special token strings from the tokenizer (called once at construction)
104    fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
105        let special_tokens = tokenizer.get_special_tokens();
106        let mut tokens = Vec::new();
107
108        if let Some(ref token) = special_tokens.bos_token {
109            tokens.push(token.clone());
110        }
111        if let Some(ref token) = special_tokens.eos_token {
112            tokens.push(token.clone());
113        }
114        if let Some(ref token) = special_tokens.unk_token {
115            tokens.push(token.clone());
116        }
117        if let Some(ref token) = special_tokens.sep_token {
118            tokens.push(token.clone());
119        }
120        if let Some(ref token) = special_tokens.pad_token {
121            tokens.push(token.clone());
122        }
123        if let Some(ref token) = special_tokens.cls_token {
124            tokens.push(token.clone());
125        }
126        if let Some(ref token) = special_tokens.mask_token {
127            tokens.push(token.clone());
128        }
129
130        tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
131        tokens
132    }
133
134    /// Get L0 cache statistics
135    pub fn cache_stats(&self) -> Option<CacheStats> {
136        self.l0.as_ref().map(|cache| cache.stats())
137    }
138
139    /// Get L1 cache statistics
140    pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
141        self.l1.as_ref().map(|cache| cache.stats())
142    }
143
144    /// Clear the cache
145    pub fn clear_cache(&self) {
146        if let Some(l0) = &self.l0 {
147            l0.clear();
148        }
149        if let Some(l1) = &self.l1 {
150            l1.clear();
151        }
152    }
153
154    /// Get the fingerprint of the underlying tokenizer
155    pub fn fingerprint(&self) -> &TokenizerFingerprint {
156        &self.fingerprint
157    }
158
159    /// Get a reference to the inner (wrapped) tokenizer
160    pub fn inner(&self) -> &Arc<dyn Tokenizer> {
161        &self.inner
162    }
163}
164
165impl Encoder for CachedTokenizer {
166    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
167        // L0 cache lookup (exact match, keyed on input + add_special_tokens)
168        if let Some(l0) = &self.l0 {
169            if let Some(cached) = l0.get(input, add_special_tokens) {
170                return Ok((*cached).clone());
171            }
172        }
173
174        // L1 cache lookup (prefix match at special token boundaries)
175        if let Some(l1) = &self.l1 {
176            let tokens: Vec<&str> = self
177                .special_token_strings
178                .iter()
179                .map(|s| s.as_str())
180                .collect();
181
182            if let Some((prefix_tokens, prefix_len)) =
183                l1.longest_prefix_match(input, &tokens, add_special_tokens)
184            {
185                let suffix = &input[prefix_len..];
186                if !suffix.is_empty() {
187                    let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
188
189                    let mut merged_tokens = prefix_tokens;
190                    merged_tokens.extend_from_slice(suffix_encoding.token_ids());
191
192                    let merged_encoding = Encoding::Plain(merged_tokens);
193
194                    if let Some(l0) = &self.l0 {
195                        l0.insert(
196                            input.to_string(),
197                            add_special_tokens,
198                            merged_encoding.clone(),
199                        );
200                    }
201
202                    return Ok(merged_encoding);
203                }
204            }
205        }
206
207        // Full tokenization (both L0 and L1 miss)
208        let encoding = self.inner.encode(input, add_special_tokens)?;
209
210        // Cache in L0
211        if let Some(l0) = &self.l0 {
212            l0.insert(input.to_string(), add_special_tokens, encoding.clone());
213        }
214
215        // Cache in L1 at special token boundaries
216        if let Some(l1) = &self.l1 {
217            let tokens: Vec<&str> = self
218                .special_token_strings
219                .iter()
220                .map(|s| s.as_str())
221                .collect();
222            let _ =
223                l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
224        }
225
226        Ok(encoding)
227    }
228
229    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
230        // Process each input in parallel, leveraging thread-safe caches
231        // This maintains the parallelism from the underlying HuggingFaceTokenizer
232        inputs
233            .par_iter()
234            .map(|&input| self.encode(input, add_special_tokens))
235            .collect()
236    }
237}
238
239impl Decoder for CachedTokenizer {
240    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
241        // Decoding is not cached (it's fast enough and rarely repeated)
242        self.inner.decode(token_ids, skip_special_tokens)
243    }
244}
245
246impl Tokenizer for CachedTokenizer {
247    fn vocab_size(&self) -> usize {
248        self.inner.vocab_size()
249    }
250
251    fn get_special_tokens(&self) -> &SpecialTokens {
252        self.inner.get_special_tokens()
253    }
254
255    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
256        self.inner.token_to_id(token)
257    }
258
259    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
260        self.inner.id_to_token(id)
261    }
262
263    fn as_any(&self) -> &dyn std::any::Any {
264        self
265    }
266
267    fn apply_chat_template(
268        &self,
269        messages: &[serde_json::Value],
270        params: ChatTemplateParams,
271    ) -> Result<String> {
272        self.inner.apply_chat_template(messages, params)
273    }
274
275    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
276        self.inner.chat_template_content_format()
277    }
278
279    fn thinking_toggle(&self) -> ThinkingToggle {
280        self.inner.thinking_toggle()
281    }
282
283    fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
284        self.inner.thinking_key_name()
285    }
286    fn think_in_prefill(&self) -> bool {
287        self.inner.think_in_prefill()
288    }
289
290    fn eos_token_ids(&self) -> &[TokenIdType] {
291        self.inner.eos_token_ids()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use crate::{mock::MockTokenizer, *};
298
299    #[test]
300    fn test_cache_hit() {
301        let tokenizer = Arc::new(MockTokenizer::new());
302        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
303
304        let input = "Hello world";
305
306        // First call - miss
307        let result1 = cached.encode(input, false).unwrap();
308
309        // Second call - hit
310        let result2 = cached.encode(input, false).unwrap();
311
312        // Results should be identical
313        assert_eq!(result1.token_ids(), result2.token_ids());
314
315        // Check cache stats
316        let stats = cached.cache_stats().unwrap();
317        assert_eq!(stats.hits, 1);
318        assert_eq!(stats.misses, 1);
319    }
320
321    #[test]
322    fn test_cache_disabled() {
323        let tokenizer = Arc::new(MockTokenizer::new());
324        let config = CacheConfig {
325            enable_l0: false,
326            l0_max_entries: 0,
327            enable_l1: false,
328            l1_max_memory: 0,
329        };
330        let cached = CachedTokenizer::new(tokenizer, config);
331
332        let input = "Hello world";
333
334        // Both calls should work even without cache
335        let result1 = cached.encode(input, false).unwrap();
336        let result2 = cached.encode(input, false).unwrap();
337
338        assert_eq!(result1.token_ids(), result2.token_ids());
339
340        // No cache stats available
341        assert!(cached.cache_stats().is_none());
342    }
343
344    #[test]
345    fn test_encode_batch() {
346        let tokenizer = Arc::new(MockTokenizer::new());
347        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
348
349        let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
350
351        let results = cached.encode_batch(&inputs, false).unwrap();
352
353        assert_eq!(results.len(), 3);
354
355        // With parallel execution, duplicate inputs may be processed simultaneously
356        // and both see cache misses. Verify results are correct instead.
357        assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
358
359        // After batch processing, cache should be populated
360        // Subsequent calls should hit the cache
361        let _ = cached.encode("Hello", false).unwrap();
362        let stats = cached.cache_stats().unwrap();
363
364        // Should have at least 1 hit from the call above (cache was populated by batch)
365        assert!(
366            stats.hits >= 1,
367            "Expected at least 1 cache hit after batch processing"
368        );
369    }
370
371    #[test]
372    fn test_decoder_passthrough() {
373        let tokenizer = Arc::new(MockTokenizer::new());
374        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
375
376        let tokens = vec![1, 2, 3];
377        let decoded = cached.decode(&tokens, false).unwrap();
378
379        // Should just pass through to inner tokenizer
380        assert!(!decoded.is_empty());
381    }
382
383    #[test]
384    fn test_tokenizer_trait_methods() {
385        let tokenizer = Arc::new(MockTokenizer::new());
386        let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
387
388        // Should pass through to inner tokenizer
389        assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
390        assert!(cached.token_to_id("Hello").is_some());
391        assert!(cached.id_to_token(1).is_some());
392    }
393}