1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3use tiktoken_rs::CoreBPE;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
12pub enum TokenizerFamily {
13 #[default]
15 O200kBase,
16 Cl100k,
18 Gemini,
20 Llama,
22}
23
24impl std::fmt::Display for TokenizerFamily {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 Self::O200kBase => write!(f, "o200k_base"),
28 Self::Cl100k => write!(f, "cl100k_base"),
29 Self::Gemini => write!(f, "gemini"),
30 Self::Llama => write!(f, "llama"),
31 }
32 }
33}
34
35pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
41 let lower = client_name.to_ascii_lowercase();
42 if lower.contains("claude")
43 || lower.contains("anthropic")
44 || lower.contains("sonnet")
45 || lower.contains("opus")
46 || lower.contains("haiku")
47 {
48 TokenizerFamily::Cl100k
49 } else if lower.contains("gemini") || lower.contains("google") {
50 TokenizerFamily::Gemini
51 } else if lower.contains("llama")
52 || lower.contains("codex")
53 || lower.contains("opencode")
54 || lower.contains("mistral")
55 || lower.contains("deepseek")
56 || lower.contains("qwen")
57 {
58 TokenizerFamily::Llama
59 } else {
60 TokenizerFamily::O200kBase
61 }
62}
63
64static BPE_O200K: OnceLock<CoreBPE> = OnceLock::new();
67static BPE_CL100K: OnceLock<CoreBPE> = OnceLock::new();
68
69fn get_bpe_o200k() -> &'static CoreBPE {
70 BPE_O200K
71 .get_or_init(|| tiktoken_rs::o200k_base().expect("failed to load o200k_base tokenizer"))
72}
73
74fn get_bpe_cl100k() -> &'static CoreBPE {
75 BPE_CL100K
76 .get_or_init(|| tiktoken_rs::cl100k_base().expect("failed to load cl100k_base tokenizer"))
77}
78
79fn bpe_for_family(family: TokenizerFamily) -> &'static CoreBPE {
80 match family {
81 TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
82 TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
83 }
84}
85
86const GEMINI_CORRECTION: f64 = 1.08;
88
89const TOKEN_CACHE_MAX: usize = 256;
92
93static TOKEN_CACHE: Mutex<Option<HashMap<u64, usize>>> = Mutex::new(None);
94
95fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
96 use std::hash::{Hash, Hasher};
97 let mut hasher = std::collections::hash_map::DefaultHasher::new();
98 family.hash(&mut hasher);
99 text.len().hash(&mut hasher);
100 if text.len() <= 512 {
101 text.hash(&mut hasher);
102 } else {
103 let start_end = floor_char_boundary(text, 256);
104 let tail_start = ceil_char_boundary(text, text.len() - 256);
105 text[..start_end].hash(&mut hasher);
106 text[tail_start..].hash(&mut hasher);
107 }
108 hasher.finish()
109}
110
111fn floor_char_boundary(s: &str, idx: usize) -> usize {
112 let idx = idx.min(s.len());
113 let mut i = idx;
114 while i > 0 && !s.is_char_boundary(i) {
115 i -= 1;
116 }
117 i
118}
119
120fn ceil_char_boundary(s: &str, idx: usize) -> usize {
121 let idx = idx.min(s.len());
122 let mut i = idx;
123 while i < s.len() && !s.is_char_boundary(i) {
124 i += 1;
125 }
126 i
127}
128
129pub fn count_tokens(text: &str) -> usize {
136 count_tokens_for(text, TokenizerFamily::O200kBase)
137}
138
139pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
141 if text.is_empty() {
142 return 0;
143 }
144
145 let key = hash_text(text, family);
146
147 if let Ok(guard) = TOKEN_CACHE.lock() {
148 if let Some(ref map) = *guard {
149 if let Some(&cached) = map.get(&key) {
150 return cached;
151 }
152 }
153 }
154
155 let raw = bpe_for_family(family)
156 .encode_with_special_tokens(text)
157 .len();
158 let count = if family == TokenizerFamily::Gemini {
159 (raw as f64 * GEMINI_CORRECTION).ceil() as usize
160 } else {
161 raw
162 };
163
164 if let Ok(mut guard) = TOKEN_CACHE.lock() {
165 let map = guard.get_or_insert_with(HashMap::new);
166 if map.len() >= TOKEN_CACHE_MAX {
167 map.clear();
168 }
169 map.insert(key, count);
170 }
171
172 count
173}
174
175pub fn encode_tokens(text: &str) -> Vec<u32> {
177 if text.is_empty() {
178 return Vec::new();
179 }
180 get_bpe_o200k().encode_with_special_tokens(text)
181}
182
183pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
187 if text.is_empty() {
188 return Vec::new();
189 }
190 bpe_for_family(family).encode_with_special_tokens(text)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::collections::HashMap;
197 use std::sync::{Mutex, OnceLock};
198
199 fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
200 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
201 LOCK.get_or_init(|| Mutex::new(()))
202 .lock()
203 .unwrap_or_else(std::sync::PoisonError::into_inner)
204 }
205
206 fn reset_cache() {
207 if let Ok(mut guard) = TOKEN_CACHE.lock() {
208 *guard = Some(HashMap::new());
209 }
210 }
211
212 #[test]
215 fn count_tokens_empty_is_zero() {
216 assert_eq!(count_tokens(""), 0);
217 }
218
219 #[test]
220 fn encode_tokens_empty_is_empty() {
221 assert!(encode_tokens("").is_empty());
222 }
223
224 #[test]
225 fn count_tokens_matches_encoded_length() {
226 let _lock = token_test_lock();
227 reset_cache();
228
229 let text = "hello world, Grüezi 🌍";
230 let counted = count_tokens(text);
231 let encoded = encode_tokens(text);
232 assert_eq!(counted, encoded.len());
233 assert_eq!(counted, count_tokens(text));
234 }
235
236 #[test]
237 fn char_boundary_helpers_handle_multibyte_indices() {
238 let s = "aé🙂z";
239 let emoji_start = s.find('🙂').expect("emoji exists");
240 let middle_of_emoji = emoji_start + 1;
241
242 let floor = floor_char_boundary(s, middle_of_emoji);
243 let ceil = ceil_char_boundary(s, middle_of_emoji);
244
245 assert!(s.is_char_boundary(floor));
246 assert!(s.is_char_boundary(ceil));
247 assert!(floor <= middle_of_emoji);
248 assert!(ceil >= middle_of_emoji);
249 }
250
251 #[test]
252 fn hash_text_is_stable_for_long_strings() {
253 let long = "abc🙂".repeat(300);
254 let h1 = hash_text(&long, TokenizerFamily::O200kBase);
255 let h2 = hash_text(&long, TokenizerFamily::O200kBase);
256 assert_eq!(h1, h2);
257 assert!(count_tokens(&long) > 0);
258 }
259
260 #[test]
263 fn tokenizer_family_default_is_o200k() {
264 assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
265 }
266
267 #[test]
268 fn tokenizer_family_display() {
269 assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
270 assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
271 assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
272 assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
273 }
274
275 #[test]
276 fn detect_tokenizer_openai_variants() {
277 assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
278 assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
279 assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
280 assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
281 }
282
283 #[test]
284 fn detect_tokenizer_claude_variants() {
285 assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
286 assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
287 assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
288 }
289
290 #[test]
291 fn detect_tokenizer_gemini_variants() {
292 assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
293 assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
294 assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
295 }
296
297 #[test]
298 fn detect_tokenizer_llama_variants() {
299 assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
300 assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
301 assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
302 }
303
304 #[test]
305 fn detect_tokenizer_unknown_defaults_to_o200k() {
306 assert_eq!(
307 detect_tokenizer("unknown-model"),
308 TokenizerFamily::O200kBase
309 );
310 assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
311 }
312
313 #[test]
314 fn count_tokens_for_all_families_nonzero() {
315 let _lock = token_test_lock();
316 reset_cache();
317
318 let text = "fn main() { println!(\"hello\"); }";
319 for family in [
320 TokenizerFamily::O200kBase,
321 TokenizerFamily::Cl100k,
322 TokenizerFamily::Gemini,
323 TokenizerFamily::Llama,
324 ] {
325 let count = count_tokens_for(text, family);
326 assert!(count > 0, "{family} returned 0 tokens");
327 }
328 }
329
330 #[test]
331 fn count_tokens_for_empty_is_zero_all_families() {
332 for family in [
333 TokenizerFamily::O200kBase,
334 TokenizerFamily::Cl100k,
335 TokenizerFamily::Gemini,
336 TokenizerFamily::Llama,
337 ] {
338 assert_eq!(count_tokens_for("", family), 0);
339 }
340 }
341
342 #[test]
343 fn gemini_count_exceeds_raw_o200k() {
344 let _lock = token_test_lock();
345 reset_cache();
346
347 let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
348 let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
349 let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
350 assert!(
351 gemini > o200k,
352 "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
353 );
354 }
355
356 #[test]
357 fn cl100k_differs_from_o200k() {
358 let _lock = token_test_lock();
359 reset_cache();
360
361 let text =
362 "use std::collections::HashMap;\nfn main() {\n let mut map = HashMap::new();\n}";
363 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
364 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
365 assert!(o200k > 0);
366 assert!(cl100k > 0);
367 }
368
369 #[test]
370 fn encode_tokens_for_matches_count() {
371 let _lock = token_test_lock();
372 reset_cache();
373
374 let text = "hello world";
375 for family in [
376 TokenizerFamily::O200kBase,
377 TokenizerFamily::Cl100k,
378 TokenizerFamily::Llama,
379 ] {
380 let encoded = encode_tokens_for(text, family);
381 let raw_count = bpe_for_family(family)
382 .encode_with_special_tokens(text)
383 .len();
384 assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
385 }
386 }
387
388 #[test]
389 fn cache_distinguishes_families() {
390 let _lock = token_test_lock();
391 reset_cache();
392
393 let text = "cache test string";
394 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
395 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
396
397 let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
398 let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
399 assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
400
401 assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
402 assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
403 }
404
405 #[test]
406 fn default_count_tokens_is_o200k() {
407 let _lock = token_test_lock();
408 reset_cache();
409
410 let text = "backward compat check";
411 assert_eq!(
412 count_tokens(text),
413 count_tokens_for(text, TokenizerFamily::O200kBase)
414 );
415 }
416}