llm_tokenizer/cache/
mod.rs1mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer};
30
31#[derive(Debug, Clone)]
33pub struct CacheConfig {
34 pub enable_l0: bool,
36 pub l0_max_entries: usize,
38 pub enable_l1: bool,
40 pub l1_max_memory: usize,
42}
43
44impl Default for CacheConfig {
45 fn default() -> Self {
46 Self {
47 enable_l0: true,
48 l0_max_entries: 10_000, enable_l1: false, l1_max_memory: 50 * 1024 * 1024, }
52 }
53}
54
55pub struct CachedTokenizer {
57 inner: Arc<dyn Tokenizer>,
59 l0: Option<L0Cache>,
61 l1: Option<L1Cache>,
63 #[allow(dead_code)]
65 config: CacheConfig,
66 fingerprint: TokenizerFingerprint,
68 special_token_strings: Vec<String>,
70}
71
72impl CachedTokenizer {
73 pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
75 let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
76
77 let l0 = if config.enable_l0 {
78 Some(L0Cache::new(config.l0_max_entries))
79 } else {
80 None
81 };
82
83 let l1 = if config.enable_l1 {
84 Some(L1Cache::new(config.l1_max_memory))
85 } else {
86 None
87 };
88
89 let special_token_strings = Self::extract_special_token_strings(&inner);
91
92 Self {
93 inner,
94 l0,
95 l1,
96 config,
97 fingerprint,
98 special_token_strings,
99 }
100 }
101
102 fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
104 let special_tokens = tokenizer.get_special_tokens();
105 let mut tokens = Vec::new();
106
107 if let Some(ref token) = special_tokens.bos_token {
108 tokens.push(token.clone());
109 }
110 if let Some(ref token) = special_tokens.eos_token {
111 tokens.push(token.clone());
112 }
113 if let Some(ref token) = special_tokens.unk_token {
114 tokens.push(token.clone());
115 }
116 if let Some(ref token) = special_tokens.sep_token {
117 tokens.push(token.clone());
118 }
119 if let Some(ref token) = special_tokens.pad_token {
120 tokens.push(token.clone());
121 }
122 if let Some(ref token) = special_tokens.cls_token {
123 tokens.push(token.clone());
124 }
125 if let Some(ref token) = special_tokens.mask_token {
126 tokens.push(token.clone());
127 }
128
129 tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
130 tokens
131 }
132
133 pub fn cache_stats(&self) -> Option<CacheStats> {
135 self.l0.as_ref().map(|cache| cache.stats())
136 }
137
138 pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
140 self.l1.as_ref().map(|cache| cache.stats())
141 }
142
143 pub fn clear_cache(&self) {
145 if let Some(l0) = &self.l0 {
146 l0.clear();
147 }
148 if let Some(l1) = &self.l1 {
149 l1.clear();
150 }
151 }
152
153 pub fn fingerprint(&self) -> &TokenizerFingerprint {
155 &self.fingerprint
156 }
157
158 pub fn inner(&self) -> &Arc<dyn Tokenizer> {
160 &self.inner
161 }
162}
163
164impl Encoder for CachedTokenizer {
165 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
166 if let Some(l0) = &self.l0 {
171 if let Some(cached) = l0.get(input) {
172 return Ok((*cached).clone());
175 }
176 }
177
178 if let Some(l1) = &self.l1 {
180 let tokens: Vec<&str> = self
182 .special_token_strings
183 .iter()
184 .map(|s| s.as_str())
185 .collect();
186
187 if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
188 let suffix = &input[prefix_len..];
190 if !suffix.is_empty() {
191 let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
192
193 let mut merged_tokens = prefix_tokens;
196 merged_tokens.extend_from_slice(suffix_encoding.token_ids());
197
198 let merged_encoding = Encoding::Sp(merged_tokens);
199
200 if let Some(l0) = &self.l0 {
202 l0.insert(input.to_string(), merged_encoding.clone());
203 }
204
205 return Ok(merged_encoding);
206 }
207 }
208 }
209
210 let encoding = self.inner.encode(input, add_special_tokens)?;
212
213 if let Some(l0) = &self.l0 {
215 l0.insert(input.to_string(), encoding.clone());
216 }
217
218 if let Some(l1) = &self.l1 {
221 let tokens: Vec<&str> = self
222 .special_token_strings
223 .iter()
224 .map(|s| s.as_str())
225 .collect();
226 let _ =
227 l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
228 }
230
231 Ok(encoding)
232 }
233
234 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
235 inputs
238 .par_iter()
239 .map(|&input| self.encode(input, add_special_tokens))
240 .collect()
241 }
242}
243
244impl Decoder for CachedTokenizer {
245 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
246 self.inner.decode(token_ids, skip_special_tokens)
248 }
249}
250
251impl Tokenizer for CachedTokenizer {
252 fn vocab_size(&self) -> usize {
253 self.inner.vocab_size()
254 }
255
256 fn get_special_tokens(&self) -> &SpecialTokens {
257 self.inner.get_special_tokens()
258 }
259
260 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
261 self.inner.token_to_id(token)
262 }
263
264 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
265 self.inner.id_to_token(id)
266 }
267
268 fn as_any(&self) -> &dyn std::any::Any {
269 self
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use crate::{mock::MockTokenizer, *};
276
277 #[test]
278 fn test_cache_hit() {
279 let tokenizer = Arc::new(MockTokenizer::new());
280 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
281
282 let input = "Hello world";
283
284 let result1 = cached.encode(input, false).unwrap();
286
287 let result2 = cached.encode(input, false).unwrap();
289
290 assert_eq!(result1.token_ids(), result2.token_ids());
292
293 let stats = cached.cache_stats().unwrap();
295 assert_eq!(stats.hits, 1);
296 assert_eq!(stats.misses, 1);
297 }
298
299 #[test]
300 fn test_cache_disabled() {
301 let tokenizer = Arc::new(MockTokenizer::new());
302 let config = CacheConfig {
303 enable_l0: false,
304 l0_max_entries: 0,
305 enable_l1: false,
306 l1_max_memory: 0,
307 };
308 let cached = CachedTokenizer::new(tokenizer, config);
309
310 let input = "Hello world";
311
312 let result1 = cached.encode(input, false).unwrap();
314 let result2 = cached.encode(input, false).unwrap();
315
316 assert_eq!(result1.token_ids(), result2.token_ids());
317
318 assert!(cached.cache_stats().is_none());
320 }
321
322 #[test]
323 fn test_encode_batch() {
324 let tokenizer = Arc::new(MockTokenizer::new());
325 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
326
327 let inputs = vec!["Hello", "world", "Hello"]; let results = cached.encode_batch(&inputs, false).unwrap();
330
331 assert_eq!(results.len(), 3);
332
333 assert_eq!(results[0].token_ids(), results[2].token_ids()); let _ = cached.encode("Hello", false).unwrap();
340 let stats = cached.cache_stats().unwrap();
341
342 assert!(
344 stats.hits >= 1,
345 "Expected at least 1 cache hit after batch processing"
346 );
347 }
348
349 #[test]
350 fn test_decoder_passthrough() {
351 let tokenizer = Arc::new(MockTokenizer::new());
352 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
353
354 let tokens = vec![1, 2, 3];
355 let decoded = cached.decode(&tokens, false).unwrap();
356
357 assert!(!decoded.is_empty());
359 }
360
361 #[test]
362 fn test_tokenizer_trait_methods() {
363 let tokenizer = Arc::new(MockTokenizer::new());
364 let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
365
366 assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
368 assert!(cached.token_to_id("Hello").is_some());
369 assert!(cached.id_to_token(1).is_some());
370 }
371}