Skip to main content

llm_tokenizer/cache/
mod.rs

1//! Tokenizer Caching Layer
2//!
3//! Provides a caching wrapper around any tokenizer implementation to speed up
4//! repeated tokenization of the same strings (e.g., system prompts).
5//!
6//! # Architecture
7//! - **L0 Cache**: Whole-string exact match (90% of wins)
8//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
9//!
10//! # Usage
11//! ```ignore
12//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
13//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
14//! let encoding = cached.encode("Hello world")?;
15//! ```
16
17mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30    chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
31    traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
32};
33
34/// Configuration for the tokenizer cache
35#[derive(Debug, Clone)]
36pub struct CacheConfig {
37    /// Enable L0 (whole-string) cache
38    pub enable_l0: bool,
39    /// Maximum number of entries in L0 cache
40    pub l0_max_entries: usize,
41    /// Enable L1 (prefix) cache
42    pub enable_l1: bool,
43    /// Maximum memory for L1 cache in bytes
44    pub l1_max_memory: usize,
45}
46
47impl Default for CacheConfig {
48    fn default() -> Self {
49        Self {
50            enable_l0: true,
51            l0_max_entries: 10_000, // ~22MB memory for typical prompts
52            enable_l1: false,       // Opt-in for now
53            l1_max_memory: 50 * 1024 * 1024, // 50MB
54        }
55    }
56}
57
58/// A caching wrapper around any tokenizer
59pub struct CachedTokenizer {
60    /// The underlying tokenizer
61    inner: Arc<dyn Tokenizer>,
62    /// L0 cache (whole-string exact match)
63    l0: Option<L0Cache>,
64    /// L1 cache (prefix matching at fixed boundaries)
65    l1: Option<L1Cache>,
66    /// Fingerprint for cache invalidation
67    fingerprint: TokenizerFingerprint,
68    /// Cached special token strings (extracted once at construction)
69    special_token_strings: Vec<String>,
70}
71
72impl CachedTokenizer {
73    /// Create a new cached tokenizer
74    pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
75        let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
76
77        let l0 = if config.enable_l0 {
78            Some(L0Cache::new(config.l0_max_entries))
79        } else {
80            None
81        };
82
83        let l1 = if config.enable_l1 {
84            Some(L1Cache::new(config.l1_max_memory))
85        } else {
86            None
87        };
88
89        // Extract special tokens once at construction time
90        let special_token_strings = Self::extract_special_token_strings(&inner);
91
92        Self {
93            inner,
94            l0,
95            l1,
96            fingerprint,
97            special_token_strings,
98        }
99    }
100
101    /// Extract all special token strings from the tokenizer (called once at construction)
102    fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
103        let special_tokens = tokenizer.get_special_tokens();
104        let mut tokens = Vec::new();
105
106        if let Some(ref token) = special_tokens.bos_token {
107            tokens.push(token.clone());
108        }
109        if let Some(ref token) = special_tokens.eos_token {
110            tokens.push(token.clone());
111        }
112        if let Some(ref token) = special_tokens.unk_token {
113            tokens.push(token.clone());
114        }
115        if let Some(ref token) = special_tokens.sep_token {
116            tokens.push(token.clone());
117        }
118        if let Some(ref token) = special_tokens.pad_token {
119            tokens.push(token.clone());
120        }
121        if let Some(ref token) = special_tokens.cls_token {
122            tokens.push(token.clone());
123        }
124        if let Some(ref token) = special_tokens.mask_token {
125            tokens.push(token.clone());
126        }
127
128        tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
129        tokens
130    }
131
132    /// Get L0 cache statistics
133    pub fn cache_stats(&self) -> Option<CacheStats> {
134        self.l0.as_ref().map(|cache| cache.stats())
135    }
136
137    /// Get L1 cache statistics
138    pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
139        self.l1.as_ref().map(|cache| cache.stats())
140    }
141
142    /// Clear the cache
143    pub fn clear_cache(&self) {
144        if let Some(l0) = &self.l0 {
145            l0.clear();
146        }
147        if let Some(l1) = &self.l1 {
148            l1.clear();
149        }
150    }
151
152    /// Get the fingerprint of the underlying tokenizer
153    pub fn fingerprint(&self) -> &TokenizerFingerprint {
154        &self.fingerprint
155    }
156
157    /// Get a reference to the inner (wrapped) tokenizer
158    pub fn inner(&self) -> &Arc<dyn Tokenizer> {
159        &self.inner
160    }
161}
162
163impl Encoder for CachedTokenizer {
164    fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
165        // L0 cache lookup (exact match, keyed on input + add_special_tokens)
166        if let Some(l0) = &self.l0 {
167            if let Some(cached) = l0.get(input, add_special_tokens) {
168                return Ok((*cached).clone());
169            }
170        }
171
172        // L1 cache lookup (prefix match at special token boundaries)
173        if let Some(l1) = &self.l1 {
174            let tokens: Vec<&str> = self
175                .special_token_strings
176                .iter()
177                .map(|s| s.as_str())
178                .collect();
179
180            if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
181                let suffix = &input[prefix_len..];
182                if !suffix.is_empty() {
183                    let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
184
185                    let mut merged_tokens = prefix_tokens;
186                    merged_tokens.extend_from_slice(suffix_encoding.token_ids());
187
188                    let merged_encoding = Encoding::Plain(merged_tokens);
189
190                    if let Some(l0) = &self.l0 {
191                        l0.insert(
192                            input.to_string(),
193                            add_special_tokens,
194                            merged_encoding.clone(),
195                        );
196                    }
197
198                    return Ok(merged_encoding);
199                }
200            }
201        }
202
203        // Full tokenization (both L0 and L1 miss)
204        let encoding = self.inner.encode(input, add_special_tokens)?;
205
206        // Cache in L0
207        if let Some(l0) = &self.l0 {
208            l0.insert(input.to_string(), add_special_tokens, encoding.clone());
209        }
210
211        // Cache in L1 at special token boundaries
212        if let Some(l1) = &self.l1 {
213            let tokens: Vec<&str> = self
214                .special_token_strings
215                .iter()
216                .map(|s| s.as_str())
217                .collect();
218            let _ =
219                l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
220        }
221
222        Ok(encoding)
223    }
224
225    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
226        // Process each input in parallel, leveraging thread-safe caches
227        // This maintains the parallelism from the underlying HuggingFaceTokenizer
228        inputs
229            .par_iter()
230            .map(|&input| self.encode(input, add_special_tokens))
231            .collect()
232    }
233}
234
235impl Decoder for CachedTokenizer {
236    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
237        // Decoding is not cached (it's fast enough and rarely repeated)
238        self.inner.decode(token_ids, skip_special_tokens)
239    }
240}
241
242impl Tokenizer for CachedTokenizer {
243    fn vocab_size(&self) -> usize {
244        self.inner.vocab_size()
245    }
246
247    fn get_special_tokens(&self) -> &SpecialTokens {
248        self.inner.get_special_tokens()
249    }
250
251    fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
252        self.inner.token_to_id(token)
253    }
254
255    fn id_to_token(&self, id: TokenIdType) -> Option<String> {
256        self.inner.id_to_token(id)
257    }
258
259    fn as_any(&self) -> &dyn std::any::Any {
260        self
261    }
262
263    fn apply_chat_template(
264        &self,
265        messages: &[serde_json::Value],
266        params: ChatTemplateParams,
267    ) -> Result<String> {
268        self.inner.apply_chat_template(messages, params)
269    }
270
271    fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
272        self.inner.chat_template_content_format()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use crate::{mock::MockTokenizer, *};
279
280    #[test]
281    fn test_cache_hit() {
282        let tokenizer = Arc::new(MockTokenizer::new());
283        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
284
285        let input = "Hello world";
286
287        // First call - miss
288        let result1 = cached.encode(input, false).unwrap();
289
290        // Second call - hit
291        let result2 = cached.encode(input, false).unwrap();
292
293        // Results should be identical
294        assert_eq!(result1.token_ids(), result2.token_ids());
295
296        // Check cache stats
297        let stats = cached.cache_stats().unwrap();
298        assert_eq!(stats.hits, 1);
299        assert_eq!(stats.misses, 1);
300    }
301
302    #[test]
303    fn test_cache_disabled() {
304        let tokenizer = Arc::new(MockTokenizer::new());
305        let config = CacheConfig {
306            enable_l0: false,
307            l0_max_entries: 0,
308            enable_l1: false,
309            l1_max_memory: 0,
310        };
311        let cached = CachedTokenizer::new(tokenizer, config);
312
313        let input = "Hello world";
314
315        // Both calls should work even without cache
316        let result1 = cached.encode(input, false).unwrap();
317        let result2 = cached.encode(input, false).unwrap();
318
319        assert_eq!(result1.token_ids(), result2.token_ids());
320
321        // No cache stats available
322        assert!(cached.cache_stats().is_none());
323    }
324
325    #[test]
326    fn test_encode_batch() {
327        let tokenizer = Arc::new(MockTokenizer::new());
328        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
329
330        let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
331
332        let results = cached.encode_batch(&inputs, false).unwrap();
333
334        assert_eq!(results.len(), 3);
335
336        // With parallel execution, duplicate inputs may be processed simultaneously
337        // and both see cache misses. Verify results are correct instead.
338        assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
339
340        // After batch processing, cache should be populated
341        // Subsequent calls should hit the cache
342        let _ = cached.encode("Hello", false).unwrap();
343        let stats = cached.cache_stats().unwrap();
344
345        // Should have at least 1 hit from the call above (cache was populated by batch)
346        assert!(
347            stats.hits >= 1,
348            "Expected at least 1 cache hit after batch processing"
349        );
350    }
351
352    #[test]
353    fn test_decoder_passthrough() {
354        let tokenizer = Arc::new(MockTokenizer::new());
355        let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
356
357        let tokens = vec![1, 2, 3];
358        let decoded = cached.decode(&tokens, false).unwrap();
359
360        // Should just pass through to inner tokenizer
361        assert!(!decoded.is_empty());
362    }
363
364    #[test]
365    fn test_tokenizer_trait_methods() {
366        let tokenizer = Arc::new(MockTokenizer::new());
367        let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
368
369        // Should pass through to inner tokenizer
370        assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
371        assert!(cached.token_to_id("Hello").is_some());
372        assert!(cached.id_to_token(1).is_some());
373    }
374}