1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12 #[default]
14 O200kBase,
15 Cl100k,
17 Gemini,
19 Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::O200kBase => write!(f, "o200k_base"),
27 Self::Cl100k => write!(f, "cl100k_base"),
28 Self::Gemini => write!(f, "gemini"),
29 Self::Llama => write!(f, "llama"),
30 }
31 }
32}
33
34pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40 let lower = client_name.to_ascii_lowercase();
41 if lower.contains("claude")
42 || lower.contains("anthropic")
43 || lower.contains("sonnet")
44 || lower.contains("opus")
45 || lower.contains("haiku")
46 {
47 TokenizerFamily::Cl100k
48 } else if lower.contains("gemini") || lower.contains("google") {
49 TokenizerFamily::Gemini
50 } else if lower.contains("llama")
51 || lower.contains("codex")
52 || lower.contains("opencode")
53 || lower.contains("mistral")
54 || lower.contains("deepseek")
55 || lower.contains("qwen")
56 {
57 TokenizerFamily::Llama
58 } else {
59 TokenizerFamily::O200kBase
60 }
61}
62
63static BPE_O200K: OnceLock<Option<CoreBPE>> = OnceLock::new();
66static BPE_CL100K: OnceLock<Option<CoreBPE>> = OnceLock::new();
67
68fn get_bpe_o200k() -> Option<&'static CoreBPE> {
69 BPE_O200K
70 .get_or_init(|| {
71 tiktoken_rs::o200k_base()
72 .inspect_err(|e| tracing::error!("failed to load o200k_base tokenizer: {e}"))
73 .ok()
74 })
75 .as_ref()
76}
77
78fn get_bpe_cl100k() -> Option<&'static CoreBPE> {
79 BPE_CL100K
80 .get_or_init(|| {
81 tiktoken_rs::cl100k_base()
82 .inspect_err(|e| tracing::error!("failed to load cl100k_base tokenizer: {e}"))
83 .ok()
84 })
85 .as_ref()
86}
87
88fn bpe_for_family(family: TokenizerFamily) -> Option<&'static CoreBPE> {
89 match family {
90 TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
91 TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
92 }
93}
94
95const CHARS_PER_TOKEN_ESTIMATE: f64 = 3.5;
96
97const GEMINI_CORRECTION: f64 = 1.08;
99
100const TOKEN_CACHE_MAX: u64 = 4096;
103
104fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
105 static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
106 CACHE.get_or_init(|| {
107 moka::sync::Cache::builder()
108 .max_capacity(TOKEN_CACHE_MAX)
109 .build()
110 })
111}
112
113fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
114 let h = blake3::hash(text.as_bytes());
115 let bytes = h.as_bytes();
116 let base = u64::from_le_bytes([
117 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
118 ]);
119 base ^ (family as u64)
120}
121
122#[cfg(test)]
123fn floor_char_boundary(s: &str, idx: usize) -> usize {
124 let idx = idx.min(s.len());
125 let mut i = idx;
126 while i > 0 && !s.is_char_boundary(i) {
127 i -= 1;
128 }
129 i
130}
131
132#[cfg(test)]
133fn ceil_char_boundary(s: &str, idx: usize) -> usize {
134 let idx = idx.min(s.len());
135 let mut i = idx;
136 while i < s.len() && !s.is_char_boundary(i) {
137 i += 1;
138 }
139 i
140}
141
142pub fn count_tokens(text: &str) -> usize {
149 count_tokens_for(text, COUNTING_FAMILY)
150}
151
152pub const COUNTING_FAMILY: TokenizerFamily = TokenizerFamily::O200kBase;
157
158pub fn counting_family_label() -> String {
160 COUNTING_FAMILY.to_string()
161}
162
163pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
165 if text.is_empty() {
166 return 0;
167 }
168
169 let key = hash_text(text, family);
170 let cache = token_cache();
171
172 if let Some(cached) = cache.get(&key) {
173 return cached;
174 }
175
176 let Some(bpe) = bpe_for_family(family) else {
177 let estimate = (text.len() as f64 / CHARS_PER_TOKEN_ESTIMATE).ceil() as usize;
178 cache.insert(key, estimate);
179 return estimate;
180 };
181 let raw = bpe.encode_with_special_tokens(text).len();
182 let count = if family == TokenizerFamily::Gemini {
183 (raw as f64 * GEMINI_CORRECTION).ceil() as usize
184 } else {
185 raw
186 };
187
188 cache.insert(key, count);
189 count
190}
191
192pub fn encode_tokens(text: &str) -> Vec<u32> {
194 if text.is_empty() {
195 return Vec::new();
196 }
197 match get_bpe_o200k() {
198 Some(bpe) => bpe.encode_with_special_tokens(text),
199 None => Vec::new(),
200 }
201}
202
203pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
207 if text.is_empty() {
208 return Vec::new();
209 }
210 match bpe_for_family(family) {
211 Some(bpe) => bpe.encode_with_special_tokens(text),
212 None => Vec::new(),
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use std::sync::{Mutex, OnceLock};
220
221 fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
222 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
223 LOCK.get_or_init(|| Mutex::new(()))
224 .lock()
225 .unwrap_or_else(std::sync::PoisonError::into_inner)
226 }
227
228 fn reset_cache() {
229 token_cache().invalidate_all();
230 }
231
232 #[test]
235 fn count_tokens_empty_is_zero() {
236 assert_eq!(count_tokens(""), 0);
237 }
238
239 #[test]
240 fn encode_tokens_empty_is_empty() {
241 assert!(encode_tokens("").is_empty());
242 }
243
244 #[test]
245 fn count_tokens_matches_encoded_length() {
246 let _lock = token_test_lock();
247 reset_cache();
248
249 let text = "hello world, Grüezi 🌍";
250 let counted = count_tokens(text);
251 let encoded = encode_tokens(text);
252 assert_eq!(counted, encoded.len());
253 assert_eq!(counted, count_tokens(text));
254 }
255
256 #[test]
257 fn char_boundary_helpers_handle_multibyte_indices() {
258 let s = "aé🙂z";
259 let emoji_start = s.find('🙂').expect("emoji exists");
260 let middle_of_emoji = emoji_start + 1;
261
262 let floor = floor_char_boundary(s, middle_of_emoji);
263 let ceil = ceil_char_boundary(s, middle_of_emoji);
264
265 assert!(s.is_char_boundary(floor));
266 assert!(s.is_char_boundary(ceil));
267 assert!(floor <= middle_of_emoji);
268 assert!(ceil >= middle_of_emoji);
269 }
270
271 #[test]
272 fn hash_text_is_stable_for_long_strings() {
273 let long = "abc🙂".repeat(300);
274 let h1 = hash_text(&long, TokenizerFamily::O200kBase);
275 let h2 = hash_text(&long, TokenizerFamily::O200kBase);
276 assert_eq!(h1, h2);
277 assert!(count_tokens(&long) > 0);
278 }
279
280 #[test]
283 fn tokenizer_family_default_is_o200k() {
284 assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
285 }
286
287 #[test]
288 fn tokenizer_family_display() {
289 assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
290 assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
291 assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
292 assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
293 }
294
295 #[test]
296 fn detect_tokenizer_openai_variants() {
297 assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
298 assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
299 assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
300 assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
301 }
302
303 #[test]
304 fn detect_tokenizer_claude_variants() {
305 assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
306 assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
307 assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
308 }
309
310 #[test]
311 fn detect_tokenizer_gemini_variants() {
312 assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
313 assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
314 assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
315 }
316
317 #[test]
318 fn detect_tokenizer_llama_variants() {
319 assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
320 assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
321 assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
322 }
323
324 #[test]
325 fn detect_tokenizer_unknown_defaults_to_o200k() {
326 assert_eq!(
327 detect_tokenizer("unknown-model"),
328 TokenizerFamily::O200kBase
329 );
330 assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
331 }
332
333 #[test]
334 fn count_tokens_for_all_families_nonzero() {
335 let _lock = token_test_lock();
336 reset_cache();
337
338 let text = "fn main() { println!(\"hello\"); }";
339 for family in [
340 TokenizerFamily::O200kBase,
341 TokenizerFamily::Cl100k,
342 TokenizerFamily::Gemini,
343 TokenizerFamily::Llama,
344 ] {
345 let count = count_tokens_for(text, family);
346 assert!(count > 0, "{family} returned 0 tokens");
347 }
348 }
349
350 #[test]
351 fn count_tokens_for_empty_is_zero_all_families() {
352 for family in [
353 TokenizerFamily::O200kBase,
354 TokenizerFamily::Cl100k,
355 TokenizerFamily::Gemini,
356 TokenizerFamily::Llama,
357 ] {
358 assert_eq!(count_tokens_for("", family), 0);
359 }
360 }
361
362 #[test]
363 fn gemini_count_exceeds_raw_o200k() {
364 let _lock = token_test_lock();
365 reset_cache();
366
367 let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
368 let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
369 let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
370 assert!(
371 gemini > o200k,
372 "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
373 );
374 }
375
376 #[test]
377 fn cl100k_differs_from_o200k() {
378 let _lock = token_test_lock();
379 reset_cache();
380
381 let text =
382 "use std::collections::HashMap;\nfn main() {\n let mut map = HashMap::new();\n}";
383 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
384 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
385 assert!(o200k > 0);
386 assert!(cl100k > 0);
387 }
388
389 #[test]
390 fn encode_tokens_for_matches_count() {
391 let _lock = token_test_lock();
392 reset_cache();
393
394 let text = "hello world";
395 for family in [
396 TokenizerFamily::O200kBase,
397 TokenizerFamily::Cl100k,
398 TokenizerFamily::Llama,
399 ] {
400 let encoded = encode_tokens_for(text, family);
401 let raw_count = bpe_for_family(family)
402 .unwrap()
403 .encode_with_special_tokens(text)
404 .len();
405 assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
406 }
407 }
408
409 #[test]
410 fn cache_distinguishes_families() {
411 let _lock = token_test_lock();
412 reset_cache();
413
414 let text = "cache test string";
415 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
416 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
417
418 let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
419 let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
420 assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
421
422 assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
423 assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
424 }
425
426 #[test]
427 fn default_count_tokens_is_o200k() {
428 let _lock = token_test_lock();
429 reset_cache();
430
431 let text = "backward compat check";
432 assert_eq!(
433 count_tokens(text),
434 count_tokens_for(text, TokenizerFamily::O200kBase)
435 );
436 }
437}