infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21const MAX_CACHE_ENTRIES: usize = 100_000;
25
26fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28 TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35 if cache.len() >= MAX_CACHE_ENTRIES {
36 cache.clear();
37 }
38}
39
40fn hash_content(content: &str) -> u64 {
42 use std::collections::hash_map::DefaultHasher;
43 let mut hasher = DefaultHasher::new();
44 content.hash(&mut hasher);
45 hasher.finish()
46}
47
48fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50 GPT4O_TOKENIZER.get_or_init(|| {
51 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52 })
53}
54
55fn get_gpt4_tokenizer() -> &'static CoreBPE {
57 GPT4_TOKENIZER.get_or_init(|| {
58 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59 })
60}
61
62pub struct Tokenizer {
68 use_exact: bool,
70 use_cache: bool,
72}
73
74impl Default for Tokenizer {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Tokenizer {
81 pub fn new() -> Self {
83 Self { use_exact: true, use_cache: true }
84 }
85
86 pub fn estimation_only() -> Self {
88 Self { use_exact: false, use_cache: true }
89 }
90
91 pub fn without_cache() -> Self {
93 Self { use_exact: true, use_cache: false }
94 }
95
96 #[must_use]
107 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
108 if text.is_empty() {
109 return 0;
110 }
111
112 if self.use_cache {
113 let cache = get_token_cache();
114 let content_hash = hash_content(text);
115 let key = (content_hash, model);
116
117 if let Some(count) = cache.get(&key) {
119 return *count;
120 }
121
122 let count = self.count_uncached(text, model);
124 maybe_cleanup_cache(cache);
125 cache.insert(key, count);
126 count
127 } else {
128 self.count_uncached(text, model)
129 }
130 }
131
132 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
134 if self.use_exact && model.has_exact_tokenizer() {
135 self.count_exact(text, model)
136 } else {
137 self.estimate(text, model)
138 }
139 }
140
141 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
144 if model.uses_o200k() {
145 let tokenizer = get_gpt4o_tokenizer();
148 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
150 tokenizer.encode_ordinary(text).len() as u32
151 })) {
152 Ok(count) => count,
153 Err(_) => self.estimate(text, model), }
155 } else if model.uses_cl100k() {
156 let tokenizer = get_gpt4_tokenizer();
159 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
161 tokenizer.encode_ordinary(text).len() as u32
162 })) {
163 Ok(count) => count,
164 Err(_) => self.estimate(text, model), }
166 } else {
167 self.estimate(text, model)
169 }
170 }
171
172 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
174 if text.is_empty() {
175 return 0;
176 }
177
178 let chars_per_token = model.chars_per_token();
179 let len = text.len() as f32;
180
181 let mut estimate = len / chars_per_token;
183
184 let whitespace_count = text.chars().filter(|c| *c == ' ' || *c == '\t').count() as f32;
186 estimate -= whitespace_count * 0.3;
187
188 let newline_count = text.chars().filter(|c| *c == '\n').count() as f32;
190 estimate += newline_count * 0.5;
191
192 let special_chars = text
194 .chars()
195 .filter(|c| {
196 matches!(
197 c,
198 '{' | '}'
199 | '('
200 | ')'
201 | '['
202 | ']'
203 | ';'
204 | ':'
205 | ','
206 | '.'
207 | '='
208 | '+'
209 | '-'
210 | '*'
211 | '/'
212 | '<'
213 | '>'
214 | '!'
215 | '&'
216 | '|'
217 | '@'
218 | '#'
219 | '$'
220 | '%'
221 | '^'
222 | '~'
223 | '`'
224 | '"'
225 | '\''
226 )
227 })
228 .count() as f32;
229
230 if matches!(
232 model,
233 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
234 ) {
235 estimate += special_chars * 0.3;
236 }
237
238 estimate.ceil().max(1.0) as u32
239 }
240
241 pub fn count_all(&self, text: &str) -> TokenCounts {
248 TokenCounts {
249 o200k: self.count(text, TokenModel::Gpt4o),
251 cl100k: self.count(text, TokenModel::Gpt4),
253 claude: self.count(text, TokenModel::Claude),
255 gemini: self.count(text, TokenModel::Gemini),
256 llama: self.count(text, TokenModel::Llama),
257 mistral: self.count(text, TokenModel::Mistral),
258 deepseek: self.count(text, TokenModel::DeepSeek),
259 qwen: self.count(text, TokenModel::Qwen),
260 cohere: self.count(text, TokenModel::Cohere),
261 grok: self.count(text, TokenModel::Grok),
262 }
263 }
264
265 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
267 let counts = self.count_all(text);
268 let models = [
269 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
272 (TokenModel::Gemini, counts.gemini),
273 (TokenModel::Llama, counts.llama),
274 (TokenModel::Mistral, counts.mistral),
275 (TokenModel::DeepSeek, counts.deepseek),
276 (TokenModel::Qwen, counts.qwen),
277 (TokenModel::Cohere, counts.cohere),
278 (TokenModel::Grok, counts.grok),
279 ];
280
281 models
283 .into_iter()
284 .min_by_key(|(_, count)| *count)
285 .unwrap_or((TokenModel::Claude, 0))
286 }
287
288 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
290 let current = self.count(text, model);
291 if current <= budget {
292 return text;
293 }
294
295 let mut low = 0usize;
297 let mut high = text.len();
298
299 while low < high {
300 let mid_raw = (low + high).div_ceil(2);
301 let mid = text.floor_char_boundary(mid_raw);
303
304 if mid <= low {
308 break;
309 }
310
311 let count = self.count(&text[..mid], model);
312
313 if count <= budget {
314 low = mid;
315 } else {
316 high = mid.saturating_sub(1);
317 }
318 }
319
320 let mut end = low;
322 while end > 0 {
323 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
324 if c == b' ' || c == b'\n' {
325 break;
326 }
327 end -= 1;
328 }
329
330 if end > 0 {
331 &text[..end]
332 } else {
333 let low = text.floor_char_boundary(low);
334 &text[..low]
335 }
336 }
337
338 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
340 self.count(text, model) > budget
341 }
342}
343
344pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
346 if text.is_empty() {
347 return 0;
348 }
349 let chars_per_token = model.chars_per_token();
350 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
351}