infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
23 TOKEN_CACHE.get_or_init(DashMap::new)
24}
25
26fn hash_content(content: &str) -> u64 {
28 use std::collections::hash_map::DefaultHasher;
29 let mut hasher = DefaultHasher::new();
30 content.hash(&mut hasher);
31 hasher.finish()
32}
33
34fn get_gpt4o_tokenizer() -> &'static CoreBPE {
36 GPT4O_TOKENIZER.get_or_init(|| {
37 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
38 })
39}
40
41fn get_gpt4_tokenizer() -> &'static CoreBPE {
43 GPT4_TOKENIZER.get_or_init(|| {
44 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
45 })
46}
47
48pub struct Tokenizer {
54 use_exact: bool,
56 use_cache: bool,
58}
59
60impl Default for Tokenizer {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Tokenizer {
67 pub fn new() -> Self {
69 Self { use_exact: true, use_cache: true }
70 }
71
72 pub fn estimation_only() -> Self {
74 Self { use_exact: false, use_cache: true }
75 }
76
77 pub fn without_cache() -> Self {
79 Self { use_exact: true, use_cache: false }
80 }
81
82 #[must_use]
93 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
94 if text.is_empty() {
95 return 0;
96 }
97
98 if self.use_cache {
99 let cache = get_token_cache();
100 let content_hash = hash_content(text);
101 let key = (content_hash, model);
102
103 if let Some(count) = cache.get(&key) {
105 return *count;
106 }
107
108 let count = self.count_uncached(text, model);
110 cache.insert(key, count);
111 count
112 } else {
113 self.count_uncached(text, model)
114 }
115 }
116
117 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
119 if self.use_exact && model.has_exact_tokenizer() {
120 self.count_exact(text, model)
121 } else {
122 self.estimate(text, model)
123 }
124 }
125
126 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
129 if model.uses_o200k() {
130 let tokenizer = get_gpt4o_tokenizer();
133 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
135 tokenizer.encode_ordinary(text).len() as u32
136 })) {
137 Ok(count) => count,
138 Err(_) => self.estimate(text, model), }
140 } else if model.uses_cl100k() {
141 let tokenizer = get_gpt4_tokenizer();
144 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
146 tokenizer.encode_ordinary(text).len() as u32
147 })) {
148 Ok(count) => count,
149 Err(_) => self.estimate(text, model), }
151 } else {
152 self.estimate(text, model)
154 }
155 }
156
157 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
159 if text.is_empty() {
160 return 0;
161 }
162
163 let chars_per_token = model.chars_per_token();
164 let len = text.len() as f32;
165
166 let mut estimate = len / chars_per_token;
168
169 let whitespace_count = text.chars().filter(|c| *c == ' ' || *c == '\t').count() as f32;
171 estimate -= whitespace_count * 0.3;
172
173 let newline_count = text.chars().filter(|c| *c == '\n').count() as f32;
175 estimate += newline_count * 0.5;
176
177 let special_chars = text
179 .chars()
180 .filter(|c| {
181 matches!(
182 c,
183 '{' | '}'
184 | '('
185 | ')'
186 | '['
187 | ']'
188 | ';'
189 | ':'
190 | ','
191 | '.'
192 | '='
193 | '+'
194 | '-'
195 | '*'
196 | '/'
197 | '<'
198 | '>'
199 | '!'
200 | '&'
201 | '|'
202 | '@'
203 | '#'
204 | '$'
205 | '%'
206 | '^'
207 | '~'
208 | '`'
209 | '"'
210 | '\''
211 )
212 })
213 .count() as f32;
214
215 if matches!(
217 model,
218 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
219 ) {
220 estimate += special_chars * 0.3;
221 }
222
223 estimate.ceil().max(1.0) as u32
224 }
225
226 pub fn count_all(&self, text: &str) -> TokenCounts {
233 TokenCounts {
234 o200k: self.count(text, TokenModel::Gpt4o),
236 cl100k: self.count(text, TokenModel::Gpt4),
238 claude: self.count(text, TokenModel::Claude),
240 gemini: self.count(text, TokenModel::Gemini),
241 llama: self.count(text, TokenModel::Llama),
242 mistral: self.count(text, TokenModel::Mistral),
243 deepseek: self.count(text, TokenModel::DeepSeek),
244 qwen: self.count(text, TokenModel::Qwen),
245 cohere: self.count(text, TokenModel::Cohere),
246 grok: self.count(text, TokenModel::Grok),
247 }
248 }
249
250 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
252 let counts = self.count_all(text);
253 let models = [
254 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
257 (TokenModel::Gemini, counts.gemini),
258 (TokenModel::Llama, counts.llama),
259 (TokenModel::Mistral, counts.mistral),
260 (TokenModel::DeepSeek, counts.deepseek),
261 (TokenModel::Qwen, counts.qwen),
262 (TokenModel::Cohere, counts.cohere),
263 (TokenModel::Grok, counts.grok),
264 ];
265
266 models
268 .into_iter()
269 .min_by_key(|(_, count)| *count)
270 .unwrap_or((TokenModel::Claude, 0))
271 }
272
273 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
275 let current = self.count(text, model);
276 if current <= budget {
277 return text;
278 }
279
280 let mut low = 0usize;
282 let mut high = text.len();
283
284 while low < high {
285 let mid_raw = (low + high).div_ceil(2);
286 let mid = text.floor_char_boundary(mid_raw);
288
289 if mid <= low {
293 break;
294 }
295
296 let count = self.count(&text[..mid], model);
297
298 if count <= budget {
299 low = mid;
300 } else {
301 high = mid.saturating_sub(1);
302 }
303 }
304
305 let mut end = low;
307 while end > 0 {
308 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
309 if c == b' ' || c == b'\n' {
310 break;
311 }
312 end -= 1;
313 }
314
315 if end > 0 {
316 &text[..end]
317 } else {
318 let low = text.floor_char_boundary(low);
319 &text[..low]
320 }
321 }
322
323 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
325 self.count(text, model) > budget
326 }
327}
328
329pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
331 if text.is_empty() {
332 return 0;
333 }
334 let chars_per_token = model.chars_per_token();
335 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
336}