1mod fingerprint;
18mod l0;
19mod l1;
20
21use std::sync::Arc;
22
23use anyhow::Result;
24pub use fingerprint::TokenizerFingerprint;
25pub use l0::{CacheStats, L0Cache};
26pub use l1::{L1Cache, L1CacheStats};
27use rayon::prelude::*;
28
29use crate::{
30 chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
31 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer},
32};
33
34#[derive(Debug, Clone)]
36pub struct CacheConfig {
37 pub enable_l0: bool,
39 pub l0_max_entries: usize,
41 pub enable_l1: bool,
43 pub l1_max_memory: usize,
45}
46
47impl Default for CacheConfig {
48 fn default() -> Self {
49 Self {
50 enable_l0: true,
51 l0_max_entries: 10_000, enable_l1: false, l1_max_memory: 50 * 1024 * 1024, }
55 }
56}
57
58pub struct CachedTokenizer {
60 inner: Arc<dyn Tokenizer>,
62 l0: Option<L0Cache>,
64 l1: Option<L1Cache>,
66 fingerprint: TokenizerFingerprint,
68 special_token_strings: Vec<String>,
70}
71
72impl CachedTokenizer {
73 pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
75 let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
76
77 let l0 = if config.enable_l0 {
78 Some(L0Cache::new(config.l0_max_entries))
79 } else {
80 None
81 };
82
83 let l1 = if config.enable_l1 {
84 Some(L1Cache::new(config.l1_max_memory))
85 } else {
86 None
87 };
88
89 let special_token_strings = Self::extract_special_token_strings(&inner);
91
92 Self {
93 inner,
94 l0,
95 l1,
96 fingerprint,
97 special_token_strings,
98 }
99 }
100
101 fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
103 let special_tokens = tokenizer.get_special_tokens();
104 let mut tokens = Vec::new();
105
106 if let Some(ref token) = special_tokens.bos_token {
107 tokens.push(token.clone());
108 }
109 if let Some(ref token) = special_tokens.eos_token {
110 tokens.push(token.clone());
111 }
112 if let Some(ref token) = special_tokens.unk_token {
113 tokens.push(token.clone());
114 }
115 if let Some(ref token) = special_tokens.sep_token {
116 tokens.push(token.clone());
117 }
118 if let Some(ref token) = special_tokens.pad_token {
119 tokens.push(token.clone());
120 }
121 if let Some(ref token) = special_tokens.cls_token {
122 tokens.push(token.clone());
123 }
124 if let Some(ref token) = special_tokens.mask_token {
125 tokens.push(token.clone());
126 }
127
128 tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
129 tokens
130 }
131
132 pub fn cache_stats(&self) -> Option<CacheStats> {
134 self.l0.as_ref().map(|cache| cache.stats())
135 }
136
137 pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
139 self.l1.as_ref().map(|cache| cache.stats())
140 }
141
142 pub fn clear_cache(&self) {
144 if let Some(l0) = &self.l0 {
145 l0.clear();
146 }
147 if let Some(l1) = &self.l1 {
148 l1.clear();
149 }
150 }
151
152 pub fn fingerprint(&self) -> &TokenizerFingerprint {
154 &self.fingerprint
155 }
156
157 pub fn inner(&self) -> &Arc<dyn Tokenizer> {
159 &self.inner
160 }
161}
162
163impl Encoder for CachedTokenizer {
164 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
165 if let Some(l0) = &self.l0 {
167 if let Some(cached) = l0.get(input, add_special_tokens) {
168 return Ok((*cached).clone());
169 }
170 }
171
172 if let Some(l1) = &self.l1 {
174 let tokens: Vec<&str> = self
175 .special_token_strings
176 .iter()
177 .map(|s| s.as_str())
178 .collect();
179
180 if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, &tokens) {
181 let suffix = &input[prefix_len..];
182 if !suffix.is_empty() {
183 let suffix_encoding = self.inner.encode(suffix, add_special_tokens)?;
184
185 let mut merged_tokens = prefix_tokens;
186 merged_tokens.extend_from_slice(suffix_encoding.token_ids());
187
188 let merged_encoding = Encoding::Plain(merged_tokens);
189
190 if let Some(l0) = &self.l0 {
191 l0.insert(
192 input.to_string(),
193 add_special_tokens,
194 merged_encoding.clone(),
195 );
196 }
197
198 return Ok(merged_encoding);
199 }
200 }
201 }
202
203 let encoding = self.inner.encode(input, add_special_tokens)?;
205
206 if let Some(l0) = &self.l0 {
208 l0.insert(input.to_string(), add_special_tokens, encoding.clone());
209 }
210
211 if let Some(l1) = &self.l1 {
213 let tokens: Vec<&str> = self
214 .special_token_strings
215 .iter()
216 .map(|s| s.as_str())
217 .collect();
218 let _ =
219 l1.insert_at_boundaries(input, self.inner.as_ref(), &tokens, add_special_tokens);
220 }
221
222 Ok(encoding)
223 }
224
225 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
226 inputs
229 .par_iter()
230 .map(|&input| self.encode(input, add_special_tokens))
231 .collect()
232 }
233}
234
235impl Decoder for CachedTokenizer {
236 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
237 self.inner.decode(token_ids, skip_special_tokens)
239 }
240}
241
242impl Tokenizer for CachedTokenizer {
243 fn vocab_size(&self) -> usize {
244 self.inner.vocab_size()
245 }
246
247 fn get_special_tokens(&self) -> &SpecialTokens {
248 self.inner.get_special_tokens()
249 }
250
251 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
252 self.inner.token_to_id(token)
253 }
254
255 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
256 self.inner.id_to_token(id)
257 }
258
259 fn as_any(&self) -> &dyn std::any::Any {
260 self
261 }
262
263 fn apply_chat_template(
264 &self,
265 messages: &[serde_json::Value],
266 params: ChatTemplateParams,
267 ) -> Result<String> {
268 self.inner.apply_chat_template(messages, params)
269 }
270
271 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
272 self.inner.chat_template_content_format()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use crate::{mock::MockTokenizer, *};
279
280 #[test]
281 fn test_cache_hit() {
282 let tokenizer = Arc::new(MockTokenizer::new());
283 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
284
285 let input = "Hello world";
286
287 let result1 = cached.encode(input, false).unwrap();
289
290 let result2 = cached.encode(input, false).unwrap();
292
293 assert_eq!(result1.token_ids(), result2.token_ids());
295
296 let stats = cached.cache_stats().unwrap();
298 assert_eq!(stats.hits, 1);
299 assert_eq!(stats.misses, 1);
300 }
301
302 #[test]
303 fn test_cache_disabled() {
304 let tokenizer = Arc::new(MockTokenizer::new());
305 let config = CacheConfig {
306 enable_l0: false,
307 l0_max_entries: 0,
308 enable_l1: false,
309 l1_max_memory: 0,
310 };
311 let cached = CachedTokenizer::new(tokenizer, config);
312
313 let input = "Hello world";
314
315 let result1 = cached.encode(input, false).unwrap();
317 let result2 = cached.encode(input, false).unwrap();
318
319 assert_eq!(result1.token_ids(), result2.token_ids());
320
321 assert!(cached.cache_stats().is_none());
323 }
324
325 #[test]
326 fn test_encode_batch() {
327 let tokenizer = Arc::new(MockTokenizer::new());
328 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
329
330 let inputs = vec!["Hello", "world", "Hello"]; let results = cached.encode_batch(&inputs, false).unwrap();
333
334 assert_eq!(results.len(), 3);
335
336 assert_eq!(results[0].token_ids(), results[2].token_ids()); let _ = cached.encode("Hello", false).unwrap();
343 let stats = cached.cache_stats().unwrap();
344
345 assert!(
347 stats.hits >= 1,
348 "Expected at least 1 cache hit after batch processing"
349 );
350 }
351
352 #[test]
353 fn test_decoder_passthrough() {
354 let tokenizer = Arc::new(MockTokenizer::new());
355 let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
356
357 let tokens = vec![1, 2, 3];
358 let decoded = cached.decode(&tokens, false).unwrap();
359
360 assert!(!decoded.is_empty());
362 }
363
364 #[test]
365 fn test_tokenizer_trait_methods() {
366 let tokenizer = Arc::new(MockTokenizer::new());
367 let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
368
369 assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
371 assert!(cached.token_to_id("Hello").is_some());
372 assert!(cached.id_to_token(1).is_some());
373 }
374}