infiniloom_engine/tokenizer/
core.rs

1//! Core tokenizer implementation
2//!
3//! This module provides the main Tokenizer struct with accurate BPE tokenization
4//! for OpenAI models and estimation-based counting for other models.
5
6use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13/// Global tokenizer instances (lazy initialized, thread-safe)
14static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17/// Global token count cache - keyed by (content_hash, model)
18/// This provides significant speedup when the same content is tokenized multiple times.
19static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21/// Maximum number of entries in the token cache before eviction.
22/// 100K entries ≈ 2.4MB memory (24 bytes per entry: 8 + 8 + 4 + padding).
23/// This prevents unbounded memory growth in long-running processes.
24const MAX_CACHE_ENTRIES: usize = 100_000;
25
26/// Get or initialize the global token cache
27fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28    TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31/// Check if cache needs cleanup and clear if it exceeds the limit.
32/// Uses a simple strategy: when cache is full, clear it entirely.
33/// This is fast and avoids complex LRU tracking overhead.
34fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35    if cache.len() >= MAX_CACHE_ENTRIES {
36        cache.clear();
37    }
38}
39
40/// Compute a fast hash of content for cache keys
41fn hash_content(content: &str) -> u64 {
42    use std::collections::hash_map::DefaultHasher;
43    let mut hasher = DefaultHasher::new();
44    content.hash(&mut hasher);
45    hasher.finish()
46}
47
48/// Get or initialize the GPT-4o tokenizer (o200k_base)
49fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50    GPT4O_TOKENIZER.get_or_init(|| {
51        o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52    })
53}
54
55/// Get or initialize the GPT-4 tokenizer (cl100k_base)
56fn get_gpt4_tokenizer() -> &'static CoreBPE {
57    GPT4_TOKENIZER.get_or_init(|| {
58        cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59    })
60}
61
62/// Pre-computed statistics for token estimation.
63/// Computed once in a single pass, then used for all estimation-based models.
64#[derive(Clone, Copy)]
65struct EstimationStats {
66    len: usize,
67    whitespace_count: u32,
68    newline_count: u32,
69    special_char_count: u32,
70}
71
72/// Accurate token counter with fallback to estimation
73///
74/// The tokenizer supports caching to avoid re-computing token counts for the same content.
75/// This is particularly useful when processing files multiple times or across different
76/// operations.
77pub struct Tokenizer {
78    /// Use exact tokenization when available
79    use_exact: bool,
80    /// Use global cache for token counts
81    use_cache: bool,
82}
83
84impl Default for Tokenizer {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Tokenizer {
91    /// Create a new tokenizer with exact mode and caching enabled
92    pub fn new() -> Self {
93        Self { use_exact: true, use_cache: true }
94    }
95
96    /// Create a tokenizer that only uses estimation (faster but less accurate)
97    pub fn estimation_only() -> Self {
98        Self { use_exact: false, use_cache: true }
99    }
100
101    /// Create a tokenizer without caching (useful for benchmarks or one-off counts)
102    pub fn without_cache() -> Self {
103        Self { use_exact: true, use_cache: false }
104    }
105
106    /// Count tokens for a specific model.
107    ///
108    /// When caching is enabled, results are stored in a global cache keyed by
109    /// content hash and model. This provides significant speedup for repeated
110    /// tokenization of the same content.
111    ///
112    /// # Returns
113    ///
114    /// The token count for the specified model. For OpenAI models (GPT-4o, GPT-4, etc.),
115    /// this is exact via tiktoken. For other models, it's a calibrated estimation.
116    #[must_use]
117    pub fn count(&self, text: &str, model: TokenModel) -> u32 {
118        if text.is_empty() {
119            return 0;
120        }
121
122        if self.use_cache {
123            let cache = get_token_cache();
124            let content_hash = hash_content(text);
125            let key = (content_hash, model);
126
127            // Check cache first
128            if let Some(count) = cache.get(&key) {
129                return *count;
130            }
131
132            // Compute and cache (with size limit enforcement)
133            let count = self.count_uncached(text, model);
134            maybe_cleanup_cache(cache);
135            cache.insert(key, count);
136            count
137        } else {
138            self.count_uncached(text, model)
139        }
140    }
141
142    /// Count tokens without using cache
143    fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
144        if self.use_exact && model.has_exact_tokenizer() {
145            self.count_exact(text, model)
146        } else {
147            self.estimate(text, model)
148        }
149    }
150
151    /// Count tokens using exact BPE encoding.
152    /// Falls back to estimation if tiktoken panics (rare edge cases with unusual byte sequences).
153    /// Panic output is suppressed to avoid polluting stderr.
154    fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
155        if model.uses_o200k() {
156            // All modern OpenAI models use o200k_base encoding
157            // GPT-5.x, GPT-4o, O1, O3, O4
158            let tokenizer = get_gpt4o_tokenizer();
159            self.tokenize_with_panic_guard(tokenizer, text, model)
160        } else if model.uses_cl100k() {
161            // Legacy OpenAI models use cl100k_base encoding
162            // GPT-4, GPT-3.5-turbo
163            let tokenizer = get_gpt4_tokenizer();
164            self.tokenize_with_panic_guard(tokenizer, text, model)
165        } else {
166            // Non-OpenAI models use estimation
167            self.estimate(text, model)
168        }
169    }
170
171    /// Tokenize text with panic guard that suppresses stderr output.
172    /// This prevents panic stack traces from polluting application logs.
173    fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
174        // Temporarily suppress panic output by setting a no-op panic hook
175        let prev_hook = std::panic::take_hook();
176        std::panic::set_hook(Box::new(|_| {
177            // Silently ignore panic - we'll fall back to estimation
178        }));
179
180        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181            tokenizer.encode_ordinary(text).len() as u32
182        }));
183
184        // Restore the previous panic hook
185        std::panic::set_hook(prev_hook);
186
187        match result {
188            Ok(count) => count,
189            Err(_) => self.estimate(text, model), // Fallback to estimation on panic
190        }
191    }
192
193    /// Estimate tokens using character-based heuristics.
194    /// Uses single-pass character counting for efficiency.
195    fn estimate(&self, text: &str, model: TokenModel) -> u32 {
196        if text.is_empty() {
197            return 0;
198        }
199        let stats = compute_estimation_stats(text);
200        estimate_from_stats(&stats, model)
201    }
202
203
204    /// Count tokens for all supported models at once.
205    ///
206    /// **Optimized**: Computes hash once, estimation stats once, and reuses them
207    /// for all models. This is ~10x faster than calling count() 10 times.
208    ///
209    /// Returns counts for representative models from each encoding family:
210    /// - `o200k`: GPT-5.x, GPT-4o, O1/O3/O4 (all use same tokenizer)
211    /// - `cl100k`: GPT-4, GPT-3.5-turbo (legacy, same tokenizer)
212    /// - Other vendors use estimation
213    pub fn count_all(&self, text: &str) -> TokenCounts {
214        if text.is_empty() {
215            return TokenCounts::default();
216        }
217
218        // Compute hash once for cache lookups
219        let content_hash = hash_content(text);
220        let cache = if self.use_cache { Some(get_token_cache()) } else { None };
221
222        // Helper to get cached or compute exact count
223        let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
224            if let Some(cache) = cache {
225                let key = (content_hash, model);
226                if let Some(count) = cache.get(&key) {
227                    return *count;
228                }
229                let count = self.tokenize_with_panic_guard(tokenizer, text, model);
230                maybe_cleanup_cache(cache);
231                cache.insert(key, count);
232                count
233            } else {
234                self.tokenize_with_panic_guard(tokenizer, text, model)
235            }
236        };
237
238        // Compute estimation stats once for all models
239        let stats = compute_estimation_stats(text);
240
241        // Compute exact OpenAI counts (only 2 tokenizer calls needed)
242        let o200k = if self.use_exact {
243            get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
244        } else {
245            estimate_from_stats(&stats, TokenModel::Gpt4o)
246        };
247
248        let cl100k = if self.use_exact {
249            get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
250        } else {
251            estimate_from_stats(&stats, TokenModel::Gpt4)
252        };
253
254        // Derive all estimation-based counts from same stats (no re-iteration)
255        TokenCounts {
256            o200k,
257            cl100k,
258            claude: estimate_from_stats(&stats, TokenModel::Claude),
259            gemini: estimate_from_stats(&stats, TokenModel::Gemini),
260            llama: estimate_from_stats(&stats, TokenModel::Llama),
261            mistral: estimate_from_stats(&stats, TokenModel::Mistral),
262            deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
263            qwen: estimate_from_stats(&stats, TokenModel::Qwen),
264            cohere: estimate_from_stats(&stats, TokenModel::Cohere),
265            grok: estimate_from_stats(&stats, TokenModel::Grok),
266        }
267    }
268
269    /// Estimate which model will have the lowest token count
270    pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
271        let counts = self.count_all(text);
272        let models = [
273            (TokenModel::Gpt4o, counts.o200k), // GPT-5.x, GPT-4o, O-series
274            (TokenModel::Gpt4, counts.cl100k), // Legacy GPT-4
275            (TokenModel::Claude, counts.claude),
276            (TokenModel::Gemini, counts.gemini),
277            (TokenModel::Llama, counts.llama),
278            (TokenModel::Mistral, counts.mistral),
279            (TokenModel::DeepSeek, counts.deepseek),
280            (TokenModel::Qwen, counts.qwen),
281            (TokenModel::Cohere, counts.cohere),
282            (TokenModel::Grok, counts.grok),
283        ];
284
285        // Safe: models array is non-empty, so min_by_key always returns Some
286        models
287            .into_iter()
288            .min_by_key(|(_, count)| *count)
289            .unwrap_or((TokenModel::Claude, 0))
290    }
291
292    /// Truncate text to fit within a token budget
293    pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
294        let current = self.count(text, model);
295        if current <= budget {
296            return text;
297        }
298
299        // Binary search for the right truncation point
300        let mut low = 0usize;
301        let mut high = text.len();
302
303        while low < high {
304            let mid_raw = (low + high).div_ceil(2);
305            // Find valid UTF-8 boundary (rounds down)
306            let mid = text.floor_char_boundary(mid_raw);
307
308            // CRITICAL: Prevent infinite loop when low and high converge within
309            // a multi-byte UTF-8 character. If floor_char_boundary rounds mid
310            // back to low, we can't make progress - break out.
311            if mid <= low {
312                break;
313            }
314
315            let count = self.count(&text[..mid], model);
316
317            if count <= budget {
318                low = mid;
319            } else {
320                high = mid.saturating_sub(1);
321            }
322        }
323
324        // Try to truncate at word boundary
325        let mut end = low;
326        while end > 0 {
327            let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
328            if c == b' ' || c == b'\n' {
329                break;
330            }
331            end -= 1;
332        }
333
334        if end > 0 {
335            &text[..end]
336        } else {
337            let low = text.floor_char_boundary(low);
338            &text[..low]
339        }
340    }
341
342    /// Check if text exceeds a token budget
343    pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
344        self.count(text, model) > budget
345    }
346}
347
348/// Quick estimation without creating a Tokenizer instance
349pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
350    if text.is_empty() {
351        return 0;
352    }
353    let chars_per_token = model.chars_per_token();
354    (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
355}
356
357/// Compute estimation stats in a single pass over the text.
358/// This is O(n) and only needs to be done once per text.
359fn compute_estimation_stats(text: &str) -> EstimationStats {
360    let mut whitespace_count = 0u32;
361    let mut newline_count = 0u32;
362    let mut special_char_count = 0u32;
363
364    // Single pass - count all character types at once using bytes for speed
365    for &byte in text.as_bytes() {
366        match byte {
367            b' ' | b'\t' => whitespace_count += 1,
368            b'\n' => newline_count += 1,
369            b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'='
370            | b'+' | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#'
371            | b'$' | b'%' | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
372            _ => {}
373        }
374    }
375
376    EstimationStats {
377        len: text.len(),
378        whitespace_count,
379        newline_count,
380        special_char_count,
381    }
382}
383
384/// Estimate tokens from pre-computed stats for a specific model.
385fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
386    let chars_per_token = model.chars_per_token();
387    let len = stats.len as f32;
388
389    // Base estimation
390    let mut estimate = len / chars_per_token;
391
392    // Whitespace adjustment (often merged with adjacent tokens)
393    estimate -= stats.whitespace_count as f32 * 0.3;
394
395    // Newline adjustment (usually single tokens)
396    estimate += stats.newline_count as f32 * 0.5;
397
398    // Code-focused models handle special chars differently
399    if matches!(
400        model,
401        TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
402    ) {
403        estimate += stats.special_char_count as f32 * 0.3;
404    }
405
406    estimate.ceil().max(1.0) as u32
407}