infiniloom_engine/tokenizer/
core.rs1use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
23 TOKEN_CACHE.get_or_init(DashMap::new)
24}
25
26fn hash_content(content: &str) -> u64 {
28 use std::collections::hash_map::DefaultHasher;
29 let mut hasher = DefaultHasher::new();
30 content.hash(&mut hasher);
31 hasher.finish()
32}
33
34fn get_gpt4o_tokenizer() -> &'static CoreBPE {
36 GPT4O_TOKENIZER.get_or_init(|| {
37 o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
38 })
39}
40
41fn get_gpt4_tokenizer() -> &'static CoreBPE {
43 GPT4_TOKENIZER.get_or_init(|| {
44 cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
45 })
46}
47
48pub struct Tokenizer {
54 use_exact: bool,
56 use_cache: bool,
58}
59
60impl Default for Tokenizer {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Tokenizer {
67 pub fn new() -> Self {
69 Self { use_exact: true, use_cache: true }
70 }
71
72 pub fn estimation_only() -> Self {
74 Self { use_exact: false, use_cache: true }
75 }
76
77 pub fn without_cache() -> Self {
79 Self { use_exact: true, use_cache: false }
80 }
81
82 #[must_use]
93 pub fn count(&self, text: &str, model: TokenModel) -> u32 {
94 if text.is_empty() {
95 return 0;
96 }
97
98 if self.use_cache {
99 let cache = get_token_cache();
100 let content_hash = hash_content(text);
101 let key = (content_hash, model);
102
103 if let Some(count) = cache.get(&key) {
105 return *count;
106 }
107
108 let count = self.count_uncached(text, model);
110 cache.insert(key, count);
111 count
112 } else {
113 self.count_uncached(text, model)
114 }
115 }
116
117 fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
119 if self.use_exact && model.has_exact_tokenizer() {
120 self.count_exact(text, model)
121 } else {
122 self.estimate(text, model)
123 }
124 }
125
126 fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
128 if model.uses_o200k() {
129 let tokenizer = get_gpt4o_tokenizer();
132 tokenizer.encode_ordinary(text).len() as u32
133 } else if model.uses_cl100k() {
134 let tokenizer = get_gpt4_tokenizer();
137 tokenizer.encode_ordinary(text).len() as u32
138 } else {
139 self.estimate(text, model)
141 }
142 }
143
144 fn estimate(&self, text: &str, model: TokenModel) -> u32 {
146 if text.is_empty() {
147 return 0;
148 }
149
150 let chars_per_token = model.chars_per_token();
151 let len = text.len() as f32;
152
153 let mut estimate = len / chars_per_token;
155
156 let whitespace_count = text.chars().filter(|c| *c == ' ' || *c == '\t').count() as f32;
158 estimate -= whitespace_count * 0.3;
159
160 let newline_count = text.chars().filter(|c| *c == '\n').count() as f32;
162 estimate += newline_count * 0.5;
163
164 let special_chars = text
166 .chars()
167 .filter(|c| {
168 matches!(
169 c,
170 '{' | '}'
171 | '('
172 | ')'
173 | '['
174 | ']'
175 | ';'
176 | ':'
177 | ','
178 | '.'
179 | '='
180 | '+'
181 | '-'
182 | '*'
183 | '/'
184 | '<'
185 | '>'
186 | '!'
187 | '&'
188 | '|'
189 | '@'
190 | '#'
191 | '$'
192 | '%'
193 | '^'
194 | '~'
195 | '`'
196 | '"'
197 | '\''
198 )
199 })
200 .count() as f32;
201
202 if matches!(
204 model,
205 TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
206 ) {
207 estimate += special_chars * 0.3;
208 }
209
210 estimate.ceil().max(1.0) as u32
211 }
212
213 pub fn count_all(&self, text: &str) -> TokenCounts {
220 TokenCounts {
221 o200k: self.count(text, TokenModel::Gpt4o),
223 cl100k: self.count(text, TokenModel::Gpt4),
225 claude: self.count(text, TokenModel::Claude),
227 gemini: self.count(text, TokenModel::Gemini),
228 llama: self.count(text, TokenModel::Llama),
229 mistral: self.count(text, TokenModel::Mistral),
230 deepseek: self.count(text, TokenModel::DeepSeek),
231 qwen: self.count(text, TokenModel::Qwen),
232 cohere: self.count(text, TokenModel::Cohere),
233 grok: self.count(text, TokenModel::Grok),
234 }
235 }
236
237 pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
239 let counts = self.count_all(text);
240 let models = [
241 (TokenModel::Gpt4o, counts.o200k), (TokenModel::Gpt4, counts.cl100k), (TokenModel::Claude, counts.claude),
244 (TokenModel::Gemini, counts.gemini),
245 (TokenModel::Llama, counts.llama),
246 (TokenModel::Mistral, counts.mistral),
247 (TokenModel::DeepSeek, counts.deepseek),
248 (TokenModel::Qwen, counts.qwen),
249 (TokenModel::Cohere, counts.cohere),
250 (TokenModel::Grok, counts.grok),
251 ];
252
253 models
255 .into_iter()
256 .min_by_key(|(_, count)| *count)
257 .unwrap_or((TokenModel::Claude, 0))
258 }
259
260 pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
262 let current = self.count(text, model);
263 if current <= budget {
264 return text;
265 }
266
267 let mut low = 0usize;
269 let mut high = text.len();
270
271 while low < high {
272 let mid_raw = (low + high).div_ceil(2);
273 let mid = text.floor_char_boundary(mid_raw);
275
276 if mid <= low {
280 break;
281 }
282
283 let count = self.count(&text[..mid], model);
284
285 if count <= budget {
286 low = mid;
287 } else {
288 high = mid.saturating_sub(1);
289 }
290 }
291
292 let mut end = low;
294 while end > 0 {
295 let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
296 if c == b' ' || c == b'\n' {
297 break;
298 }
299 end -= 1;
300 }
301
302 if end > 0 {
303 &text[..end]
304 } else {
305 let low = text.floor_char_boundary(low);
306 &text[..low]
307 }
308 }
309
310 pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
312 self.count(text, model) > budget
313 }
314}
315
316pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
318 if text.is_empty() {
319 return 0;
320 }
321 let chars_per_token = model.chars_per_token();
322 (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
323}