1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12 #[default]
14 O200kBase,
15 Cl100k,
17 Gemini,
19 Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::O200kBase => write!(f, "o200k_base"),
27 Self::Cl100k => write!(f, "cl100k_base"),
28 Self::Gemini => write!(f, "gemini"),
29 Self::Llama => write!(f, "llama"),
30 }
31 }
32}
33
34pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40 let lower = client_name.to_ascii_lowercase();
41 if lower.contains("claude")
42 || lower.contains("anthropic")
43 || lower.contains("sonnet")
44 || lower.contains("opus")
45 || lower.contains("haiku")
46 {
47 TokenizerFamily::Cl100k
48 } else if lower.contains("gemini") || lower.contains("google") {
49 TokenizerFamily::Gemini
50 } else if lower.contains("llama")
51 || lower.contains("codex")
52 || lower.contains("opencode")
53 || lower.contains("mistral")
54 || lower.contains("deepseek")
55 || lower.contains("qwen")
56 {
57 TokenizerFamily::Llama
58 } else {
59 TokenizerFamily::O200kBase
60 }
61}
62
63static BPE_O200K: OnceLock<Option<CoreBPE>> = OnceLock::new();
66static BPE_CL100K: OnceLock<Option<CoreBPE>> = OnceLock::new();
67
68fn get_bpe_o200k() -> Option<&'static CoreBPE> {
69 BPE_O200K
70 .get_or_init(|| {
71 tiktoken_rs::o200k_base()
72 .inspect_err(|e| tracing::error!("failed to load o200k_base tokenizer: {e}"))
73 .ok()
74 })
75 .as_ref()
76}
77
78fn get_bpe_cl100k() -> Option<&'static CoreBPE> {
79 BPE_CL100K
80 .get_or_init(|| {
81 tiktoken_rs::cl100k_base()
82 .inspect_err(|e| tracing::error!("failed to load cl100k_base tokenizer: {e}"))
83 .ok()
84 })
85 .as_ref()
86}
87
88fn bpe_for_family(family: TokenizerFamily) -> Option<&'static CoreBPE> {
89 match family {
90 TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
91 TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
92 }
93}
94
95const CHARS_PER_TOKEN_ESTIMATE: f64 = 3.5;
96
97const GEMINI_CORRECTION: f64 = 1.08;
99
100const TOKEN_CACHE_MAX: u64 = 4096;
103
104fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
105 static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
106 CACHE.get_or_init(|| {
107 moka::sync::Cache::builder()
108 .max_capacity(TOKEN_CACHE_MAX)
109 .build()
110 })
111}
112
113fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
114 let h = blake3::hash(text.as_bytes());
115 let bytes = h.as_bytes();
116 let base = u64::from_le_bytes([
117 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
118 ]);
119 base ^ (family as u64)
120}
121
122#[cfg(test)]
123fn floor_char_boundary(s: &str, idx: usize) -> usize {
124 let idx = idx.min(s.len());
125 let mut i = idx;
126 while i > 0 && !s.is_char_boundary(i) {
127 i -= 1;
128 }
129 i
130}
131
132#[cfg(test)]
133fn ceil_char_boundary(s: &str, idx: usize) -> usize {
134 let idx = idx.min(s.len());
135 let mut i = idx;
136 while i < s.len() && !s.is_char_boundary(i) {
137 i += 1;
138 }
139 i
140}
141
142pub fn count_tokens(text: &str) -> usize {
149 count_tokens_for(text, TokenizerFamily::O200kBase)
150}
151
152pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
154 if text.is_empty() {
155 return 0;
156 }
157
158 let key = hash_text(text, family);
159 let cache = token_cache();
160
161 if let Some(cached) = cache.get(&key) {
162 return cached;
163 }
164
165 let Some(bpe) = bpe_for_family(family) else {
166 let estimate = (text.len() as f64 / CHARS_PER_TOKEN_ESTIMATE).ceil() as usize;
167 cache.insert(key, estimate);
168 return estimate;
169 };
170 let raw = bpe.encode_with_special_tokens(text).len();
171 let count = if family == TokenizerFamily::Gemini {
172 (raw as f64 * GEMINI_CORRECTION).ceil() as usize
173 } else {
174 raw
175 };
176
177 cache.insert(key, count);
178 count
179}
180
181pub fn encode_tokens(text: &str) -> Vec<u32> {
183 if text.is_empty() {
184 return Vec::new();
185 }
186 match get_bpe_o200k() {
187 Some(bpe) => bpe.encode_with_special_tokens(text),
188 None => Vec::new(),
189 }
190}
191
192pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
196 if text.is_empty() {
197 return Vec::new();
198 }
199 match bpe_for_family(family) {
200 Some(bpe) => bpe.encode_with_special_tokens(text),
201 None => Vec::new(),
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use std::sync::{Mutex, OnceLock};
209
210 fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
211 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
212 LOCK.get_or_init(|| Mutex::new(()))
213 .lock()
214 .unwrap_or_else(std::sync::PoisonError::into_inner)
215 }
216
217 fn reset_cache() {
218 token_cache().invalidate_all();
219 }
220
221 #[test]
224 fn count_tokens_empty_is_zero() {
225 assert_eq!(count_tokens(""), 0);
226 }
227
228 #[test]
229 fn encode_tokens_empty_is_empty() {
230 assert!(encode_tokens("").is_empty());
231 }
232
233 #[test]
234 fn count_tokens_matches_encoded_length() {
235 let _lock = token_test_lock();
236 reset_cache();
237
238 let text = "hello world, Grüezi 🌍";
239 let counted = count_tokens(text);
240 let encoded = encode_tokens(text);
241 assert_eq!(counted, encoded.len());
242 assert_eq!(counted, count_tokens(text));
243 }
244
245 #[test]
246 fn char_boundary_helpers_handle_multibyte_indices() {
247 let s = "aé🙂z";
248 let emoji_start = s.find('🙂').expect("emoji exists");
249 let middle_of_emoji = emoji_start + 1;
250
251 let floor = floor_char_boundary(s, middle_of_emoji);
252 let ceil = ceil_char_boundary(s, middle_of_emoji);
253
254 assert!(s.is_char_boundary(floor));
255 assert!(s.is_char_boundary(ceil));
256 assert!(floor <= middle_of_emoji);
257 assert!(ceil >= middle_of_emoji);
258 }
259
260 #[test]
261 fn hash_text_is_stable_for_long_strings() {
262 let long = "abc🙂".repeat(300);
263 let h1 = hash_text(&long, TokenizerFamily::O200kBase);
264 let h2 = hash_text(&long, TokenizerFamily::O200kBase);
265 assert_eq!(h1, h2);
266 assert!(count_tokens(&long) > 0);
267 }
268
269 #[test]
272 fn tokenizer_family_default_is_o200k() {
273 assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
274 }
275
276 #[test]
277 fn tokenizer_family_display() {
278 assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
279 assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
280 assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
281 assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
282 }
283
284 #[test]
285 fn detect_tokenizer_openai_variants() {
286 assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
287 assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
288 assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
289 assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
290 }
291
292 #[test]
293 fn detect_tokenizer_claude_variants() {
294 assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
295 assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
296 assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
297 }
298
299 #[test]
300 fn detect_tokenizer_gemini_variants() {
301 assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
302 assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
303 assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
304 }
305
306 #[test]
307 fn detect_tokenizer_llama_variants() {
308 assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
309 assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
310 assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
311 }
312
313 #[test]
314 fn detect_tokenizer_unknown_defaults_to_o200k() {
315 assert_eq!(
316 detect_tokenizer("unknown-model"),
317 TokenizerFamily::O200kBase
318 );
319 assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
320 }
321
322 #[test]
323 fn count_tokens_for_all_families_nonzero() {
324 let _lock = token_test_lock();
325 reset_cache();
326
327 let text = "fn main() { println!(\"hello\"); }";
328 for family in [
329 TokenizerFamily::O200kBase,
330 TokenizerFamily::Cl100k,
331 TokenizerFamily::Gemini,
332 TokenizerFamily::Llama,
333 ] {
334 let count = count_tokens_for(text, family);
335 assert!(count > 0, "{family} returned 0 tokens");
336 }
337 }
338
339 #[test]
340 fn count_tokens_for_empty_is_zero_all_families() {
341 for family in [
342 TokenizerFamily::O200kBase,
343 TokenizerFamily::Cl100k,
344 TokenizerFamily::Gemini,
345 TokenizerFamily::Llama,
346 ] {
347 assert_eq!(count_tokens_for("", family), 0);
348 }
349 }
350
351 #[test]
352 fn gemini_count_exceeds_raw_o200k() {
353 let _lock = token_test_lock();
354 reset_cache();
355
356 let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
357 let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
358 let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
359 assert!(
360 gemini > o200k,
361 "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
362 );
363 }
364
365 #[test]
366 fn cl100k_differs_from_o200k() {
367 let _lock = token_test_lock();
368 reset_cache();
369
370 let text =
371 "use std::collections::HashMap;\nfn main() {\n let mut map = HashMap::new();\n}";
372 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
373 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
374 assert!(o200k > 0);
375 assert!(cl100k > 0);
376 }
377
378 #[test]
379 fn encode_tokens_for_matches_count() {
380 let _lock = token_test_lock();
381 reset_cache();
382
383 let text = "hello world";
384 for family in [
385 TokenizerFamily::O200kBase,
386 TokenizerFamily::Cl100k,
387 TokenizerFamily::Llama,
388 ] {
389 let encoded = encode_tokens_for(text, family);
390 let raw_count = bpe_for_family(family)
391 .unwrap()
392 .encode_with_special_tokens(text)
393 .len();
394 assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
395 }
396 }
397
398 #[test]
399 fn cache_distinguishes_families() {
400 let _lock = token_test_lock();
401 reset_cache();
402
403 let text = "cache test string";
404 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
405 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
406
407 let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
408 let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
409 assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
410
411 assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
412 assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
413 }
414
415 #[test]
416 fn default_count_tokens_is_o200k() {
417 let _lock = token_test_lock();
418 reset_cache();
419
420 let text = "backward compat check";
421 assert_eq!(
422 count_tokens(text),
423 count_tokens_for(text, TokenizerFamily::O200kBase)
424 );
425 }
426}