Skip to main content

lean_ctx/core/
tokens.rs

1use std::sync::OnceLock;
2use tiktoken_rs::CoreBPE;
3
4// ── Tokenizer Families ─────────────────────────────────────
5
6/// Tokenizer families for different LLM providers.
7///
8/// Different LLM families use different tokenizers, leading to 5–15% variance
9/// in token counts. This enum lets callers select the appropriate tokenizer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11pub enum TokenizerFamily {
12    /// GPT-4o, GPT-4-turbo (tiktoken o200k_base, exact)
13    #[default]
14    O200kBase,
15    /// Claude / Anthropic (approximated via tiktoken cl100k_base)
16    Cl100k,
17    /// Gemini / Google (o200k_base with 1.1× correction factor)
18    Gemini,
19    /// Llama 3+ (approximated via cl100k_base)
20    Llama,
21}
22
23impl std::fmt::Display for TokenizerFamily {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::O200kBase => write!(f, "o200k_base"),
27            Self::Cl100k => write!(f, "cl100k_base"),
28            Self::Gemini => write!(f, "gemini"),
29            Self::Llama => write!(f, "llama"),
30        }
31    }
32}
33
34/// Detects the appropriate tokenizer family from a client or model name.
35///
36/// Matches are case-insensitive substrings. Falls back to `O200kBase`.
37/// Accuracy: cl100k is within ~3% of Claude's actual tokenizer;
38/// Gemini correction factor 1.08 is empirically calibrated; o200k is exact for GPT-4o+.
39pub fn detect_tokenizer(client_name: &str) -> TokenizerFamily {
40    let lower = client_name.to_ascii_lowercase();
41    if lower.contains("claude")
42        || lower.contains("anthropic")
43        || lower.contains("sonnet")
44        || lower.contains("opus")
45        || lower.contains("haiku")
46    {
47        TokenizerFamily::Cl100k
48    } else if lower.contains("gemini") || lower.contains("google") {
49        TokenizerFamily::Gemini
50    } else if lower.contains("llama")
51        || lower.contains("codex")
52        || lower.contains("opencode")
53        || lower.contains("mistral")
54        || lower.contains("deepseek")
55        || lower.contains("qwen")
56    {
57        TokenizerFamily::Llama
58    } else {
59        TokenizerFamily::O200kBase
60    }
61}
62
63// ── Tokenizer Instances ────────────────────────────────────
64
65static BPE_O200K: OnceLock<Option<CoreBPE>> = OnceLock::new();
66static BPE_CL100K: OnceLock<Option<CoreBPE>> = OnceLock::new();
67
68fn get_bpe_o200k() -> Option<&'static CoreBPE> {
69    BPE_O200K
70        .get_or_init(|| {
71            tiktoken_rs::o200k_base()
72                .inspect_err(|e| tracing::error!("failed to load o200k_base tokenizer: {e}"))
73                .ok()
74        })
75        .as_ref()
76}
77
78fn get_bpe_cl100k() -> Option<&'static CoreBPE> {
79    BPE_CL100K
80        .get_or_init(|| {
81            tiktoken_rs::cl100k_base()
82                .inspect_err(|e| tracing::error!("failed to load cl100k_base tokenizer: {e}"))
83                .ok()
84        })
85        .as_ref()
86}
87
88fn bpe_for_family(family: TokenizerFamily) -> Option<&'static CoreBPE> {
89    match family {
90        TokenizerFamily::O200kBase | TokenizerFamily::Gemini => get_bpe_o200k(),
91        TokenizerFamily::Cl100k | TokenizerFamily::Llama => get_bpe_cl100k(),
92    }
93}
94
95const CHARS_PER_TOKEN_ESTIMATE: f64 = 3.5;
96
97/// Gemini tokens are ~8% larger on average vs o200k; empirically calibrated.
98const GEMINI_CORRECTION: f64 = 1.08;
99
100// ── Cache ──────────────────────────────────────────────────
101
102const TOKEN_CACHE_MAX: u64 = 4096;
103
104fn token_cache() -> &'static moka::sync::Cache<u64, usize> {
105    static CACHE: std::sync::OnceLock<moka::sync::Cache<u64, usize>> = std::sync::OnceLock::new();
106    CACHE.get_or_init(|| {
107        moka::sync::Cache::builder()
108            .max_capacity(TOKEN_CACHE_MAX)
109            .build()
110    })
111}
112
113fn hash_text(text: &str, family: TokenizerFamily) -> u64 {
114    let h = blake3::hash(text.as_bytes());
115    let bytes = h.as_bytes();
116    let base = u64::from_le_bytes([
117        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
118    ]);
119    base ^ (family as u64)
120}
121
122#[cfg(test)]
123fn floor_char_boundary(s: &str, idx: usize) -> usize {
124    let idx = idx.min(s.len());
125    let mut i = idx;
126    while i > 0 && !s.is_char_boundary(i) {
127        i -= 1;
128    }
129    i
130}
131
132#[cfg(test)]
133fn ceil_char_boundary(s: &str, idx: usize) -> usize {
134    let idx = idx.min(s.len());
135    let mut i = idx;
136    while i < s.len() && !s.is_char_boundary(i) {
137        i += 1;
138    }
139    i
140}
141
142// ── Public API ─────────────────────────────────────────────
143
144/// Counts BPE tokens using the default tokenizer (o200k_base).
145///
146/// Backward-compatible — equivalent to
147/// `count_tokens_for(text, TokenizerFamily::O200kBase)`.
148pub fn count_tokens(text: &str) -> usize {
149    count_tokens_for(text, COUNTING_FAMILY)
150}
151
152/// The tokenizer family [`count_tokens`] uses for all read/savings accounting.
153///
154/// Centralised so honesty surfaces (the savings ledger, Wrapped) can record exactly
155/// which tokenizer produced their token counts, rather than assuming the model's own.
156pub const COUNTING_FAMILY: TokenizerFamily = TokenizerFamily::O200kBase;
157
158/// Label of the tokenizer used for counting (e.g. `"o200k_base"`).
159pub fn counting_family_label() -> String {
160    COUNTING_FAMILY.to_string()
161}
162
163/// Counts BPE tokens using the specified tokenizer family.
164pub fn count_tokens_for(text: &str, family: TokenizerFamily) -> usize {
165    if text.is_empty() {
166        return 0;
167    }
168
169    let key = hash_text(text, family);
170    let cache = token_cache();
171
172    if let Some(cached) = cache.get(&key) {
173        return cached;
174    }
175
176    let Some(bpe) = bpe_for_family(family) else {
177        let estimate = (text.len() as f64 / CHARS_PER_TOKEN_ESTIMATE).ceil() as usize;
178        cache.insert(key, estimate);
179        return estimate;
180    };
181    let raw = bpe.encode_with_special_tokens(text).len();
182    let count = if family == TokenizerFamily::Gemini {
183        (raw as f64 * GEMINI_CORRECTION).ceil() as usize
184    } else {
185        raw
186    };
187
188    cache.insert(key, count);
189    count
190}
191
192/// Encodes text into BPE token IDs (o200k_base).
193pub fn encode_tokens(text: &str) -> Vec<u32> {
194    if text.is_empty() {
195        return Vec::new();
196    }
197    match get_bpe_o200k() {
198        Some(bpe) => bpe.encode_with_special_tokens(text),
199        None => Vec::new(),
200    }
201}
202
203/// Encodes text into BPE token IDs for the specified tokenizer family.
204///
205/// Gemini correction is not applied here — this returns raw token IDs.
206pub fn encode_tokens_for(text: &str, family: TokenizerFamily) -> Vec<u32> {
207    if text.is_empty() {
208        return Vec::new();
209    }
210    match bpe_for_family(family) {
211        Some(bpe) => bpe.encode_with_special_tokens(text),
212        None => Vec::new(),
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use std::sync::{Mutex, OnceLock};
220
221    fn token_test_lock() -> std::sync::MutexGuard<'static, ()> {
222        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
223        LOCK.get_or_init(|| Mutex::new(()))
224            .lock()
225            .unwrap_or_else(std::sync::PoisonError::into_inner)
226    }
227
228    fn reset_cache() {
229        token_cache().invalidate_all();
230    }
231
232    // ── Backward-compatible tests ──────────────────────────
233
234    #[test]
235    fn count_tokens_empty_is_zero() {
236        assert_eq!(count_tokens(""), 0);
237    }
238
239    #[test]
240    fn encode_tokens_empty_is_empty() {
241        assert!(encode_tokens("").is_empty());
242    }
243
244    #[test]
245    fn count_tokens_matches_encoded_length() {
246        let _lock = token_test_lock();
247        reset_cache();
248
249        let text = "hello world, Grüezi 🌍";
250        let counted = count_tokens(text);
251        let encoded = encode_tokens(text);
252        assert_eq!(counted, encoded.len());
253        assert_eq!(counted, count_tokens(text));
254    }
255
256    #[test]
257    fn char_boundary_helpers_handle_multibyte_indices() {
258        let s = "aé🙂z";
259        let emoji_start = s.find('🙂').expect("emoji exists");
260        let middle_of_emoji = emoji_start + 1;
261
262        let floor = floor_char_boundary(s, middle_of_emoji);
263        let ceil = ceil_char_boundary(s, middle_of_emoji);
264
265        assert!(s.is_char_boundary(floor));
266        assert!(s.is_char_boundary(ceil));
267        assert!(floor <= middle_of_emoji);
268        assert!(ceil >= middle_of_emoji);
269    }
270
271    #[test]
272    fn hash_text_is_stable_for_long_strings() {
273        let long = "abc🙂".repeat(300);
274        let h1 = hash_text(&long, TokenizerFamily::O200kBase);
275        let h2 = hash_text(&long, TokenizerFamily::O200kBase);
276        assert_eq!(h1, h2);
277        assert!(count_tokens(&long) > 0);
278    }
279
280    // ── Multi-tokenizer tests ──────────────────────────────
281
282    #[test]
283    fn tokenizer_family_default_is_o200k() {
284        assert_eq!(TokenizerFamily::default(), TokenizerFamily::O200kBase);
285    }
286
287    #[test]
288    fn tokenizer_family_display() {
289        assert_eq!(TokenizerFamily::O200kBase.to_string(), "o200k_base");
290        assert_eq!(TokenizerFamily::Cl100k.to_string(), "cl100k_base");
291        assert_eq!(TokenizerFamily::Gemini.to_string(), "gemini");
292        assert_eq!(TokenizerFamily::Llama.to_string(), "llama");
293    }
294
295    #[test]
296    fn detect_tokenizer_openai_variants() {
297        assert_eq!(detect_tokenizer("cursor"), TokenizerFamily::O200kBase);
298        assert_eq!(detect_tokenizer("openai"), TokenizerFamily::O200kBase);
299        assert_eq!(detect_tokenizer("gpt-4o"), TokenizerFamily::O200kBase);
300        assert_eq!(detect_tokenizer("GPT-4-turbo"), TokenizerFamily::O200kBase);
301    }
302
303    #[test]
304    fn detect_tokenizer_claude_variants() {
305        assert_eq!(detect_tokenizer("claude-3.5"), TokenizerFamily::Cl100k);
306        assert_eq!(detect_tokenizer("anthropic"), TokenizerFamily::Cl100k);
307        assert_eq!(detect_tokenizer("Claude"), TokenizerFamily::Cl100k);
308    }
309
310    #[test]
311    fn detect_tokenizer_gemini_variants() {
312        assert_eq!(detect_tokenizer("gemini-pro"), TokenizerFamily::Gemini);
313        assert_eq!(detect_tokenizer("google"), TokenizerFamily::Gemini);
314        assert_eq!(detect_tokenizer("Gemini-1.5"), TokenizerFamily::Gemini);
315    }
316
317    #[test]
318    fn detect_tokenizer_llama_variants() {
319        assert_eq!(detect_tokenizer("llama-3"), TokenizerFamily::Llama);
320        assert_eq!(detect_tokenizer("codex"), TokenizerFamily::Llama);
321        assert_eq!(detect_tokenizer("opencode"), TokenizerFamily::Llama);
322    }
323
324    #[test]
325    fn detect_tokenizer_unknown_defaults_to_o200k() {
326        assert_eq!(
327            detect_tokenizer("unknown-model"),
328            TokenizerFamily::O200kBase
329        );
330        assert_eq!(detect_tokenizer(""), TokenizerFamily::O200kBase);
331    }
332
333    #[test]
334    fn count_tokens_for_all_families_nonzero() {
335        let _lock = token_test_lock();
336        reset_cache();
337
338        let text = "fn main() { println!(\"hello\"); }";
339        for family in [
340            TokenizerFamily::O200kBase,
341            TokenizerFamily::Cl100k,
342            TokenizerFamily::Gemini,
343            TokenizerFamily::Llama,
344        ] {
345            let count = count_tokens_for(text, family);
346            assert!(count > 0, "{family} returned 0 tokens");
347        }
348    }
349
350    #[test]
351    fn count_tokens_for_empty_is_zero_all_families() {
352        for family in [
353            TokenizerFamily::O200kBase,
354            TokenizerFamily::Cl100k,
355            TokenizerFamily::Gemini,
356            TokenizerFamily::Llama,
357        ] {
358            assert_eq!(count_tokens_for("", family), 0);
359        }
360    }
361
362    #[test]
363    fn gemini_count_exceeds_raw_o200k() {
364        let _lock = token_test_lock();
365        reset_cache();
366
367        let text = "The quick brown fox jumps over the lazy dog. ".repeat(20);
368        let o200k = count_tokens_for(&text, TokenizerFamily::O200kBase);
369        let gemini = count_tokens_for(&text, TokenizerFamily::Gemini);
370        assert!(
371            gemini > o200k,
372            "Gemini ({gemini}) should exceed O200kBase ({o200k}) due to 1.1× correction"
373        );
374    }
375
376    #[test]
377    fn cl100k_differs_from_o200k() {
378        let _lock = token_test_lock();
379        reset_cache();
380
381        let text =
382            "use std::collections::HashMap;\nfn main() {\n    let mut map = HashMap::new();\n}";
383        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
384        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
385        assert!(o200k > 0);
386        assert!(cl100k > 0);
387    }
388
389    #[test]
390    fn encode_tokens_for_matches_count() {
391        let _lock = token_test_lock();
392        reset_cache();
393
394        let text = "hello world";
395        for family in [
396            TokenizerFamily::O200kBase,
397            TokenizerFamily::Cl100k,
398            TokenizerFamily::Llama,
399        ] {
400            let encoded = encode_tokens_for(text, family);
401            let raw_count = bpe_for_family(family)
402                .unwrap()
403                .encode_with_special_tokens(text)
404                .len();
405            assert_eq!(encoded.len(), raw_count, "mismatch for {family}");
406        }
407    }
408
409    #[test]
410    fn cache_distinguishes_families() {
411        let _lock = token_test_lock();
412        reset_cache();
413
414        let text = "cache test string";
415        let o200k = count_tokens_for(text, TokenizerFamily::O200kBase);
416        let cl100k = count_tokens_for(text, TokenizerFamily::Cl100k);
417
418        let h_o200k = hash_text(text, TokenizerFamily::O200kBase);
419        let h_cl100k = hash_text(text, TokenizerFamily::Cl100k);
420        assert_ne!(h_o200k, h_cl100k, "cache keys must differ across families");
421
422        assert_eq!(o200k, count_tokens_for(text, TokenizerFamily::O200kBase));
423        assert_eq!(cl100k, count_tokens_for(text, TokenizerFamily::Cl100k));
424    }
425
426    #[test]
427    fn default_count_tokens_is_o200k() {
428        let _lock = token_test_lock();
429        reset_cache();
430
431        let text = "backward compat check";
432        assert_eq!(
433            count_tokens(text),
434            count_tokens_for(text, TokenizerFamily::O200kBase)
435        );
436    }
437}