Skip to main content

lean_ctx/core/
tokens.rs

1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4// ── Tokenizer Families ─────────────────────────────────────
5
6/// Tokenizer families for different LLM providers.
7///
8/// Different LLM families use different tokenizers, leading to 5–15% variance
9/// in token counts. This enum lets callers select the appropriate tokenizer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12    /// GPT-4o, GPT-4-turbo (tiktoken o200k_base, exact)
13    #[default]
14    O200kBase,
15    /// Claude / Anthropic (approximated via tiktoken cl100k_base)
16    Cl100k,
17    /// Gemini / Google (o200k_base with 1.1× correction factor)
18    Gemini,
19    /// Llama 3+ (approximated via cl100k_base)
20    Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::O200kBase => write!(f, "o200k_base"),
27            Self::Cl100k => write!(f, "cl100k_base"),
28            Self::Gemini => write!(f, "gemini"),
29            Self::Llama => write!(f, "llama"),
30        }
31    }
32}
33
34/// Detects the appropriate tokenizer family from a client or model name.
35///
36/// Matches are case-insensitive substrings. Falls back to `O200kBase`.
37/// Accuracy: cl100k is within ~3% of Claude's actual tokenizer;
38/// Gemini correction factor 1.08 is empirically calibrated; o200k is exact for GPT-4o+.
39pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40    let lower = client_name.to_ascii_lowercase();
41    if lower.contains("claude")
42        || lower.contains("anthropic")
43        || lower.contains("sonnet")
44        || lower.contains("opus")
45        || lower.contains("haiku")
46    {
47        TokenizerFamily::Cl100k
48    } else if lower.contains("gemini") || lower.contains("google") {
49        TokenizerFamily::Gemini
50    } else if lower.contains("llama")
51        || lower.contains("codex")
52        || lower.contains("opencode")
53        || lower.contains("mistral")
54        || lower.contains("deepseek")
55        || lower.contains("qwen")
56    {
57        TokenizerFamily::Llama
58    } else {
59        TokenizerFamily::O200kBase
60    }
61}
62
63// ── Tokenizer Instances ────────────────────────────────────
64
65static BPE_O200K: OnceLock<Option<CoreBPE>> = OnceLock::new();
66static BPE_CL100K: OnceLock<Option<CoreBPE>> = OnceLock::new();
67
68fn get_bpe_o200k() -> Option<&'static CoreBPE> {
69    BPE_O200K
70        .get_or_init(|| {
71            tiktoken_rs::o200k_base()
72                .inspect_err(|e| tracing::error!("failed to load o200k_base tokenizer: {e}"))
73                .ok()
74        })
75        .as_ref()
76}
77
78fn get_bpe_cl100k() -> Option<&'static CoreBPE> {
79    BPE_CL100K
80        .get_or_init(|| {
81            tiktoken_rs::cl100k_base()
82                .inspect_err(|e| tracing::error!("failed to load cl100k_base tokenizer: {e}"))
83                .ok()
84        })
85        .as_ref()
86}
87
88fn bpe_for_family(family: TokenizerFamily) -> Option<&'static CoreBPE> {
89    match family {
90        TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
91        TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
92    }
93}
94
95const CHARS_PER_TOKEN_ESTIMATE: f64 = 3.5;
96
97/// Gemini tokens are ~8% larger on average vs o200k; empirically calibrated.
98const GEMINI_CORRECTION: f64 = 1.08;
99
100// ── Cache ──────────────────────────────────────────────────
101
102const TOKEN_CACHE_MAX: u64 = 4096;
103
104fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
105    static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
106    CACHE.get_or_init(|| {
107        moka::sync::Cache::builder()
108            .max_capacity(TOKEN_CACHE_MAX)
109            .build()
110    })
111}
112
113fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
114    let h = blake3::hash(text.as_bytes());
115    let bytes = h.as_bytes();
116    let base = u64::from_le_bytes([
117        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
118    ]);
119    base ^ (family as u64)
120}
121
122#[cfg(test)]
123fn floor_char_boundary(s: &str, idx: usize) -> usize {
124    let idx = idx.min(s.len());
125    let mut i = idx;
126    while i > 0 && !s.is_char_boundary(i) {
127        i -= 1;
128    }
129    i
130}
131
132#[cfg(test)]
133fn ceil_char_boundary(s: &str, idx: usize) -> usize {
134    let idx = idx.min(s.len());
135    let mut i = idx;
136    while i < s.len() && !s.is_char_boundary(i) {
137        i += 1;
138    }
139    i
140}
141
142// ── Public API ─────────────────────────────────────────────
143
144/// Counts BPE tokens using the default tokenizer (o200k_base).
145///
146/// Backward-compatible — equivalent to
147/// `count_tokens_for(text, TokenizerFamily::O200kBase)`.
148pub fn count_tokens(text: &str) -> usize {
149    count_tokens_for(text, TokenizerFamily::O200kBase)
150}
151
152/// Counts BPE tokens using the specified tokenizer family.
153pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
154    if text.is_empty() {
155        return 0;
156    }
157
158    let key = hash_text(text, family);
159    let cache = token_cache();
160
161    if let Some(cached) = cache.get(&key) {
162        return cached;
163    }
164
165    let Some(bpe) = bpe_for_family(family) else {
166        let estimate = (text.len() as f64 / CHARS_PER_TOKEN_ESTIMATE).ceil() as usize;
167        cache.insert(key, estimate);
168        return estimate;
169    };
170    let raw = bpe.encode_with_special_tokens(text).len();
171    let count = if family == TokenizerFamily::Gemini {
172        (raw as f64 * GEMINI_CORRECTION).ceil() as usize
173    } else {
174        raw
175    };
176
177    cache.insert(key, count);
178    count
179}
180
181/// Encodes text into BPE token IDs (o200k_base).
182pub fn encode_tokens(text: &str) -> Vec<u32> {
183    if text.is_empty() {
184        return Vec::new();
185    }
186    match get_bpe_o200k() {
187        Some(bpe) => bpe.encode_with_special_tokens(text),
188        None => Vec::new(),
189    }
190}
191
192/// Encodes text into BPE token IDs for the specified tokenizer family.
193///
194/// Gemini correction is not applied here — this returns raw token IDs.
195pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
196    if text.is_empty() {
197        return Vec::new();
198    }
199    match bpe_for_family(family) {
200        Some(bpe) => bpe.encode_with_special_tokens(text),
201        None => Vec::new(),
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use std::sync::{Mutex, OnceLock};
209
210    fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
211        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
212        LOCK.get_or_init(|| Mutex::new(()))
213            .lock()
214            .unwrap_or_else(std::sync::PoisonError::into_inner)
215    }
216
217    fn reset_cache() {
218        token_cache().invalidate_all();
219    }
220
221    // ── Backward-compatible tests ──────────────────────────
222
223    #[test]
224    fn count_tokens_empty_is_zero() {
225        assert_eq!(count_tokens(""), 0);
226    }
227
228    #[test]
229    fn encode_tokens_empty_is_empty() {
230        assert!(encode_tokens("").is_empty());
231    }
232
233    #[test]
234    fn count_tokens_matches_encoded_length() {
235        let _lock = token_test_lock();
236        reset_cache();
237
238        let text = "hello world, Grüezi 🌍";
239        let counted = count_tokens(text);
240        let encoded = encode_tokens(text);
241        assert_eq!(counted, encoded.len());
242        assert_eq!(counted, count_tokens(text));
243    }
244
245    #[test]
246    fn char_boundary_helpers_handle_multibyte_indices() {
247        let s = "aé🙂z";
248        let emoji_start = s.find('🙂').expect("emoji exists");
249        let middle_of_emoji = emoji_start + 1;
250
251        let floor = floor_char_boundary(s, middle_of_emoji);
252        let ceil = ceil_char_boundary(s, middle_of_emoji);
253
254        assert!(s.is_char_boundary(floor));
255        assert!(s.is_char_boundary(ceil));
256        assert!(floor <= middle_of_emoji);
257        assert!(ceil >= middle_of_emoji);
258    }
259
260    #[test]
261    fn hash_text_is_stable_for_long_strings() {
262        let long = "abc🙂".repeat(300);
263        let h1 = hash_text(&long, TokenizerFamily::O200kBase);
264        let h2 = hash_text(&long, TokenizerFamily::O200kBase);
265        assert_eq!(h1, h2);
266        assert!(count_tokens(&long) > 0);
267    }
268
269    // ── Multi-tokenizer tests ──────────────────────────────
270
271    #[test]
272    fn tokenizer_family_default_is_o200k() {
273        assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
274    }
275
276    #[test]
277    fn tokenizer_family_display() {
278        assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
279        assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
280        assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
281        assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
282    }
283
284    #[test]
285    fn detect_tokenizer_openai_variants() {
286        assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
287        assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
288        assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
289        assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
290    }
291
292    #[test]
293    fn detect_tokenizer_claude_variants() {
294        assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
295        assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
296        assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
297    }
298
299    #[test]
300    fn detect_tokenizer_gemini_variants() {
301        assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
302        assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
303        assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
304    }
305
306    #[test]
307    fn detect_tokenizer_llama_variants() {
308        assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
309        assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
310        assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
311    }
312
313    #[test]
314    fn detect_tokenizer_unknown_defaults_to_o200k() {
315        assert_eq!(
316            detect_tokenizer("unknown-model"),
317            TokenizerFamily::O200kBase
318        );
319        assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
320    }
321
322    #[test]
323    fn count_tokens_for_all_families_nonzero() {
324        let _lock = token_test_lock();
325        reset_cache();
326
327        let text = "fn main() { println!(\"hello\"); }";
328        for family in [
329            TokenizerFamily::O200kBase,
330            TokenizerFamily::Cl100k,
331            TokenizerFamily::Gemini,
332            TokenizerFamily::Llama,
333        ] {
334            let count = count_tokens_for(text, family);
335            assert!(count > 0, "{family} returned 0 tokens");
336        }
337    }
338
339    #[test]
340    fn count_tokens_for_empty_is_zero_all_families() {
341        for family in [
342            TokenizerFamily::O200kBase,
343            TokenizerFamily::Cl100k,
344            TokenizerFamily::Gemini,
345            TokenizerFamily::Llama,
346        ] {
347            assert_eq!(count_tokens_for("", family), 0);
348        }
349    }
350
351    #[test]
352    fn gemini_count_exceeds_raw_o200k() {
353        let _lock = token_test_lock();
354        reset_cache();
355
356        let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
357        let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
358        let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
359        assert!(
360            gemini > o200k,
361            "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
362        );
363    }
364
365    #[test]
366    fn cl100k_differs_from_o200k() {
367        let _lock = token_test_lock();
368        reset_cache();
369
370        let text =
371            "use std::collections::HashMap;\nfn main() {\n    let mut map = HashMap::new();\n}";
372        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
373        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
374        assert!(o200k > 0);
375        assert!(cl100k > 0);
376    }
377
378    #[test]
379    fn encode_tokens_for_matches_count() {
380        let _lock = token_test_lock();
381        reset_cache();
382
383        let text = "hello world";
384        for family in [
385            TokenizerFamily::O200kBase,
386            TokenizerFamily::Cl100k,
387            TokenizerFamily::Llama,
388        ] {
389            let encoded = encode_tokens_for(text, family);
390            let raw_count = bpe_for_family(family)
391                .unwrap()
392                .encode_with_special_tokens(text)
393                .len();
394            assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
395        }
396    }
397
398    #[test]
399    fn cache_distinguishes_families() {
400        let _lock = token_test_lock();
401        reset_cache();
402
403        let text = "cache test string";
404        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
405        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
406
407        let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
408        let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
409        assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
410
411        assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
412        assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
413    }
414
415    #[test]
416    fn default_count_tokens_is_o200k() {
417        let _lock = token_test_lock();
418        reset_cache();
419
420        let text = "backward compat check";
421        assert_eq!(
422            count_tokens(text),
423            count_tokens_for(text, TokenizerFamily::O200kBase)
424        );
425    }
426}