1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12 #[default]
14 O200kBase,
15 Cl100k,
17 Gemini,
19 Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::O200kBase => write!(f, "o200k_base"),
27 Self::Cl100k => write!(f, "cl100k_base"),
28 Self::Gemini => write!(f, "gemini"),
29 Self::Llama => write!(f, "llama"),
30 }
31 }
32}
33
34pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40 let lower = client_name.to_ascii_lowercase();
41 if lower.contains("claude")
42 || lower.contains("anthropic")
43 || lower.contains("sonnet")
44 || lower.contains("opus")
45 || lower.contains("haiku")
46 {
47 TokenizerFamily::Cl100k
48 } else if lower.contains("gemini") || lower.contains("google") {
49 TokenizerFamily::Gemini
50 } else if lower.contains("llama")
51 || lower.contains("codex")
52 || lower.contains("opencode")
53 || lower.contains("mistral")
54 || lower.contains("deepseek")
55 || lower.contains("qwen")
56 {
57 TokenizerFamily::Llama
58 } else {
59 TokenizerFamily::O200kBase
60 }
61}
62
63static BPE_O200K: OnceLock<CoreBPE> = OnceLock::new();
66static BPE_CL100K: OnceLock<CoreBPE> = OnceLock::new();
67
68fn get_bpe_o200k() -> &'static CoreBPE {
69 BPE_O200K
70 .get_or_init(|| tiktoken_rs::o200k_base().expect("failed to load o200k_base tokenizer"))
71}
72
73fn get_bpe_cl100k() -> &'static CoreBPE {
74 BPE_CL100K
75 .get_or_init(|| tiktoken_rs::cl100k_base().expect("failed to load cl100k_base tokenizer"))
76}
77
78fn bpe_for_family(family: TokenizerFamily) -> &'static CoreBPE {
79 match family {
80 TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
81 TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
82 }
83}
84
85const GEMINI_CORRECTION: f64 = 1.08;
87
88const TOKEN_CACHE_MAX: u64 = 4096;
91
92fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
93 static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
94 CACHE.get_or_init(|| {
95 moka::sync::Cache::builder()
96 .max_capacity(TOKEN_CACHE_MAX)
97 .build()
98 })
99}
100
101fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
102 let h = blake3::hash(text.as_bytes());
103 let bytes = h.as_bytes();
104 let base = u64::from_le_bytes([
105 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
106 ]);
107 base ^ (family as u64)
108}
109
110#[cfg(test)]
111fn floor_char_boundary(s: &str, idx: usize) -> usize {
112 let idx = idx.min(s.len());
113 let mut i = idx;
114 while i > 0 && !s.is_char_boundary(i) {
115 i -= 1;
116 }
117 i
118}
119
120#[cfg(test)]
121fn ceil_char_boundary(s: &str, idx: usize) -> usize {
122 let idx = idx.min(s.len());
123 let mut i = idx;
124 while i < s.len() && !s.is_char_boundary(i) {
125 i += 1;
126 }
127 i
128}
129
130pub fn count_tokens(text: &str) -> usize {
137 count_tokens_for(text, TokenizerFamily::O200kBase)
138}
139
140pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
142 if text.is_empty() {
143 return 0;
144 }
145
146 let key = hash_text(text, family);
147 let cache = token_cache();
148
149 if let Some(cached) = cache.get(&key) {
150 return cached;
151 }
152
153 let raw = bpe_for_family(family)
154 .encode_with_special_tokens(text)
155 .len();
156 let count = if family == TokenizerFamily::Gemini {
157 (raw as f64 * GEMINI_CORRECTION).ceil() as usize
158 } else {
159 raw
160 };
161
162 cache.insert(key, count);
163 count
164}
165
166pub fn encode_tokens(text: &str) -> Vec<u32> {
168 if text.is_empty() {
169 return Vec::new();
170 }
171 get_bpe_o200k().encode_with_special_tokens(text)
172}
173
174pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
178 if text.is_empty() {
179 return Vec::new();
180 }
181 bpe_for_family(family).encode_with_special_tokens(text)
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use std::sync::{Mutex, OnceLock};
188
189 fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
190 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
191 LOCK.get_or_init(|| Mutex::new(()))
192 .lock()
193 .unwrap_or_else(std::sync::PoisonError::into_inner)
194 }
195
196 fn reset_cache() {
197 token_cache().invalidate_all();
198 }
199
200 #[test]
203 fn count_tokens_empty_is_zero() {
204 assert_eq!(count_tokens(""), 0);
205 }
206
207 #[test]
208 fn encode_tokens_empty_is_empty() {
209 assert!(encode_tokens("").is_empty());
210 }
211
212 #[test]
213 fn count_tokens_matches_encoded_length() {
214 let _lock = token_test_lock();
215 reset_cache();
216
217 let text = "hello world, Grüezi 🌍";
218 let counted = count_tokens(text);
219 let encoded = encode_tokens(text);
220 assert_eq!(counted, encoded.len());
221 assert_eq!(counted, count_tokens(text));
222 }
223
224 #[test]
225 fn char_boundary_helpers_handle_multibyte_indices() {
226 let s = "aé🙂z";
227 let emoji_start = s.find('🙂').expect("emoji exists");
228 let middle_of_emoji = emoji_start + 1;
229
230 let floor = floor_char_boundary(s, middle_of_emoji);
231 let ceil = ceil_char_boundary(s, middle_of_emoji);
232
233 assert!(s.is_char_boundary(floor));
234 assert!(s.is_char_boundary(ceil));
235 assert!(floor <= middle_of_emoji);
236 assert!(ceil >= middle_of_emoji);
237 }
238
239 #[test]
240 fn hash_text_is_stable_for_long_strings() {
241 let long = "abc🙂".repeat(300);
242 let h1 = hash_text(&long, TokenizerFamily::O200kBase);
243 let h2 = hash_text(&long, TokenizerFamily::O200kBase);
244 assert_eq!(h1, h2);
245 assert!(count_tokens(&long) > 0);
246 }
247
248 #[test]
251 fn tokenizer_family_default_is_o200k() {
252 assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
253 }
254
255 #[test]
256 fn tokenizer_family_display() {
257 assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
258 assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
259 assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
260 assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
261 }
262
263 #[test]
264 fn detect_tokenizer_openai_variants() {
265 assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
266 assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
267 assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
268 assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
269 }
270
271 #[test]
272 fn detect_tokenizer_claude_variants() {
273 assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
274 assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
275 assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
276 }
277
278 #[test]
279 fn detect_tokenizer_gemini_variants() {
280 assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
281 assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
282 assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
283 }
284
285 #[test]
286 fn detect_tokenizer_llama_variants() {
287 assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
288 assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
289 assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
290 }
291
292 #[test]
293 fn detect_tokenizer_unknown_defaults_to_o200k() {
294 assert_eq!(
295 detect_tokenizer("unknown-model"),
296 TokenizerFamily::O200kBase
297 );
298 assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
299 }
300
301 #[test]
302 fn count_tokens_for_all_families_nonzero() {
303 let _lock = token_test_lock();
304 reset_cache();
305
306 let text = "fn main() { println!(\"hello\"); }";
307 for family in [
308 TokenizerFamily::O200kBase,
309 TokenizerFamily::Cl100k,
310 TokenizerFamily::Gemini,
311 TokenizerFamily::Llama,
312 ] {
313 let count = count_tokens_for(text, family);
314 assert!(count > 0, "{family} returned 0 tokens");
315 }
316 }
317
318 #[test]
319 fn count_tokens_for_empty_is_zero_all_families() {
320 for family in [
321 TokenizerFamily::O200kBase,
322 TokenizerFamily::Cl100k,
323 TokenizerFamily::Gemini,
324 TokenizerFamily::Llama,
325 ] {
326 assert_eq!(count_tokens_for("", family), 0);
327 }
328 }
329
330 #[test]
331 fn gemini_count_exceeds_raw_o200k() {
332 let _lock = token_test_lock();
333 reset_cache();
334
335 let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
336 let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
337 let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
338 assert!(
339 gemini > o200k,
340 "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
341 );
342 }
343
344 #[test]
345 fn cl100k_differs_from_o200k() {
346 let _lock = token_test_lock();
347 reset_cache();
348
349 let text =
350 "use std::collections::HashMap;\nfn main() {\n let mut map = HashMap::new();\n}";
351 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
352 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
353 assert!(o200k > 0);
354 assert!(cl100k > 0);
355 }
356
357 #[test]
358 fn encode_tokens_for_matches_count() {
359 let _lock = token_test_lock();
360 reset_cache();
361
362 let text = "hello world";
363 for family in [
364 TokenizerFamily::O200kBase,
365 TokenizerFamily::Cl100k,
366 TokenizerFamily::Llama,
367 ] {
368 let encoded = encode_tokens_for(text, family);
369 let raw_count = bpe_for_family(family)
370 .encode_with_special_tokens(text)
371 .len();
372 assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
373 }
374 }
375
376 #[test]
377 fn cache_distinguishes_families() {
378 let _lock = token_test_lock();
379 reset_cache();
380
381 let text = "cache test string";
382 let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
383 let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
384
385 let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
386 let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
387 assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
388
389 assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
390 assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
391 }
392
393 #[test]
394 fn default_count_tokens_is_o200k() {
395 let _lock = token_test_lock();
396 reset_cache();
397
398 let text = "backward compat check";
399 assert_eq!(
400 count_tokens(text),
401 count_tokens_for(text, TokenizerFamily::O200kBase)
402 );
403 }
404}