1mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30 chat_template::{
31 ChatTemplateContentFormat, ChatTemplateParams, ThinkingKeyName, ThinkingToggle,
32 },
33 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
34};
35
36#[derive(Debug, Clone)]
38pub struct CacheConfig {
39 pub enable_l0: bool,
41 pub l0_max_entries: usize,
43 pub enable_l1: bool,
45 pub l1_max_memory: usize,
47}
48
49impl Default for CacheConfig {
50 fn default() -> Self {
51 Self {
52 enable_l0: true,
53 l0_max_entries: 10_000, enable_l1: false, l1_max_memory: 50 * 1024 * 1024, }
57 }
58}
59
60pub struct CachedTokenizer {
62 inner: Arc<dyn Tokenizer>,
64 l0: Option<L0Cache>,
66 l1: Option<L1Cache>,
68 fingerprint: TokenizerFingerprint,
70 special_token_strings: Vec<String>,
72}
73
74impl CachedTokenizer {
75 pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
77 let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
78
79 let l0 = if config.enable_l0 {
80 Some(L0Cache::new(config.l0_max_entries))
81 } else {
82 None
83 };
84
85 let l1 = if config.enable_l1 {
86 Some(L1Cache::new(config.l1_max_memory))
87 } else {
88 None
89 };
90
91 let special_token_strings = Self::extract_special_token_strings(&inner);
93
94 Self {
95 inner,
96 l0,
97 l1,
98 fingerprint,
99 special_token_strings,
100 }
101 }
102
103 fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
105 let special_tokens = tokenizer.get_special_tokens();
106 let mut tokens = Vec::new();
107
108 if let Some(ref token) = special_tokens.bos_token {
109 tokens.push(token.clone());
110 }
111 if let Some(ref token) = special_tokens.eos_token {
112 tokens.push(token.clone());
113 }
114 if let Some(ref token) = special_tokens.unk_token {
115 tokens.push(token.clone());
116 }
117 if let Some(ref token) = special_tokens.sep_token {
118 tokens.push(token.clone());
119 }
120 if let Some(ref token) = special_tokens.pad_token {
121 tokens.push(token.clone());
122 }
123 if let Some(ref token) = special_tokens.cls_token {
124 tokens.push(token.clone());
125 }
126 if let Some(ref token) = special_tokens.mask_token {
127 tokens.push(token.clone());
128 }
129
130 tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
131 tokens
132 }
133
134 pub fn cache_stats(&self) -> Option<CacheStats> {
136 self.l0.as_ref().map(|cache| cache.stats())
137 }
138
139 pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
141 self.l1.as_ref().map(|cache| cache.stats())
142 }
143
144 pub fn clear_cache(&self) {
146 if let Some(l0) = &self.l0 {
147 l0.clear();
148 }
149 if let Some(l1) = &self.l1 {
150 l1.clear();
151 }
152 }
153
154 pub fn fingerprint(&self) -> &TokenizerFingerprint {
156 &self.fingerprint
157 }
158
159 pub fn inner(&self) -> &Arc<dyn Tokenizer> {
161 &self.inner
162 }
163}
164
165impl Encoder for CachedTokenizer {
166 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
167 if let Some(l0) = &self.l0 {
169 if let Some(cached) = l0.get(input, add_special_tokens) {
170 return Ok((*cached).clone());
171 }
172 }
173
174 if let Some(l1) = &self.l1 {
176 let tokens: Vec<&str> = self
177 .special_token_strings
178 .iter()
179 .map(|s| s.as_str())
180 .collect();
181
182 if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
183 let suffix = &input[prefix_len..];
184 if !suffix.is_empty() {
185 let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
186
187 let mut merged_tokens = prefix_tokens;
188 merged_tokens.extend_from_slice(suffix_encoding.token_ids());
189
190 let merged_encoding = Encoding::Plain(merged_tokens);
191
192 if let Some(l0) = &self.l0 {
193 l0.insert(
194 input.to_string(),
195 add_special_tokens,
196 merged_encoding.clone(),
197 );
198 }
199
200 return Ok(merged_encoding);
201 }
202 }
203 }
204
205 let encoding = self.inner.encode(input, add_special_tokens)?;
207
208 if let Some(l0) = &self.l0 {
210 l0.insert(input.to_string(), add_special_tokens, encoding.clone());
211 }
212
213 if let Some(l1) = &self.l1 {
215 let tokens: Vec<&str> = self
216 .special_token_strings
217 .iter()
218 .map(|s| s.as_str())
219 .collect();
220 let _ =
221 l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
222 }
223
224 Ok(encoding)
225 }
226
227 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
228 inputs
231 .par_iter()
232 .map(|&input| self.encode(input, add_special_tokens))
233 .collect()
234 }
235}
236
237impl Decoder for CachedTokenizer {
238 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
239 self.inner.decode(token_ids, skip_special_tokens)
241 }
242}
243
244impl Tokenizer for CachedTokenizer {
245 fn vocab_size(&self) -> usize {
246 self.inner.vocab_size()
247 }
248
249 fn get_special_tokens(&self) -> &SpecialTokens {
250 self.inner.get_special_tokens()
251 }
252
253 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
254 self.inner.token_to_id(token)
255 }
256
257 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
258 self.inner.id_to_token(id)
259 }
260
261 fn as_any(&self) -> &dyn std::any::Any {
262 self
263 }
264
265 fn apply_chat_template(
266 &self,
267 messages: &[serde_json::Value],
268 params: ChatTemplateParams,
269 ) -> Result<String> {
270 self.inner.apply_chat_template(messages, params)
271 }
272
273 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
274 self.inner.chat_template_content_format()
275 }
276
277 fn thinking_toggle(&self) -> ThinkingToggle {
278 self.inner.thinking_toggle()
279 }
280
281 fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
282 self.inner.thinking_key_name()
283 }
284 fn think_in_prefill(&self) -> bool {
285 self.inner.think_in_prefill()
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use crate::{mock::MockTokenizer, *};
292
293 #[test]
294 fn test_cache_hit() {
295 let tokenizer = Arc::new(MockTokenizer::new());
296 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
297
298 let input = "Hello world";
299
300 let result1 = cached.encode(input, false).unwrap();
302
303 let result2 = cached.encode(input, false).unwrap();
305
306 assert_eq!(result1.token_ids(), result2.token_ids());
308
309 let stats = cached.cache_stats().unwrap();
311 assert_eq!(stats.hits, 1);
312 assert_eq!(stats.misses, 1);
313 }
314
315 #[test]
316 fn test_cache_disabled() {
317 let tokenizer = Arc::new(MockTokenizer::new());
318 let config = CacheConfig {
319 enable_l0: false,
320 l0_max_entries: 0,
321 enable_l1: false,
322 l1_max_memory: 0,
323 };
324 let cached = CachedTokenizer::new(tokenizer, config);
325
326 let input = "Hello world";
327
328 let result1 = cached.encode(input, false).unwrap();
330 let result2 = cached.encode(input, false).unwrap();
331
332 assert_eq!(result1.token_ids(), result2.token_ids());
333
334 assert!(cached.cache_stats().is_none());
336 }
337
338 #[test]
339 fn test_encode_batch() {
340 let tokenizer = Arc::new(MockTokenizer::new());
341 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
342
343 let inputs = vec!["Hello", "world", "Hello"]; let results = cached.encode_batch(&inputs, false).unwrap();
346
347 assert_eq!(results.len(), 3);
348
349 assert_eq!(results[0].token_ids(), results[2].token_ids()); let _ = cached.encode("Hello", false).unwrap();
356 let stats = cached.cache_stats().unwrap();
357
358 assert!(
360 stats.hits >= 1,
361 "Expected at least 1 cache hit after batch processing"
362 );
363 }
364
365 #[test]
366 fn test_decoder_passthrough() {
367 let tokenizer = Arc::new(MockTokenizer::new());
368 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
369
370 let tokens = vec![1, 2, 3];
371 let decoded = cached.decode(&tokens, false).unwrap();
372
373 assert!(!decoded.is_empty());
375 }
376
377 #[test]
378 fn test_tokenizer_trait_methods() {
379 let tokenizer = Arc::new(MockTokenizer::new());
380 let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
381
382 assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
384 assert!(cached.token_to_id("Hello").is_some());
385 assert!(cached.id_to_token(1).is_some());
386 }
387}