infiniloom_engine/tokenizer/
core.rs

1//! Core tokenizer implementation
2//!
3//! This module provides the main Tokenizer struct with accurate BPE tokenization
4//! for OpenAI models and estimation-based counting for other models.
5
6use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13/// Global tokenizer instances (lazy initialized, thread-safe)
14static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17/// Global token count cache - keyed by (content_hash, model)
18/// This provides significant speedup when the same content is tokenized multiple times.
19static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21/// Maximum number of entries in the token cache before eviction.
22/// 100K entries ≈ 2.4MB memory (24 bytes per entry: 8 + 8 + 4 + padding).
23/// This prevents unbounded memory growth in long-running processes.
24const MAX_CACHE_ENTRIES: usize = 100_000;
25
26/// Get or initialize the global token cache
27fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28    TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31/// Check if cache needs cleanup and clear if it exceeds the limit.
32/// Uses a simple strategy: when cache is full, clear it entirely.
33/// This is fast and avoids complex LRU tracking overhead.
34fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35    if cache.len() >= MAX_CACHE_ENTRIES {
36        cache.clear();
37    }
38}
39
40/// Compute a fast hash of content for cache keys
41fn hash_content(content: &str) -> u64 {
42    use std::collections::hash_map::DefaultHasher;
43    let mut hasher = DefaultHasher::new();
44    content.hash(&mut hasher);
45    hasher.finish()
46}
47
48/// Get or initialize the GPT-4o tokenizer (o200k_base)
49fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50    GPT4O_TOKENIZER.get_or_init(|| {
51        o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52    })
53}
54
55/// Get or initialize the GPT-4 tokenizer (cl100k_base)
56fn get_gpt4_tokenizer() -> &'static CoreBPE {
57    GPT4_TOKENIZER.get_or_init(|| {
58        cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59    })
60}
61
62/// Pre-computed statistics for token estimation.
63/// Computed once in a single pass, then used for all estimation-based models.
64#[derive(Clone, Copy)]
65struct EstimationStats {
66    len: usize,
67    whitespace_count: u32,
68    newline_count: u32,
69    special_char_count: u32,
70}
71
72/// Accurate token counter with fallback to estimation
73///
74/// The tokenizer supports caching to avoid re-computing token counts for the same content.
75/// This is particularly useful when processing files multiple times or across different
76/// operations.
77pub struct Tokenizer {
78    /// Use exact tokenization when available
79    use_exact: bool,
80    /// Use global cache for token counts
81    use_cache: bool,
82}
83
84impl Default for Tokenizer {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Tokenizer {
91    /// Create a new tokenizer with exact mode and caching enabled
92    pub fn new() -> Self {
93        Self { use_exact: true, use_cache: true }
94    }
95
96    /// Create a tokenizer that only uses estimation (faster but less accurate)
97    pub fn estimation_only() -> Self {
98        Self { use_exact: false, use_cache: true }
99    }
100
101    /// Create a tokenizer without caching (useful for benchmarks or one-off counts)
102    pub fn without_cache() -> Self {
103        Self { use_exact: true, use_cache: false }
104    }
105
106    /// Count tokens for a specific model.
107    ///
108    /// When caching is enabled, results are stored in a global cache keyed by
109    /// content hash and model. This provides significant speedup for repeated
110    /// tokenization of the same content.
111    ///
112    /// # Returns
113    ///
114    /// The token count for the specified model. For OpenAI models (GPT-4o, GPT-4, etc.),
115    /// this is exact via tiktoken. For other models, it's a calibrated estimation.
116    #[must_use]
117    pub fn count(&self, text: &str, model: TokenModel) -> u32 {
118        if text.is_empty() {
119            return 0;
120        }
121
122        if self.use_cache {
123            let cache = get_token_cache();
124            let content_hash = hash_content(text);
125            let key = (content_hash, model);
126
127            // Check cache first
128            if let Some(count) = cache.get(&key) {
129                return *count;
130            }
131
132            // Compute and cache (with size limit enforcement)
133            let count = self.count_uncached(text, model);
134            maybe_cleanup_cache(cache);
135            cache.insert(key, count);
136            count
137        } else {
138            self.count_uncached(text, model)
139        }
140    }
141
142    /// Count tokens without using cache
143    fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
144        if self.use_exact && model.has_exact_tokenizer() {
145            self.count_exact(text, model)
146        } else {
147            self.estimate(text, model)
148        }
149    }
150
151    /// Count tokens using exact BPE encoding.
152    /// Falls back to estimation if tiktoken panics (rare edge cases with unusual byte sequences).
153    /// Panic output is suppressed to avoid polluting stderr.
154    fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
155        if model.uses_o200k() {
156            // All modern OpenAI models use o200k_base encoding
157            // GPT-5.x, GPT-4o, O1, O3, O4
158            let tokenizer = get_gpt4o_tokenizer();
159            self.tokenize_with_panic_guard(tokenizer, text, model)
160        } else if model.uses_cl100k() {
161            // Legacy OpenAI models use cl100k_base encoding
162            // GPT-4, GPT-3.5-turbo
163            let tokenizer = get_gpt4_tokenizer();
164            self.tokenize_with_panic_guard(tokenizer, text, model)
165        } else {
166            // Non-OpenAI models use estimation
167            self.estimate(text, model)
168        }
169    }
170
171    /// Tokenize text with panic guard that suppresses stderr output.
172    /// This prevents panic stack traces from polluting application logs.
173    fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
174        // Temporarily suppress panic output by setting a no-op panic hook
175        let prev_hook = std::panic::take_hook();
176        std::panic::set_hook(Box::new(|_| {
177            // Silently ignore panic - we'll fall back to estimation
178        }));
179
180        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
181            tokenizer.encode_ordinary(text).len() as u32
182        }));
183
184        // Restore the previous panic hook
185        std::panic::set_hook(prev_hook);
186
187        match result {
188            Ok(count) => count,
189            Err(_) => self.estimate(text, model), // Fallback to estimation on panic
190        }
191    }
192
193    /// Estimate tokens using character-based heuristics.
194    /// Uses single-pass character counting for efficiency.
195    fn estimate(&self, text: &str, model: TokenModel) -> u32 {
196        if text.is_empty() {
197            return 0;
198        }
199        let stats = compute_estimation_stats(text);
200        estimate_from_stats(&stats, model)
201    }
202
203    /// Count tokens for all supported models at once.
204    ///
205    /// **Optimized**: Computes hash once, estimation stats once, and reuses them
206    /// for all models. This is ~10x faster than calling count() 10 times.
207    ///
208    /// Returns counts for representative models from each encoding family:
209    /// - `o200k`: GPT-5.x, GPT-4o, O1/O3/O4 (all use same tokenizer)
210    /// - `cl100k`: GPT-4, GPT-3.5-turbo (legacy, same tokenizer)
211    /// - Other vendors use estimation
212    pub fn count_all(&self, text: &str) -> TokenCounts {
213        if text.is_empty() {
214            return TokenCounts::default();
215        }
216
217        // Compute hash once for cache lookups
218        let content_hash = hash_content(text);
219        let cache = if self.use_cache {
220            Some(get_token_cache())
221        } else {
222            None
223        };
224
225        // Helper to get cached or compute exact count
226        let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
227            if let Some(cache) = cache {
228                let key = (content_hash, model);
229                if let Some(count) = cache.get(&key) {
230                    return *count;
231                }
232                let count = self.tokenize_with_panic_guard(tokenizer, text, model);
233                maybe_cleanup_cache(cache);
234                cache.insert(key, count);
235                count
236            } else {
237                self.tokenize_with_panic_guard(tokenizer, text, model)
238            }
239        };
240
241        // Compute estimation stats once for all models
242        let stats = compute_estimation_stats(text);
243
244        // Compute exact OpenAI counts (only 2 tokenizer calls needed)
245        let o200k = if self.use_exact {
246            get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
247        } else {
248            estimate_from_stats(&stats, TokenModel::Gpt4o)
249        };
250
251        let cl100k = if self.use_exact {
252            get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
253        } else {
254            estimate_from_stats(&stats, TokenModel::Gpt4)
255        };
256
257        // Derive all estimation-based counts from same stats (no re-iteration)
258        TokenCounts {
259            o200k,
260            cl100k,
261            claude: estimate_from_stats(&stats, TokenModel::Claude),
262            gemini: estimate_from_stats(&stats, TokenModel::Gemini),
263            llama: estimate_from_stats(&stats, TokenModel::Llama),
264            mistral: estimate_from_stats(&stats, TokenModel::Mistral),
265            deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
266            qwen: estimate_from_stats(&stats, TokenModel::Qwen),
267            cohere: estimate_from_stats(&stats, TokenModel::Cohere),
268            grok: estimate_from_stats(&stats, TokenModel::Grok),
269        }
270    }
271
272    /// Estimate which model will have the lowest token count
273    pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
274        let counts = self.count_all(text);
275        let models = [
276            (TokenModel::Gpt4o, counts.o200k), // GPT-5.x, GPT-4o, O-series
277            (TokenModel::Gpt4, counts.cl100k), // Legacy GPT-4
278            (TokenModel::Claude, counts.claude),
279            (TokenModel::Gemini, counts.gemini),
280            (TokenModel::Llama, counts.llama),
281            (TokenModel::Mistral, counts.mistral),
282            (TokenModel::DeepSeek, counts.deepseek),
283            (TokenModel::Qwen, counts.qwen),
284            (TokenModel::Cohere, counts.cohere),
285            (TokenModel::Grok, counts.grok),
286        ];
287
288        // Safe: models array is non-empty, so min_by_key always returns Some
289        models
290            .into_iter()
291            .min_by_key(|(_, count)| *count)
292            .unwrap_or((TokenModel::Claude, 0))
293    }
294
295    /// Truncate text to fit within a token budget
296    pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
297        let current = self.count(text, model);
298        if current <= budget {
299            return text;
300        }
301
302        // Binary search for the right truncation point
303        let mut low = 0usize;
304        let mut high = text.len();
305
306        while low < high {
307            let mid_raw = (low + high).div_ceil(2);
308            // Find valid UTF-8 boundary (rounds down)
309            let mid = text.floor_char_boundary(mid_raw);
310
311            // CRITICAL: Prevent infinite loop when low and high converge within
312            // a multi-byte UTF-8 character. If floor_char_boundary rounds mid
313            // back to low, we can't make progress - break out.
314            if mid <= low {
315                break;
316            }
317
318            let count = self.count(&text[..mid], model);
319
320            if count <= budget {
321                low = mid;
322            } else {
323                high = mid.saturating_sub(1);
324            }
325        }
326
327        // Try to truncate at word boundary
328        let mut end = low;
329        while end > 0 {
330            let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
331            if c == b' ' || c == b'\n' {
332                break;
333            }
334            end -= 1;
335        }
336
337        if end > 0 {
338            &text[..end]
339        } else {
340            let low = text.floor_char_boundary(low);
341            &text[..low]
342        }
343    }
344
345    /// Check if text exceeds a token budget
346    pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
347        self.count(text, model) > budget
348    }
349}
350
351/// Quick estimation without creating a Tokenizer instance
352pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
353    if text.is_empty() {
354        return 0;
355    }
356    let chars_per_token = model.chars_per_token();
357    (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
358}
359
360/// Compute estimation stats in a single pass over the text.
361/// This is O(n) and only needs to be done once per text.
362fn compute_estimation_stats(text: &str) -> EstimationStats {
363    let mut whitespace_count = 0u32;
364    let mut newline_count = 0u32;
365    let mut special_char_count = 0u32;
366
367    // Single pass - count all character types at once using bytes for speed
368    for &byte in text.as_bytes() {
369        match byte {
370            b' ' | b'\t' => whitespace_count += 1,
371            b'\n' => newline_count += 1,
372            b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'=' | b'+'
373            | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#' | b'$' | b'%'
374            | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
375            _ => {},
376        }
377    }
378
379    EstimationStats { len: text.len(), whitespace_count, newline_count, special_char_count }
380}
381
382/// Estimate tokens from pre-computed stats for a specific model.
383fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
384    let chars_per_token = model.chars_per_token();
385    let len = stats.len as f32;
386
387    // Base estimation
388    let mut estimate = len / chars_per_token;
389
390    // Whitespace adjustment (often merged with adjacent tokens)
391    estimate -= stats.whitespace_count as f32 * 0.3;
392
393    // Newline adjustment (usually single tokens)
394    estimate += stats.newline_count as f32 * 0.5;
395
396    // Code-focused models handle special chars differently
397    if matches!(
398        model,
399        TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
400    ) {
401        estimate += stats.special_char_count as f32 * 0.3;
402    }
403
404    estimate.ceil().max(1.0) as u32
405}