infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17struct CacheEntry {
19 count: u32,
20 last_access: u32,
22}
23
24static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), CacheEntry>> = OnceLock::new();
27
28const MAX_CACHE_ENTRIES: usize = 100_000;
32
33const EVICTION_FRACTION: usize = 2;
35
36fn get_token_cache() -> &'static DashMap<(u64, TokenModel), CacheEntry> {
38 TOKEN_CACHE.get_or_init(DashMap::new)
39}
40
41fn current_timestamp() -> u32 {
43 std::time::SystemTime::now()
44 .duration_since(std::time::UNIX_EPOCH)
45 .map(|d| d.as_secs() as u32)
46 .unwrap_or(0)
47}
48
49fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), CacheEntry>) {
53 if cache.len() < MAX_CACHE_ENTRIES {
54 return;
55 }
56
57 let mut entries: Vec<((u64, TokenModel), u32)> = cache
59 .iter()
60 .map(|entry| (*entry.key(), entry.value().last_access))
61 .collect();
62
63 entries.sort_by_key(|(_, ts)| *ts);
65
66 let to_remove = entries.len() / EVICTION_FRACTION;
68 for (key, _) in entries.into_iter().take(to_remove) {
69 cache.remove(&key);
70 }
71}
72
73fn hash_content(content: &str) -> u64 {
75 use std::collections::hash_map::DefaultHasher;
76 let mut hasher = DefaultHasher::new();
77 content.hash(&mut hasher);
78 hasher.finish()
79}
80
81fn get_gpt4o_tokenizer() -> &'static CoreBPE {
83 GPT4O_TOKENIZER.get_or_init(|| {
84 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
85 })
86}
87
88fn get_gpt4_tokenizer() -> &'static CoreBPE {
90 GPT4_TOKENIZER.get_or_init(|| {
91 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
92 })
93}
94
95#[derive(Clone, Copy)]
98struct EstimationStats {
99 len: usize,
100 whitespace_count: u32,
101 newline_count: u32,
102 special_char_count: u32,
103}
104
105pub struct Tokenizer {
111 use_exact: bool,
113 use_cache: bool,
115}
116
117impl Default for Tokenizer {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl Tokenizer {
124 pub fn new() -> Self {
126 Self { use_exact: true, use_cache: true }
127 }
128
129 pub fn estimation_only() -> Self {
131 Self { use_exact: false, use_cache: true }
132 }
133
134 pub fn without_cache() -> Self {
136 Self { use_exact: true, use_cache: false }
137 }
138
139 #[must_use]
150 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
151 if text.is_empty() {
152 return 0;
153 }
154
155 if self.use_cache {
156 let cache = get_token_cache();
157 let content_hash = hash_content(text);
158 let key = (content_hash, model);
159 let now = current_timestamp();
160
161 if let Some(mut entry) = cache.get_mut(&key) {
163 entry.last_access = now;
164 return entry.count;
165 }
166
167 let count = self.count_uncached(text, model);
169 maybe_cleanup_cache(cache);
170 cache.insert(key, CacheEntry { count, last_access: now });
171 count
172 } else {
173 self.count_uncached(text, model)
174 }
175 }
176
177 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
179 if self.use_exact && model.has_exact_tokenizer() {
180 self.count_exact(text, model)
181 } else {
182 self.estimate(text, model)
183 }
184 }
185
186 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
190 if model.uses_o200k() {
191 let tokenizer = get_gpt4o_tokenizer();
194 self.tokenize_with_panic_guard(tokenizer, text, model)
195 } else if model.uses_cl100k() {
196 let tokenizer = get_gpt4_tokenizer();
199 self.tokenize_with_panic_guard(tokenizer, text, model)
200 } else {
201 self.estimate(text, model)
203 }
204 }
205
206 fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
209 let prev_hook = std::panic::take_hook();
211 std::panic::set_hook(Box::new(|_| {
212 }));
214
215 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
216 tokenizer.encode_ordinary(text).len() as u32
217 }));
218
219 std::panic::set_hook(prev_hook);
221
222 match result {
223 Ok(count) => count,
224 Err(_) => self.estimate(text, model), }
226 }
227
228 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
231 if text.is_empty() {
232 return 0;
233 }
234 let stats = compute_estimation_stats(text);
235 estimate_from_stats(&stats, model)
236 }
237
238 pub fn count_all(&self, text: &str) -> TokenCounts {
248 if text.is_empty() {
249 return TokenCounts::default();
250 }
251
252 let content_hash = hash_content(text);
254 let cache = if self.use_cache {
255 Some(get_token_cache())
256 } else {
257 None
258 };
259
260 let now = current_timestamp();
262 let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
263 if let Some(cache) = cache {
264 let key = (content_hash, model);
265 if let Some(mut entry) = cache.get_mut(&key) {
266 entry.last_access = now;
267 return entry.count;
268 }
269 let count = self.tokenize_with_panic_guard(tokenizer, text, model);
270 maybe_cleanup_cache(cache);
271 cache.insert(key, CacheEntry { count, last_access: now });
272 count
273 } else {
274 self.tokenize_with_panic_guard(tokenizer, text, model)
275 }
276 };
277
278 let stats = compute_estimation_stats(text);
280
281 let o200k = if self.use_exact {
283 get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
284 } else {
285 estimate_from_stats(&stats, TokenModel::Gpt4o)
286 };
287
288 let cl100k = if self.use_exact {
289 get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
290 } else {
291 estimate_from_stats(&stats, TokenModel::Gpt4)
292 };
293
294 TokenCounts {
296 o200k,
297 cl100k,
298 claude: estimate_from_stats(&stats, TokenModel::Claude),
299 gemini: estimate_from_stats(&stats, TokenModel::Gemini),
300 llama: estimate_from_stats(&stats, TokenModel::Llama),
301 mistral: estimate_from_stats(&stats, TokenModel::Mistral),
302 deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
303 qwen: estimate_from_stats(&stats, TokenModel::Qwen),
304 cohere: estimate_from_stats(&stats, TokenModel::Cohere),
305 grok: estimate_from_stats(&stats, TokenModel::Grok),
306 }
307 }
308
309 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
311 let counts = self.count_all(text);
312 let models = [
313 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
316 (TokenModel::Gemini, counts.gemini),
317 (TokenModel::Llama, counts.llama),
318 (TokenModel::Mistral, counts.mistral),
319 (TokenModel::DeepSeek, counts.deepseek),
320 (TokenModel::Qwen, counts.qwen),
321 (TokenModel::Cohere, counts.cohere),
322 (TokenModel::Grok, counts.grok),
323 ];
324
325 models
327 .into_iter()
328 .min_by_key(|(_, count)| *count)
329 .unwrap_or((TokenModel::Claude, 0))
330 }
331
332 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
337 let current = self.count(text, model);
338 if current <= budget {
339 return text;
340 }
341
342 let mut low = 0usize;
344 let mut high = text.len();
345
346 while low < high {
347 let mid_raw = (low + high).div_ceil(2);
348 let mid = text.floor_char_boundary(mid_raw);
350
351 if mid <= low {
355 break;
356 }
357
358 let count = self.count(&text[..mid], model);
359
360 if count <= budget {
361 low = mid;
362 } else {
363 high = mid.saturating_sub(1);
364 }
365 }
366
367 let low = text.floor_char_boundary(low);
369
370 let mut end = low;
372 while end > 0 {
373 let boundary = text.floor_char_boundary(end);
375 if boundary < end {
376 end = boundary;
377 continue;
378 }
379
380 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
381 if c == b' ' || c == b'\n' {
383 break;
384 }
385 end -= 1;
386 }
387
388 let end = text.floor_char_boundary(end);
390
391 if end > 0 {
392 &text[..end]
393 } else {
394 &text[..low]
396 }
397 }
398
399 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
401 self.count(text, model) > budget
402 }
403}
404
405pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
407 if text.is_empty() {
408 return 0;
409 }
410 let chars_per_token = model.chars_per_token();
411 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
412}
413
414fn compute_estimation_stats(text: &str) -> EstimationStats {
417 let mut whitespace_count = 0u32;
418 let mut newline_count = 0u32;
419 let mut special_char_count = 0u32;
420
421 for &byte in text.as_bytes() {
423 match byte {
424 b' ' | b'\t' => whitespace_count += 1,
425 b'\n' => newline_count += 1,
426 b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'=' | b'+'
427 | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#' | b'$' | b'%'
428 | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
429 _ => {},
430 }
431 }
432
433 EstimationStats { len: text.len(), whitespace_count, newline_count, special_char_count }
434}
435
436fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
438 let chars_per_token = model.chars_per_token();
439 let len = stats.len as f32;
440
441 let mut estimate = len / chars_per_token;
443
444 estimate -= stats.whitespace_count as f32 * 0.3;
446
447 estimate += stats.newline_count as f32 * 0.5;
449
450 if matches!(
452 model,
453 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
454 ) {
455 estimate += stats.special_char_count as f32 * 0.3;
456 }
457
458 estimate.ceil().max(1.0) as u32
459}