1mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30 chat_template::{
31 ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
32 },
33 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
34};
35
36#[derive(Debug, Clone)]
38pub struct CacheConfig {
39 pub enable_l0: bool,
41 pub l0_max_entries: usize,
43 pub enable_l1: bool,
45 pub l1_max_memory: usize,
47}
48
49impl Default for CacheConfig {
50 fn default() -> Self {
51 Self {
52 enable_l0: true,
53 l0_max_entries: 10_000, enable_l1: false, l1_max_memory: 50 * 1024 * 1024, }
57 }
58}
59
60pub struct CachedTokenizer {
62 inner: Arc<dyn Tokenizer>,
64 l0: Option<L0Cache>,
66 l1: Option<L1Cache>,
68 fingerprint: TokenizerFingerprint,
70 special_token_strings: Vec<String>,
72}
73
74impl CachedTokenizer {
75 pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
77 let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
78
79 let l0 = if config.enable_l0 {
80 Some(L0Cache::new(config.l0_max_entries))
81 } else {
82 None
83 };
84
85 let l1 = if config.enable_l1 {
86 Some(L1Cache::new(config.l1_max_memory))
87 } else {
88 None
89 };
90
91 let special_token_strings = Self::extract_special_token_strings(&inner);
93
94 Self {
95 inner,
96 l0,
97 l1,
98 fingerprint,
99 special_token_strings,
100 }
101 }
102
103 fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
105 let special_tokens = tokenizer.get_special_tokens();
106 let mut tokens = Vec::new();
107
108 if let Some(ref token) = special_tokens.bos_token {
109 tokens.push(token.clone());
110 }
111 if let Some(ref token) = special_tokens.eos_token {
112 tokens.push(token.clone());
113 }
114 if let Some(ref token) = special_tokens.unk_token {
115 tokens.push(token.clone());
116 }
117 if let Some(ref token) = special_tokens.sep_token {
118 tokens.push(token.clone());
119 }
120 if let Some(ref token) = special_tokens.pad_token {
121 tokens.push(token.clone());
122 }
123 if let Some(ref token) = special_tokens.cls_token {
124 tokens.push(token.clone());
125 }
126 if let Some(ref token) = special_tokens.mask_token {
127 tokens.push(token.clone());
128 }
129
130 tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
131 tokens
132 }
133
134 pub fn cache_stats(&self) -> Option<CacheStats> {
136 self.l0.as_ref().map(|cache| cache.stats())
137 }
138
139 pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
141 self.l1.as_ref().map(|cache| cache.stats())
142 }
143
144 pub fn clear_cache(&self) {
146 if let Some(l0) = &self.l0 {
147 l0.clear();
148 }
149 if let Some(l1) = &self.l1 {
150 l1.clear();
151 }
152 }
153
154 pub fn fingerprint(&self) -> &TokenizerFingerprint {
156 &self.fingerprint
157 }
158
159 pub fn inner(&self) -> &Arc<dyn Tokenizer> {
161 &self.inner
162 }
163}
164
165impl Encoder for CachedTokenizer {
166 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
167 if let Some(l0) = &self.l0 {
169 if let Some(cached) = l0.get(input, add_special_tokens) {
170 return Ok((*cached).clone());
171 }
172 }
173
174 if let Some(l1) = &self.l1 {
176 let tokens: Vec<&str> = self
177 .special_token_strings
178 .iter()
179 .map(|s| s.as_str())
180 .collect();
181
182 if let Some((prefix_tokens, prefix_len)) =
183 l1.longest_prefix_match(input, &tokens, add_special_tokens)
184 {
185 let suffix = &input[prefix_len..];
186 if !suffix.is_empty() {
187 let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
188
189 let mut merged_tokens = prefix_tokens;
190 merged_tokens.extend_from_slice(suffix_encoding.token_ids());
191
192 let merged_encoding = Encoding::Plain(merged_tokens);
193
194 if let Some(l0) = &self.l0 {
195 l0.insert(
196 input.to_string(),
197 add_special_tokens,
198 merged_encoding.clone(),
199 );
200 }
201
202 return Ok(merged_encoding);
203 }
204 }
205 }
206
207 let encoding = self.inner.encode(input, add_special_tokens)?;
209
210 if let Some(l0) = &self.l0 {
212 l0.insert(input.to_string(), add_special_tokens, encoding.clone());
213 }
214
215 if let Some(l1) = &self.l1 {
217 let tokens: Vec<&str> = self
218 .special_token_strings
219 .iter()
220 .map(|s| s.as_str())
221 .collect();
222 let _ =
223 l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
224 }
225
226 Ok(encoding)
227 }
228
229 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
230 inputs
233 .par_iter()
234 .map(|&input| self.encode(input, add_special_tokens))
235 .collect()
236 }
237}
238
239impl Decoder for CachedTokenizer {
240 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
241 self.inner.decode(token_ids, skip_special_tokens)
243 }
244}
245
246impl Tokenizer for CachedTokenizer {
247 fn vocab_size(&self) -> usize {
248 self.inner.vocab_size()
249 }
250
251 fn get_special_tokens(&self) -> &SpecialTokens {
252 self.inner.get_special_tokens()
253 }
254
255 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
256 self.inner.token_to_id(token)
257 }
258
259 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
260 self.inner.id_to_token(id)
261 }
262
263 fn as_any(&self) -> &dyn std::any::Any {
264 self
265 }
266
267 fn apply_chat_template(
268 &self,
269 messages: &[serde_json::Value],
270 params: ChatTemplateParams,
271 ) -> Result<String> {
272 self.inner.apply_chat_template(messages, params)
273 }
274
275 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
276 self.inner.chat_template_content_format()
277 }
278
279 fn thinking_toggle(&self) -> ThinkingToggle {
280 self.inner.thinking_toggle()
281 }
282
283 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
284 self.inner.thinking_key_name()
285 }
286 fn think_in_prefill(&self) -> bool {
287 self.inner.think_in_prefill()
288 }
289
290 fn eos_token_ids(&self) -> &[TokenIdType] {
291 self.inner.eos_token_ids()
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use crate::{mock::MockTokenizer, *};
298
299 #[test]
300 fn test_cache_hit() {
301 let tokenizer = Arc::new(MockTokenizer::new());
302 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
303
304 let input = "Hello world";
305
306 let result1 = cached.encode(input, false).unwrap();
308
309 let result2 = cached.encode(input, false).unwrap();
311
312 assert_eq!(result1.token_ids(), result2.token_ids());
314
315 let stats = cached.cache_stats().unwrap();
317 assert_eq!(stats.hits, 1);
318 assert_eq!(stats.misses, 1);
319 }
320
321 #[test]
322 fn test_cache_disabled() {
323 let tokenizer = Arc::new(MockTokenizer::new());
324 let config = CacheConfig {
325 enable_l0: false,
326 l0_max_entries: 0,
327 enable_l1: false,
328 l1_max_memory: 0,
329 };
330 let cached = CachedTokenizer::new(tokenizer, config);
331
332 let input = "Hello world";
333
334 let result1 = cached.encode(input, false).unwrap();
336 let result2 = cached.encode(input, false).unwrap();
337
338 assert_eq!(result1.token_ids(), result2.token_ids());
339
340 assert!(cached.cache_stats().is_none());
342 }
343
344 #[test]
345 fn test_encode_batch() {
346 let tokenizer = Arc::new(MockTokenizer::new());
347 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
348
349 let inputs = vec!["Hello", "world", "Hello"]; let results = cached.encode_batch(&inputs, false).unwrap();
352
353 assert_eq!(results.len(), 3);
354
355 assert_eq!(results[0].token_ids(), results[2].token_ids()); let _ = cached.encode("Hello", false).unwrap();
362 let stats = cached.cache_stats().unwrap();
363
364 assert!(
366 stats.hits >= 1,
367 "Expected at least 1 cache hit after batch processing"
368 );
369 }
370
371 #[test]
372 fn test_decoder_passthrough() {
373 let tokenizer = Arc::new(MockTokenizer::new());
374 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
375
376 let tokens = vec![1, 2, 3];
377 let decoded = cached.decode(&tokens, false).unwrap();
378
379 assert!(!decoded.is_empty());
381 }
382
383 #[test]
384 fn test_tokenizer_trait_methods() {
385 let tokenizer = Arc::new(MockTokenizer::new());
386 let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
387
388 assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
390 assert!(cached.token_to_id("Hello").is_some());
391 assert!(cached.id_to_token(1).is_some());
392 }
393}