infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21const MAX_CACHE_ENTRIES: usize = 100_000;
25
26fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28 TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35 if cache.len() >= MAX_CACHE_ENTRIES {
36 cache.clear();
37 }
38}
39
40fn hash_content(content: &str) -> u64 {
42 use std::collections::hash_map::DefaultHasher;
43 let mut hasher = DefaultHasher::new();
44 content.hash(&mut hasher);
45 hasher.finish()
46}
47
48fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50 GPT4O_TOKENIZER.get_or_init(|| {
51 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52 })
53}
54
55fn get_gpt4_tokenizer() -> &'static CoreBPE {
57 GPT4_TOKENIZER.get_or_init(|| {
58 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59 })
60}
61
62#[derive(Clone, Copy)]
65struct EstimationStats {
66 len: usize,
67 whitespace_count: u32,
68 newline_count: u32,
69 special_char_count: u32,
70}
71
72pub struct Tokenizer {
78 use_exact: bool,
80 use_cache: bool,
82}
83
84impl Default for Tokenizer {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Tokenizer {
91 pub fn new() -> Self {
93 Self { use_exact: true, use_cache: true }
94 }
95
96 pub fn estimation_only() -> Self {
98 Self { use_exact: false, use_cache: true }
99 }
100
101 pub fn without_cache() -> Self {
103 Self { use_exact: true, use_cache: false }
104 }
105
106 #[must_use]
117 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
118 if text.is_empty() {
119 return 0;
120 }
121
122 if self.use_cache {
123 let cache = get_token_cache();
124 let content_hash = hash_content(text);
125 let key = (content_hash, model);
126
127 if let Some(count) = cache.get(&key) {
129 return *count;
130 }
131
132 let count = self.count_uncached(text, model);
134 maybe_cleanup_cache(cache);
135 cache.insert(key, count);
136 count
137 } else {
138 self.count_uncached(text, model)
139 }
140 }
141
142 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
144 if self.use_exact && model.has_exact_tokenizer() {
145 self.count_exact(text, model)
146 } else {
147 self.estimate(text, model)
148 }
149 }
150
151 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
155 if model.uses_o200k() {
156 let tokenizer = get_gpt4o_tokenizer();
159 self.tokenize_with_panic_guard(tokenizer, text, model)
160 } else if model.uses_cl100k() {
161 let tokenizer = get_gpt4_tokenizer();
164 self.tokenize_with_panic_guard(tokenizer, text, model)
165 } else {
166 self.estimate(text, model)
168 }
169 }
170
171 fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
174 let prev_hook = std::panic::take_hook();
176 std::panic::set_hook(Box::new(|_| {
177 }));
179
180 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181 tokenizer.encode_ordinary(text).len() as u32
182 }));
183
184 std::panic::set_hook(prev_hook);
186
187 match result {
188 Ok(count) => count,
189 Err(_) => self.estimate(text, model), }
191 }
192
193 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
196 if text.is_empty() {
197 return 0;
198 }
199 let stats = compute_estimation_stats(text);
200 estimate_from_stats(&stats, model)
201 }
202
203
204 pub fn count_all(&self, text: &str) -> TokenCounts {
214 if text.is_empty() {
215 return TokenCounts::default();
216 }
217
218 let content_hash = hash_content(text);
220 let cache = if self.use_cache { Some(get_token_cache()) } else { None };
221
222 let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
224 if let Some(cache) = cache {
225 let key = (content_hash, model);
226 if let Some(count) = cache.get(&key) {
227 return *count;
228 }
229 let count = self.tokenize_with_panic_guard(tokenizer, text, model);
230 maybe_cleanup_cache(cache);
231 cache.insert(key, count);
232 count
233 } else {
234 self.tokenize_with_panic_guard(tokenizer, text, model)
235 }
236 };
237
238 let stats = compute_estimation_stats(text);
240
241 let o200k = if self.use_exact {
243 get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
244 } else {
245 estimate_from_stats(&stats, TokenModel::Gpt4o)
246 };
247
248 let cl100k = if self.use_exact {
249 get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
250 } else {
251 estimate_from_stats(&stats, TokenModel::Gpt4)
252 };
253
254 TokenCounts {
256 o200k,
257 cl100k,
258 claude: estimate_from_stats(&stats, TokenModel::Claude),
259 gemini: estimate_from_stats(&stats, TokenModel::Gemini),
260 llama: estimate_from_stats(&stats, TokenModel::Llama),
261 mistral: estimate_from_stats(&stats, TokenModel::Mistral),
262 deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
263 qwen: estimate_from_stats(&stats, TokenModel::Qwen),
264 cohere: estimate_from_stats(&stats, TokenModel::Cohere),
265 grok: estimate_from_stats(&stats, TokenModel::Grok),
266 }
267 }
268
269 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
271 let counts = self.count_all(text);
272 let models = [
273 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
276 (TokenModel::Gemini, counts.gemini),
277 (TokenModel::Llama, counts.llama),
278 (TokenModel::Mistral, counts.mistral),
279 (TokenModel::DeepSeek, counts.deepseek),
280 (TokenModel::Qwen, counts.qwen),
281 (TokenModel::Cohere, counts.cohere),
282 (TokenModel::Grok, counts.grok),
283 ];
284
285 models
287 .into_iter()
288 .min_by_key(|(_, count)| *count)
289 .unwrap_or((TokenModel::Claude, 0))
290 }
291
292 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
294 let current = self.count(text, model);
295 if current <= budget {
296 return text;
297 }
298
299 let mut low = 0usize;
301 let mut high = text.len();
302
303 while low < high {
304 let mid_raw = (low + high).div_ceil(2);
305 let mid = text.floor_char_boundary(mid_raw);
307
308 if mid <= low {
312 break;
313 }
314
315 let count = self.count(&text[..mid], model);
316
317 if count <= budget {
318 low = mid;
319 } else {
320 high = mid.saturating_sub(1);
321 }
322 }
323
324 let mut end = low;
326 while end > 0 {
327 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
328 if c == b' ' || c == b'\n' {
329 break;
330 }
331 end -= 1;
332 }
333
334 if end > 0 {
335 &text[..end]
336 } else {
337 let low = text.floor_char_boundary(low);
338 &text[..low]
339 }
340 }
341
342 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
344 self.count(text, model) > budget
345 }
346}
347
348pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
350 if text.is_empty() {
351 return 0;
352 }
353 let chars_per_token = model.chars_per_token();
354 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
355}
356
357fn compute_estimation_stats(text: &str) -> EstimationStats {
360 let mut whitespace_count = 0u32;
361 let mut newline_count = 0u32;
362 let mut special_char_count = 0u32;
363
364 for &byte in text.as_bytes() {
366 match byte {
367 b' ' | b'\t' => whitespace_count += 1,
368 b'\n' => newline_count += 1,
369 b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'='
370 | b'+' | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#'
371 | b'$' | b'%' | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
372 _ => {}
373 }
374 }
375
376 EstimationStats {
377 len: text.len(),
378 whitespace_count,
379 newline_count,
380 special_char_count,
381 }
382}
383
384fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
386 let chars_per_token = model.chars_per_token();
387 let len = stats.len as f32;
388
389 let mut estimate = len / chars_per_token;
391
392 estimate -= stats.whitespace_count as f32 * 0.3;
394
395 estimate += stats.newline_count as f32 * 0.5;
397
398 if matches!(
400 model,
401 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
402 ) {
403 estimate += stats.special_char_count as f32 * 0.3;
404 }
405
406 estimate.ceil().max(1.0) as u32
407}