infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21const MAX_CACHE_ENTRIES: usize = 100_000;
25
26fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28 TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35 if cache.len() >= MAX_CACHE_ENTRIES {
36 cache.clear();
37 }
38}
39
40fn hash_content(content: &str) -> u64 {
42 use std::collections::hash_map::DefaultHasher;
43 let mut hasher = DefaultHasher::new();
44 content.hash(&mut hasher);
45 hasher.finish()
46}
47
48fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50 GPT4O_TOKENIZER.get_or_init(|| {
51 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52 })
53}
54
55fn get_gpt4_tokenizer() -> &'static CoreBPE {
57 GPT4_TOKENIZER.get_or_init(|| {
58 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59 })
60}
61
62#[derive(Clone, Copy)]
65struct EstimationStats {
66 len: usize,
67 whitespace_count: u32,
68 newline_count: u32,
69 special_char_count: u32,
70}
71
72pub struct Tokenizer {
78 use_exact: bool,
80 use_cache: bool,
82}
83
84impl Default for Tokenizer {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Tokenizer {
91 pub fn new() -> Self {
93 Self { use_exact: true, use_cache: true }
94 }
95
96 pub fn estimation_only() -> Self {
98 Self { use_exact: false, use_cache: true }
99 }
100
101 pub fn without_cache() -> Self {
103 Self { use_exact: true, use_cache: false }
104 }
105
106 #[must_use]
117 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
118 if text.is_empty() {
119 return 0;
120 }
121
122 if self.use_cache {
123 let cache = get_token_cache();
124 let content_hash = hash_content(text);
125 let key = (content_hash, model);
126
127 if let Some(count) = cache.get(&key) {
129 return *count;
130 }
131
132 let count = self.count_uncached(text, model);
134 maybe_cleanup_cache(cache);
135 cache.insert(key, count);
136 count
137 } else {
138 self.count_uncached(text, model)
139 }
140 }
141
142 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
144 if self.use_exact && model.has_exact_tokenizer() {
145 self.count_exact(text, model)
146 } else {
147 self.estimate(text, model)
148 }
149 }
150
151 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
155 if model.uses_o200k() {
156 let tokenizer = get_gpt4o_tokenizer();
159 self.tokenize_with_panic_guard(tokenizer, text, model)
160 } else if model.uses_cl100k() {
161 let tokenizer = get_gpt4_tokenizer();
164 self.tokenize_with_panic_guard(tokenizer, text, model)
165 } else {
166 self.estimate(text, model)
168 }
169 }
170
171 fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
174 let prev_hook = std::panic::take_hook();
176 std::panic::set_hook(Box::new(|_| {
177 }));
179
180 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181 tokenizer.encode_ordinary(text).len() as u32
182 }));
183
184 std::panic::set_hook(prev_hook);
186
187 match result {
188 Ok(count) => count,
189 Err(_) => self.estimate(text, model), }
191 }
192
193 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
196 if text.is_empty() {
197 return 0;
198 }
199 let stats = compute_estimation_stats(text);
200 estimate_from_stats(&stats, model)
201 }
202
203 pub fn count_all(&self, text: &str) -> TokenCounts {
213 if text.is_empty() {
214 return TokenCounts::default();
215 }
216
217 let content_hash = hash_content(text);
219 let cache = if self.use_cache {
220 Some(get_token_cache())
221 } else {
222 None
223 };
224
225 let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
227 if let Some(cache) = cache {
228 let key = (content_hash, model);
229 if let Some(count) = cache.get(&key) {
230 return *count;
231 }
232 let count = self.tokenize_with_panic_guard(tokenizer, text, model);
233 maybe_cleanup_cache(cache);
234 cache.insert(key, count);
235 count
236 } else {
237 self.tokenize_with_panic_guard(tokenizer, text, model)
238 }
239 };
240
241 let stats = compute_estimation_stats(text);
243
244 let o200k = if self.use_exact {
246 get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
247 } else {
248 estimate_from_stats(&stats, TokenModel::Gpt4o)
249 };
250
251 let cl100k = if self.use_exact {
252 get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
253 } else {
254 estimate_from_stats(&stats, TokenModel::Gpt4)
255 };
256
257 TokenCounts {
259 o200k,
260 cl100k,
261 claude: estimate_from_stats(&stats, TokenModel::Claude),
262 gemini: estimate_from_stats(&stats, TokenModel::Gemini),
263 llama: estimate_from_stats(&stats, TokenModel::Llama),
264 mistral: estimate_from_stats(&stats, TokenModel::Mistral),
265 deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
266 qwen: estimate_from_stats(&stats, TokenModel::Qwen),
267 cohere: estimate_from_stats(&stats, TokenModel::Cohere),
268 grok: estimate_from_stats(&stats, TokenModel::Grok),
269 }
270 }
271
272 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
274 let counts = self.count_all(text);
275 let models = [
276 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
279 (TokenModel::Gemini, counts.gemini),
280 (TokenModel::Llama, counts.llama),
281 (TokenModel::Mistral, counts.mistral),
282 (TokenModel::DeepSeek, counts.deepseek),
283 (TokenModel::Qwen, counts.qwen),
284 (TokenModel::Cohere, counts.cohere),
285 (TokenModel::Grok, counts.grok),
286 ];
287
288 models
290 .into_iter()
291 .min_by_key(|(_, count)| *count)
292 .unwrap_or((TokenModel::Claude, 0))
293 }
294
295 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
297 let current = self.count(text, model);
298 if current <= budget {
299 return text;
300 }
301
302 let mut low = 0usize;
304 let mut high = text.len();
305
306 while low < high {
307 let mid_raw = (low + high).div_ceil(2);
308 let mid = text.floor_char_boundary(mid_raw);
310
311 if mid <= low {
315 break;
316 }
317
318 let count = self.count(&text[..mid], model);
319
320 if count <= budget {
321 low = mid;
322 } else {
323 high = mid.saturating_sub(1);
324 }
325 }
326
327 let mut end = low;
329 while end > 0 {
330 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
331 if c == b' ' || c == b'\n' {
332 break;
333 }
334 end -= 1;
335 }
336
337 if end > 0 {
338 &text[..end]
339 } else {
340 let low = text.floor_char_boundary(low);
341 &text[..low]
342 }
343 }
344
345 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
347 self.count(text, model) > budget
348 }
349}
350
351pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
353 if text.is_empty() {
354 return 0;
355 }
356 let chars_per_token = model.chars_per_token();
357 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
358}
359
360fn compute_estimation_stats(text: &str) -> EstimationStats {
363 let mut whitespace_count = 0u32;
364 let mut newline_count = 0u32;
365 let mut special_char_count = 0u32;
366
367 for &byte in text.as_bytes() {
369 match byte {
370 b' ' | b'\t' => whitespace_count += 1,
371 b'\n' => newline_count += 1,
372 b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'=' | b'+'
373 | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#' | b'$' | b'%'
374 | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
375 _ => {},
376 }
377 }
378
379 EstimationStats { len: text.len(), whitespace_count, newline_count, special_char_count }
380}
381
382fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
384 let chars_per_token = model.chars_per_token();
385 let len = stats.len as f32;
386
387 let mut estimate = len / chars_per_token;
389
390 estimate -= stats.whitespace_count as f32 * 0.3;
392
393 estimate += stats.newline_count as f32 * 0.5;
395
396 if matches!(
398 model,
399 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
400 ) {
401 estimate += stats.special_char_count as f32 * 0.3;
402 }
403
404 estimate.ceil().max(1.0) as u32
405}