Skip to main content

lean_ctx/core/
tokens.rs

1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4// ── Tokenizer Families ─────────────────────────────────────
5
6/// Tokenizer families for different LLM providers.
7///
8/// Different LLM families use different tokenizers, leading to 5–15% variance
9/// in token counts. This enum lets callers select the appropriate tokenizer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12    /// GPT-4o, GPT-4-turbo (tiktoken o200k_base, exact)
13    #[default]
14    O200kBase,
15    /// Claude / Anthropic (approximated via tiktoken cl100k_base)
16    Cl100k,
17    /// Gemini / Google (o200k_base with 1.1× correction factor)
18    Gemini,
19    /// Llama 3+ (approximated via cl100k_base)
20    Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::O200kBase => write!(f, "o200k_base"),
27            Self::Cl100k => write!(f, "cl100k_base"),
28            Self::Gemini => write!(f, "gemini"),
29            Self::Llama => write!(f, "llama"),
30        }
31    }
32}
33
34/// Detects the appropriate tokenizer family from a client or model name.
35///
36/// Matches are case-insensitive substrings. Falls back to `O200kBase`.
37/// Accuracy: cl100k is within ~3% of Claude's actual tokenizer;
38/// Gemini correction factor 1.08 is empirically calibrated; o200k is exact for GPT-4o+.
39pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40    let lower = client_name.to_ascii_lowercase();
41    if lower.contains("claude")
42        || lower.contains("anthropic")
43        || lower.contains("sonnet")
44        || lower.contains("opus")
45        || lower.contains("haiku")
46    {
47        TokenizerFamily::Cl100k
48    } else if lower.contains("gemini") || lower.contains("google") {
49        TokenizerFamily::Gemini
50    } else if lower.contains("llama")
51        || lower.contains("codex")
52        || lower.contains("opencode")
53        || lower.contains("mistral")
54        || lower.contains("deepseek")
55        || lower.contains("qwen")
56    {
57        TokenizerFamily::Llama
58    } else {
59        TokenizerFamily::O200kBase
60    }
61}
62
63// ── Tokenizer Instances ────────────────────────────────────
64
65static BPE_O200K: OnceLock<CoreBPE> = OnceLock::new();
66static BPE_CL100K: OnceLock<CoreBPE> = OnceLock::new();
67
68fn get_bpe_o200k() -> &'static CoreBPE {
69    BPE_O200K
70        .get_or_init(|| tiktoken_rs::o200k_base().expect("failed to load o200k_base tokenizer"))
71}
72
73fn get_bpe_cl100k() -> &'static CoreBPE {
74    BPE_CL100K
75        .get_or_init(|| tiktoken_rs::cl100k_base().expect("failed to load cl100k_base tokenizer"))
76}
77
78fn bpe_for_family(family: TokenizerFamily) -> &'static CoreBPE {
79    match family {
80        TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
81        TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
82    }
83}
84
85/// Gemini tokens are ~8% larger on average vs o200k; empirically calibrated.
86const GEMINI_CORRECTION: f64 = 1.08;
87
88// ── Cache ──────────────────────────────────────────────────
89
90const TOKEN_CACHE_MAX: u64 = 4096;
91
92fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
93    static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
94    CACHE.get_or_init(|| {
95        moka::sync::Cache::builder()
96            .max_capacity(TOKEN_CACHE_MAX)
97            .build()
98    })
99}
100
101fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
102    let h = blake3::hash(text.as_bytes());
103    let bytes = h.as_bytes();
104    let base = u64::from_le_bytes([
105        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
106    ]);
107    base ^ (family as u64)
108}
109
110#[cfg(test)]
111fn floor_char_boundary(s: &str, idx: usize) -> usize {
112    let idx = idx.min(s.len());
113    let mut i = idx;
114    while i > 0 && !s.is_char_boundary(i) {
115        i -= 1;
116    }
117    i
118}
119
120#[cfg(test)]
121fn ceil_char_boundary(s: &str, idx: usize) -> usize {
122    let idx = idx.min(s.len());
123    let mut i = idx;
124    while i < s.len() && !s.is_char_boundary(i) {
125        i += 1;
126    }
127    i
128}
129
130// ── Public API ─────────────────────────────────────────────
131
132/// Counts BPE tokens using the default tokenizer (o200k_base).
133///
134/// Backward-compatible — equivalent to
135/// `count_tokens_for(text, TokenizerFamily::O200kBase)`.
136pub fn count_tokens(text: &str) -> usize {
137    count_tokens_for(text, TokenizerFamily::O200kBase)
138}
139
140/// Counts BPE tokens using the specified tokenizer family.
141pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
142    if text.is_empty() {
143        return 0;
144    }
145
146    let key = hash_text(text, family);
147    let cache = token_cache();
148
149    if let Some(cached) = cache.get(&key) {
150        return cached;
151    }
152
153    let raw = bpe_for_family(family)
154        .encode_with_special_tokens(text)
155        .len();
156    let count = if family == TokenizerFamily::Gemini {
157        (raw as f64 * GEMINI_CORRECTION).ceil() as usize
158    } else {
159        raw
160    };
161
162    cache.insert(key, count);
163    count
164}
165
166/// Encodes text into BPE token IDs (o200k_base).
167pub fn encode_tokens(text: &str) -> Vec<u32> {
168    if text.is_empty() {
169        return Vec::new();
170    }
171    get_bpe_o200k().encode_with_special_tokens(text)
172}
173
174/// Encodes text into BPE token IDs for the specified tokenizer family.
175///
176/// Gemini correction is not applied here — this returns raw token IDs.
177pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
178    if text.is_empty() {
179        return Vec::new();
180    }
181    bpe_for_family(family).encode_with_special_tokens(text)
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use std::sync::{Mutex, OnceLock};
188
189    fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
190        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
191        LOCK.get_or_init(|| Mutex::new(()))
192            .lock()
193            .unwrap_or_else(std::sync::PoisonError::into_inner)
194    }
195
196    fn reset_cache() {
197        token_cache().invalidate_all();
198    }
199
200    // ── Backward-compatible tests ──────────────────────────
201
202    #[test]
203    fn count_tokens_empty_is_zero() {
204        assert_eq!(count_tokens(""), 0);
205    }
206
207    #[test]
208    fn encode_tokens_empty_is_empty() {
209        assert!(encode_tokens("").is_empty());
210    }
211
212    #[test]
213    fn count_tokens_matches_encoded_length() {
214        let _lock = token_test_lock();
215        reset_cache();
216
217        let text = "hello world, Grüezi 🌍";
218        let counted = count_tokens(text);
219        let encoded = encode_tokens(text);
220        assert_eq!(counted, encoded.len());
221        assert_eq!(counted, count_tokens(text));
222    }
223
224    #[test]
225    fn char_boundary_helpers_handle_multibyte_indices() {
226        let s = "aé🙂z";
227        let emoji_start = s.find('🙂').expect("emoji exists");
228        let middle_of_emoji = emoji_start + 1;
229
230        let floor = floor_char_boundary(s, middle_of_emoji);
231        let ceil = ceil_char_boundary(s, middle_of_emoji);
232
233        assert!(s.is_char_boundary(floor));
234        assert!(s.is_char_boundary(ceil));
235        assert!(floor <= middle_of_emoji);
236        assert!(ceil >= middle_of_emoji);
237    }
238
239    #[test]
240    fn hash_text_is_stable_for_long_strings() {
241        let long = "abc🙂".repeat(300);
242        let h1 = hash_text(&long, TokenizerFamily::O200kBase);
243        let h2 = hash_text(&long, TokenizerFamily::O200kBase);
244        assert_eq!(h1, h2);
245        assert!(count_tokens(&long) > 0);
246    }
247
248    // ── Multi-tokenizer tests ──────────────────────────────
249
250    #[test]
251    fn tokenizer_family_default_is_o200k() {
252        assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
253    }
254
255    #[test]
256    fn tokenizer_family_display() {
257        assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
258        assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
259        assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
260        assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
261    }
262
263    #[test]
264    fn detect_tokenizer_openai_variants() {
265        assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
266        assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
267        assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
268        assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
269    }
270
271    #[test]
272    fn detect_tokenizer_claude_variants() {
273        assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
274        assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
275        assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
276    }
277
278    #[test]
279    fn detect_tokenizer_gemini_variants() {
280        assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
281        assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
282        assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
283    }
284
285    #[test]
286    fn detect_tokenizer_llama_variants() {
287        assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
288        assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
289        assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
290    }
291
292    #[test]
293    fn detect_tokenizer_unknown_defaults_to_o200k() {
294        assert_eq!(
295            detect_tokenizer("unknown-model"),
296            TokenizerFamily::O200kBase
297        );
298        assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
299    }
300
301    #[test]
302    fn count_tokens_for_all_families_nonzero() {
303        let _lock = token_test_lock();
304        reset_cache();
305
306        let text = "fn main() { println!(\"hello\"); }";
307        for family in [
308            TokenizerFamily::O200kBase,
309            TokenizerFamily::Cl100k,
310            TokenizerFamily::Gemini,
311            TokenizerFamily::Llama,
312        ] {
313            let count = count_tokens_for(text, family);
314            assert!(count > 0, "{family} returned 0 tokens");
315        }
316    }
317
318    #[test]
319    fn count_tokens_for_empty_is_zero_all_families() {
320        for family in [
321            TokenizerFamily::O200kBase,
322            TokenizerFamily::Cl100k,
323            TokenizerFamily::Gemini,
324            TokenizerFamily::Llama,
325        ] {
326            assert_eq!(count_tokens_for("", family), 0);
327        }
328    }
329
330    #[test]
331    fn gemini_count_exceeds_raw_o200k() {
332        let _lock = token_test_lock();
333        reset_cache();
334
335        let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
336        let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
337        let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
338        assert!(
339            gemini > o200k,
340            "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
341        );
342    }
343
344    #[test]
345    fn cl100k_differs_from_o200k() {
346        let _lock = token_test_lock();
347        reset_cache();
348
349        let text =
350            "use std::collections::HashMap;\nfn main() {\n    let mut map = HashMap::new();\n}";
351        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
352        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
353        assert!(o200k > 0);
354        assert!(cl100k > 0);
355    }
356
357    #[test]
358    fn encode_tokens_for_matches_count() {
359        let _lock = token_test_lock();
360        reset_cache();
361
362        let text = "hello world";
363        for family in [
364            TokenizerFamily::O200kBase,
365            TokenizerFamily::Cl100k,
366            TokenizerFamily::Llama,
367        ] {
368            let encoded = encode_tokens_for(text, family);
369            let raw_count = bpe_for_family(family)
370                .encode_with_special_tokens(text)
371                .len();
372            assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
373        }
374    }
375
376    #[test]
377    fn cache_distinguishes_families() {
378        let _lock = token_test_lock();
379        reset_cache();
380
381        let text = "cache test string";
382        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
383        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
384
385        let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
386        let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
387        assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
388
389        assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
390        assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
391    }
392
393    #[test]
394    fn default_count_tokens_is_o200k() {
395        let _lock = token_test_lock();
396        reset_cache();
397
398        let text = "backward compat check";
399        assert_eq!(
400            count_tokens(text),
401            count_tokens_for(text, TokenizerFamily::O200kBase)
402        );
403    }
404}