Skip to main content

lean_ctx/core/
tokens.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3use tiktoken_rs::CoreBPE;
4
5// ── Tokenizer Families ─────────────────────────────────────
6
7/// Tokenizer families for different LLM providers.
8///
9/// Different LLM families use different tokenizers, leading to 5–15% variance
10/// in token counts. This enum lets callers select the appropriate tokenizer.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
12pub enum TokenizerFamily {
13    /// GPT-4o, GPT-4-turbo (tiktoken o200k_base, exact)
14    #[default]
15    O200kBase,
16    /// Claude / Anthropic (approximated via tiktoken cl100k_base)
17    Cl100k,
18    /// Gemini / Google (o200k_base with 1.1× correction factor)
19    Gemini,
20    /// Llama 3+ (approximated via cl100k_base)
21    Llama,
22}
23
24impl std::fmt::Display for TokenizerFamily {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::O200kBase => write!(f, "o200k_base"),
28            Self::Cl100k => write!(f, "cl100k_base"),
29            Self::Gemini => write!(f, "gemini"),
30            Self::Llama => write!(f, "llama"),
31        }
32    }
33}
34
35/// Detects the appropriate tokenizer family from a client or model name.
36///
37/// Matches are case-insensitive substrings. Falls back to `O200kBase`.
38/// Accuracy: cl100k is within ~3% of Claude's actual tokenizer;
39/// Gemini correction factor 1.08 is empirically calibrated; o200k is exact for GPT-4o+.
40pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
41    let lower = client_name.to_ascii_lowercase();
42    if lower.contains("claude")
43        || lower.contains("anthropic")
44        || lower.contains("sonnet")
45        || lower.contains("opus")
46        || lower.contains("haiku")
47    {
48        TokenizerFamily::Cl100k
49    } else if lower.contains("gemini") || lower.contains("google") {
50        TokenizerFamily::Gemini
51    } else if lower.contains("llama")
52        || lower.contains("codex")
53        || lower.contains("opencode")
54        || lower.contains("mistral")
55        || lower.contains("deepseek")
56        || lower.contains("qwen")
57    {
58        TokenizerFamily::Llama
59    } else {
60        TokenizerFamily::O200kBase
61    }
62}
63
64// ── Tokenizer Instances ────────────────────────────────────
65
66static BPE_O200K: OnceLock<CoreBPE> = OnceLock::new();
67static BPE_CL100K: OnceLock<CoreBPE> = OnceLock::new();
68
69fn get_bpe_o200k() -> &'static CoreBPE {
70    BPE_O200K
71        .get_or_init(|| tiktoken_rs::o200k_base().expect("failed to load o200k_base tokenizer"))
72}
73
74fn get_bpe_cl100k() -> &'static CoreBPE {
75    BPE_CL100K
76        .get_or_init(|| tiktoken_rs::cl100k_base().expect("failed to load cl100k_base tokenizer"))
77}
78
79fn bpe_for_family(family: TokenizerFamily) -> &'static CoreBPE {
80    match family {
81        TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
82        TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
83    }
84}
85
86/// Gemini tokens are ~8% larger on average vs o200k; empirically calibrated.
87const GEMINI_CORRECTION: f64 = 1.08;
88
89// ── Cache ──────────────────────────────────────────────────
90
91const TOKEN_CACHE_MAX: usize = 256;
92
93static TOKEN_CACHE: Mutex<Option<HashMap<u64, usize>>> = Mutex::new(None);
94
95fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
96    use std::hash::{Hash, Hasher};
97    let mut hasher = std::collections::hash_map::DefaultHasher::new();
98    family.hash(&mut hasher);
99    text.len().hash(&mut hasher);
100    if text.len() <= 512 {
101        text.hash(&mut hasher);
102    } else {
103        let start_end = floor_char_boundary(text, 256);
104        let tail_start = ceil_char_boundary(text, text.len() - 256);
105        text[..start_end].hash(&mut hasher);
106        text[tail_start..].hash(&mut hasher);
107    }
108    hasher.finish()
109}
110
111fn floor_char_boundary(s: &str, idx: usize) -> usize {
112    let idx = idx.min(s.len());
113    let mut i = idx;
114    while i > 0 && !s.is_char_boundary(i) {
115        i -= 1;
116    }
117    i
118}
119
120fn ceil_char_boundary(s: &str, idx: usize) -> usize {
121    let idx = idx.min(s.len());
122    let mut i = idx;
123    while i < s.len() && !s.is_char_boundary(i) {
124        i += 1;
125    }
126    i
127}
128
129// ── Public API ─────────────────────────────────────────────
130
131/// Counts BPE tokens using the default tokenizer (o200k_base).
132///
133/// Backward-compatible — equivalent to
134/// `count_tokens_for(text, TokenizerFamily::O200kBase)`.
135pub fn count_tokens(text: &str) -> usize {
136    count_tokens_for(text, TokenizerFamily::O200kBase)
137}
138
139/// Counts BPE tokens using the specified tokenizer family.
140pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
141    if text.is_empty() {
142        return 0;
143    }
144
145    let key = hash_text(text, family);
146
147    if let Ok(guard) = TOKEN_CACHE.lock() {
148        if let Some(ref map) = *guard {
149            if let Some(&cached) = map.get(&key) {
150                return cached;
151            }
152        }
153    }
154
155    let raw = bpe_for_family(family)
156        .encode_with_special_tokens(text)
157        .len();
158    let count = if family == TokenizerFamily::Gemini {
159        (raw as f64 * GEMINI_CORRECTION).ceil() as usize
160    } else {
161        raw
162    };
163
164    if let Ok(mut guard) = TOKEN_CACHE.lock() {
165        let map = guard.get_or_insert_with(HashMap::new);
166        if map.len() >= TOKEN_CACHE_MAX {
167            map.clear();
168        }
169        map.insert(key, count);
170    }
171
172    count
173}
174
175/// Encodes text into BPE token IDs (o200k_base).
176pub fn encode_tokens(text: &str) -> Vec<u32> {
177    if text.is_empty() {
178        return Vec::new();
179    }
180    get_bpe_o200k().encode_with_special_tokens(text)
181}
182
183/// Encodes text into BPE token IDs for the specified tokenizer family.
184///
185/// Gemini correction is not applied here — this returns raw token IDs.
186pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
187    if text.is_empty() {
188        return Vec::new();
189    }
190    bpe_for_family(family).encode_with_special_tokens(text)
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use std::collections::HashMap;
197    use std::sync::{Mutex, OnceLock};
198
199    fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
200        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
201        LOCK.get_or_init(|| Mutex::new(()))
202            .lock()
203            .unwrap_or_else(std::sync::PoisonError::into_inner)
204    }
205
206    fn reset_cache() {
207        if let Ok(mut guard) = TOKEN_CACHE.lock() {
208            *guard = Some(HashMap::new());
209        }
210    }
211
212    // ── Backward-compatible tests ──────────────────────────
213
214    #[test]
215    fn count_tokens_empty_is_zero() {
216        assert_eq!(count_tokens(""), 0);
217    }
218
219    #[test]
220    fn encode_tokens_empty_is_empty() {
221        assert!(encode_tokens("").is_empty());
222    }
223
224    #[test]
225    fn count_tokens_matches_encoded_length() {
226        let _lock = token_test_lock();
227        reset_cache();
228
229        let text = "hello world, Grüezi 🌍";
230        let counted = count_tokens(text);
231        let encoded = encode_tokens(text);
232        assert_eq!(counted, encoded.len());
233        assert_eq!(counted, count_tokens(text));
234    }
235
236    #[test]
237    fn char_boundary_helpers_handle_multibyte_indices() {
238        let s = "aé🙂z";
239        let emoji_start = s.find('🙂').expect("emoji exists");
240        let middle_of_emoji = emoji_start + 1;
241
242        let floor = floor_char_boundary(s, middle_of_emoji);
243        let ceil = ceil_char_boundary(s, middle_of_emoji);
244
245        assert!(s.is_char_boundary(floor));
246        assert!(s.is_char_boundary(ceil));
247        assert!(floor <= middle_of_emoji);
248        assert!(ceil >= middle_of_emoji);
249    }
250
251    #[test]
252    fn hash_text_is_stable_for_long_strings() {
253        let long = "abc🙂".repeat(300);
254        let h1 = hash_text(&long, TokenizerFamily::O200kBase);
255        let h2 = hash_text(&long, TokenizerFamily::O200kBase);
256        assert_eq!(h1, h2);
257        assert!(count_tokens(&long) > 0);
258    }
259
260    // ── Multi-tokenizer tests ──────────────────────────────
261
262    #[test]
263    fn tokenizer_family_default_is_o200k() {
264        assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
265    }
266
267    #[test]
268    fn tokenizer_family_display() {
269        assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
270        assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
271        assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
272        assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
273    }
274
275    #[test]
276    fn detect_tokenizer_openai_variants() {
277        assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
278        assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
279        assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
280        assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
281    }
282
283    #[test]
284    fn detect_tokenizer_claude_variants() {
285        assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
286        assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
287        assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
288    }
289
290    #[test]
291    fn detect_tokenizer_gemini_variants() {
292        assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
293        assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
294        assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
295    }
296
297    #[test]
298    fn detect_tokenizer_llama_variants() {
299        assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
300        assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
301        assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
302    }
303
304    #[test]
305    fn detect_tokenizer_unknown_defaults_to_o200k() {
306        assert_eq!(
307            detect_tokenizer("unknown-model"),
308            TokenizerFamily::O200kBase
309        );
310        assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
311    }
312
313    #[test]
314    fn count_tokens_for_all_families_nonzero() {
315        let _lock = token_test_lock();
316        reset_cache();
317
318        let text = "fn main() { println!(\"hello\"); }";
319        for family in [
320            TokenizerFamily::O200kBase,
321            TokenizerFamily::Cl100k,
322            TokenizerFamily::Gemini,
323            TokenizerFamily::Llama,
324        ] {
325            let count = count_tokens_for(text, family);
326            assert!(count > 0, "{family} returned 0 tokens");
327        }
328    }
329
330    #[test]
331    fn count_tokens_for_empty_is_zero_all_families() {
332        for family in [
333            TokenizerFamily::O200kBase,
334            TokenizerFamily::Cl100k,
335            TokenizerFamily::Gemini,
336            TokenizerFamily::Llama,
337        ] {
338            assert_eq!(count_tokens_for("", family), 0);
339        }
340    }
341
342    #[test]
343    fn gemini_count_exceeds_raw_o200k() {
344        let _lock = token_test_lock();
345        reset_cache();
346
347        let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
348        let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
349        let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
350        assert!(
351            gemini > o200k,
352            "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
353        );
354    }
355
356    #[test]
357    fn cl100k_differs_from_o200k() {
358        let _lock = token_test_lock();
359        reset_cache();
360
361        let text =
362            "use std::collections::HashMap;\nfn main() {\n    let mut map = HashMap::new();\n}";
363        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
364        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
365        assert!(o200k > 0);
366        assert!(cl100k > 0);
367    }
368
369    #[test]
370    fn encode_tokens_for_matches_count() {
371        let _lock = token_test_lock();
372        reset_cache();
373
374        let text = "hello world";
375        for family in [
376            TokenizerFamily::O200kBase,
377            TokenizerFamily::Cl100k,
378            TokenizerFamily::Llama,
379        ] {
380            let encoded = encode_tokens_for(text, family);
381            let raw_count = bpe_for_family(family)
382                .encode_with_special_tokens(text)
383                .len();
384            assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
385        }
386    }
387
388    #[test]
389    fn cache_distinguishes_families() {
390        let _lock = token_test_lock();
391        reset_cache();
392
393        let text = "cache test string";
394        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
395        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
396
397        let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
398        let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
399        assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
400
401        assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
402        assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
403    }
404
405    #[test]
406    fn default_count_tokens_is_o200k() {
407        let _lock = token_test_lock();
408        reset_cache();
409
410        let text = "backward compat check";
411        assert_eq!(
412            count_tokens(text),
413            count_tokens_for(text, TokenizerFamily::O200kBase)
414        );
415    }
416}