Skip to main content

infiniloom_engine/tokenizer/
core.rs

1//! Core tokenizer implementation
2//!
3//! This module provides the main Tokenizer struct with accurate BPE tokenization
4//! for OpenAI models and estimation-based counting for other models.
5
6use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13/// Global tokenizer instances (lazy initialized, thread-safe)
14static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17/// Cache entry with access timestamp for LRU-like eviction
18struct CacheEntry {
19    count: u32,
20    /// Timestamp of last access (seconds since epoch, truncated for space)
21    last_access: u32,
22}
23
24/// Global token count cache - keyed by (content_hash, model)
25/// This provides significant speedup when the same content is tokenized multiple times.
26static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), CacheEntry>> = OnceLock::new();
27
28/// Maximum number of entries in the token cache before eviction.
29/// 100K entries ≈ 3.2MB memory (32 bytes per entry with CacheEntry).
30/// This prevents unbounded memory growth in long-running processes.
31const MAX_CACHE_ENTRIES: usize = 100_000;
32
33/// Number of entries to evict when cache is full (50% = evict half)
34const EVICTION_FRACTION: usize = 2;
35
36/// Get or initialize the global token cache
37fn get_token_cache() -> &'static DashMap<(u64, TokenModel), CacheEntry> {
38    TOKEN_CACHE.get_or_init(DashMap::new)
39}
40
41/// Get current timestamp (seconds since epoch, truncated to u32)
42fn current_timestamp() -> u32 {
43    std::time::SystemTime::now()
44        .duration_since(std::time::UNIX_EPOCH)
45        .map(|d| d.as_secs() as u32)
46        .unwrap_or(0)
47}
48
49/// Check if cache needs cleanup and evict oldest entries if it exceeds the limit.
50/// Uses an LRU-like strategy: when cache is full, remove the oldest half of entries.
51/// This preserves recently accessed entries while still being fast.
52fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), CacheEntry>) {
53    if cache.len() < MAX_CACHE_ENTRIES {
54        return;
55    }
56
57    // Collect entries with their timestamps for sorting
58    let mut entries: Vec<((u64, TokenModel), u32)> = cache
59        .iter()
60        .map(|entry| (*entry.key(), entry.value().last_access))
61        .collect();
62
63    // Sort by last_access (oldest first)
64    entries.sort_by_key(|(_, ts)| *ts);
65
66    // Remove oldest half
67    let to_remove = entries.len() / EVICTION_FRACTION;
68    for (key, _) in entries.into_iter().take(to_remove) {
69        cache.remove(&key);
70    }
71}
72
73/// Compute a fast hash of content for cache keys
74fn hash_content(content: &str) -> u64 {
75    use std::collections::hash_map::DefaultHasher;
76    let mut hasher = DefaultHasher::new();
77    content.hash(&mut hasher);
78    hasher.finish()
79}
80
81/// Get or initialize the GPT-4o tokenizer (o200k_base)
82fn get_gpt4o_tokenizer() -> &'static CoreBPE {
83    GPT4O_TOKENIZER.get_or_init(|| {
84        o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
85    })
86}
87
88/// Get or initialize the GPT-4 tokenizer (cl100k_base)
89fn get_gpt4_tokenizer() -> &'static CoreBPE {
90    GPT4_TOKENIZER.get_or_init(|| {
91        cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
92    })
93}
94
95/// Pre-computed statistics for token estimation.
96/// Computed once in a single pass, then used for all estimation-based models.
97#[derive(Clone, Copy)]
98struct EstimationStats {
99    len: usize,
100    whitespace_count: u32,
101    newline_count: u32,
102    special_char_count: u32,
103}
104
105/// Accurate token counter with fallback to estimation
106///
107/// The tokenizer supports caching to avoid re-computing token counts for the same content.
108/// This is particularly useful when processing files multiple times or across different
109/// operations.
110pub struct Tokenizer {
111    /// Use exact tokenization when available
112    use_exact: bool,
113    /// Use global cache for token counts
114    use_cache: bool,
115}
116
117impl Default for Tokenizer {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl Tokenizer {
124    /// Create a new tokenizer with exact mode and caching enabled
125    pub fn new() -> Self {
126        Self { use_exact: true, use_cache: true }
127    }
128
129    /// Create a tokenizer that only uses estimation (faster but less accurate)
130    pub fn estimation_only() -> Self {
131        Self { use_exact: false, use_cache: true }
132    }
133
134    /// Create a tokenizer without caching (useful for benchmarks or one-off counts)
135    pub fn without_cache() -> Self {
136        Self { use_exact: true, use_cache: false }
137    }
138
139    /// Count tokens for a specific model.
140    ///
141    /// When caching is enabled, results are stored in a global cache keyed by
142    /// content hash and model. This provides significant speedup for repeated
143    /// tokenization of the same content.
144    ///
145    /// # Returns
146    ///
147    /// The token count for the specified model. For OpenAI models (GPT-4o, GPT-4, etc.),
148    /// this is exact via tiktoken. For other models, it's a calibrated estimation.
149    #[must_use]
150    pub fn count(&self, text: &str, model: TokenModel) -> u32 {
151        if text.is_empty() {
152            return 0;
153        }
154
155        if self.use_cache {
156            let cache = get_token_cache();
157            let content_hash = hash_content(text);
158            let key = (content_hash, model);
159            let now = current_timestamp();
160
161            // Check cache first and update access time
162            if let Some(mut entry) = cache.get_mut(&key) {
163                entry.last_access = now;
164                return entry.count;
165            }
166
167            // Compute and cache (with size limit enforcement)
168            let count = self.count_uncached(text, model);
169            maybe_cleanup_cache(cache);
170            cache.insert(key, CacheEntry { count, last_access: now });
171            count
172        } else {
173            self.count_uncached(text, model)
174        }
175    }
176
177    /// Count tokens without using cache
178    fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
179        if self.use_exact && model.has_exact_tokenizer() {
180            self.count_exact(text, model)
181        } else {
182            self.estimate(text, model)
183        }
184    }
185
186    /// Count tokens using exact BPE encoding.
187    /// Falls back to estimation if tiktoken panics (rare edge cases with unusual byte sequences).
188    /// Panic output is suppressed to avoid polluting stderr.
189    fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
190        if model.uses_o200k() {
191            // All modern OpenAI models use o200k_base encoding
192            // GPT-5.x, GPT-4o, O1, O3, O4
193            let tokenizer = get_gpt4o_tokenizer();
194            self.tokenize_with_panic_guard(tokenizer, text, model)
195        } else if model.uses_cl100k() {
196            // Legacy OpenAI models use cl100k_base encoding
197            // GPT-4, GPT-3.5-turbo
198            let tokenizer = get_gpt4_tokenizer();
199            self.tokenize_with_panic_guard(tokenizer, text, model)
200        } else {
201            // Non-OpenAI models use estimation
202            self.estimate(text, model)
203        }
204    }
205
206    /// Tokenize text with panic guard that suppresses stderr output.
207    /// This prevents panic stack traces from polluting application logs.
208    fn tokenize_with_panic_guard(&self, tokenizer: &CoreBPE, text: &str, model: TokenModel) -> u32 {
209        // Temporarily suppress panic output by setting a no-op panic hook
210        let prev_hook = std::panic::take_hook();
211        std::panic::set_hook(Box::new(|_| {
212            // Silently ignore panic - we'll fall back to estimation
213        }));
214
215        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
216            tokenizer.encode_ordinary(text).len() as u32
217        }));
218
219        // Restore the previous panic hook
220        std::panic::set_hook(prev_hook);
221
222        match result {
223            Ok(count) => count,
224            Err(_) => self.estimate(text, model), // Fallback to estimation on panic
225        }
226    }
227
228    /// Estimate tokens using character-based heuristics.
229    /// Uses single-pass character counting for efficiency.
230    fn estimate(&self, text: &str, model: TokenModel) -> u32 {
231        if text.is_empty() {
232            return 0;
233        }
234        let stats = compute_estimation_stats(text);
235        estimate_from_stats(&stats, model)
236    }
237
238    /// Count tokens for all supported models at once.
239    ///
240    /// **Optimized**: Computes hash once, estimation stats once, and reuses them
241    /// for all models. This is ~10x faster than calling count() 10 times.
242    ///
243    /// Returns counts for representative models from each encoding family:
244    /// - `o200k`: GPT-5.x, GPT-4o, O1/O3/O4 (all use same tokenizer)
245    /// - `cl100k`: GPT-4, GPT-3.5-turbo (legacy, same tokenizer)
246    /// - Other vendors use estimation
247    pub fn count_all(&self, text: &str) -> TokenCounts {
248        if text.is_empty() {
249            return TokenCounts::default();
250        }
251
252        // Compute hash once for cache lookups
253        let content_hash = hash_content(text);
254        let cache = if self.use_cache {
255            Some(get_token_cache())
256        } else {
257            None
258        };
259
260        // Helper to get cached or compute exact count
261        let now = current_timestamp();
262        let get_exact = |model: TokenModel, tokenizer: &CoreBPE| -> u32 {
263            if let Some(cache) = cache {
264                let key = (content_hash, model);
265                if let Some(mut entry) = cache.get_mut(&key) {
266                    entry.last_access = now;
267                    return entry.count;
268                }
269                let count = self.tokenize_with_panic_guard(tokenizer, text, model);
270                maybe_cleanup_cache(cache);
271                cache.insert(key, CacheEntry { count, last_access: now });
272                count
273            } else {
274                self.tokenize_with_panic_guard(tokenizer, text, model)
275            }
276        };
277
278        // Compute estimation stats once for all models
279        let stats = compute_estimation_stats(text);
280
281        // Compute exact OpenAI counts (only 2 tokenizer calls needed)
282        let o200k = if self.use_exact {
283            get_exact(TokenModel::Gpt4o, get_gpt4o_tokenizer())
284        } else {
285            estimate_from_stats(&stats, TokenModel::Gpt4o)
286        };
287
288        let cl100k = if self.use_exact {
289            get_exact(TokenModel::Gpt4, get_gpt4_tokenizer())
290        } else {
291            estimate_from_stats(&stats, TokenModel::Gpt4)
292        };
293
294        // Derive all estimation-based counts from same stats (no re-iteration)
295        TokenCounts {
296            o200k,
297            cl100k,
298            claude: estimate_from_stats(&stats, TokenModel::Claude),
299            gemini: estimate_from_stats(&stats, TokenModel::Gemini),
300            llama: estimate_from_stats(&stats, TokenModel::Llama),
301            mistral: estimate_from_stats(&stats, TokenModel::Mistral),
302            deepseek: estimate_from_stats(&stats, TokenModel::DeepSeek),
303            qwen: estimate_from_stats(&stats, TokenModel::Qwen),
304            cohere: estimate_from_stats(&stats, TokenModel::Cohere),
305            grok: estimate_from_stats(&stats, TokenModel::Grok),
306        }
307    }
308
309    /// Estimate which model will have the lowest token count
310    pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
311        let counts = self.count_all(text);
312        let models = [
313            (TokenModel::Gpt4o, counts.o200k), // GPT-5.x, GPT-4o, O-series
314            (TokenModel::Gpt4, counts.cl100k), // Legacy GPT-4
315            (TokenModel::Claude, counts.claude),
316            (TokenModel::Gemini, counts.gemini),
317            (TokenModel::Llama, counts.llama),
318            (TokenModel::Mistral, counts.mistral),
319            (TokenModel::DeepSeek, counts.deepseek),
320            (TokenModel::Qwen, counts.qwen),
321            (TokenModel::Cohere, counts.cohere),
322            (TokenModel::Grok, counts.grok),
323        ];
324
325        // Safe: models array is non-empty, so min_by_key always returns Some
326        models
327            .into_iter()
328            .min_by_key(|(_, count)| *count)
329            .unwrap_or((TokenModel::Claude, 0))
330    }
331
332    /// Truncate text to fit within a token budget
333    ///
334    /// # Safety
335    /// All string slicing uses `floor_char_boundary` to ensure valid UTF-8 boundaries.
336    pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
337        let current = self.count(text, model);
338        if current <= budget {
339            return text;
340        }
341
342        // Binary search for the right truncation point
343        let mut low = 0usize;
344        let mut high = text.len();
345
346        while low < high {
347            let mid_raw = (low + high).div_ceil(2);
348            // Find valid UTF-8 boundary (rounds down)
349            let mid = text.floor_char_boundary(mid_raw);
350
351            // CRITICAL: Prevent infinite loop when low and high converge within
352            // a multi-byte UTF-8 character. If floor_char_boundary rounds mid
353            // back to low, we can't make progress - break out.
354            if mid <= low {
355                break;
356            }
357
358            let count = self.count(&text[..mid], model);
359
360            if count <= budget {
361                low = mid;
362            } else {
363                high = mid.saturating_sub(1);
364            }
365        }
366
367        // Ensure low is at a valid char boundary before proceeding
368        let low = text.floor_char_boundary(low);
369
370        // Try to truncate at word boundary while staying within char boundaries
371        let mut end = low;
372        while end > 0 {
373            // First, ensure we're at a char boundary
374            let boundary = text.floor_char_boundary(end);
375            if boundary < end {
376                end = boundary;
377                continue;
378            }
379
380            let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
381            // Only check ASCII whitespace (won't be in middle of multi-byte char)
382            if c == b' ' || c == b'\n' {
383                break;
384            }
385            end -= 1;
386        }
387
388        // Final safety check: ensure end is at a valid char boundary
389        let end = text.floor_char_boundary(end);
390
391        if end > 0 {
392            &text[..end]
393        } else {
394            // Fall back to the binary search result
395            &text[..low]
396        }
397    }
398
399    /// Check if text exceeds a token budget
400    pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
401        self.count(text, model) > budget
402    }
403}
404
405/// Quick estimation without creating a Tokenizer instance
406pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
407    if text.is_empty() {
408        return 0;
409    }
410    let chars_per_token = model.chars_per_token();
411    (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
412}
413
414/// Compute estimation stats in a single pass over the text.
415/// This is O(n) and only needs to be done once per text.
416fn compute_estimation_stats(text: &str) -> EstimationStats {
417    let mut whitespace_count = 0u32;
418    let mut newline_count = 0u32;
419    let mut special_char_count = 0u32;
420
421    // Single pass - count all character types at once using bytes for speed
422    for &byte in text.as_bytes() {
423        match byte {
424            b' ' | b'\t' => whitespace_count += 1,
425            b'\n' => newline_count += 1,
426            b'{' | b'}' | b'(' | b')' | b'[' | b']' | b';' | b':' | b',' | b'.' | b'=' | b'+'
427            | b'-' | b'*' | b'/' | b'<' | b'>' | b'!' | b'&' | b'|' | b'@' | b'#' | b'$' | b'%'
428            | b'^' | b'~' | b'`' | b'"' | b'\'' => special_char_count += 1,
429            _ => {},
430        }
431    }
432
433    EstimationStats { len: text.len(), whitespace_count, newline_count, special_char_count }
434}
435
436/// Estimate tokens from pre-computed stats for a specific model.
437fn estimate_from_stats(stats: &EstimationStats, model: TokenModel) -> u32 {
438    let chars_per_token = model.chars_per_token();
439    let len = stats.len as f32;
440
441    // Base estimation
442    let mut estimate = len / chars_per_token;
443
444    // Whitespace adjustment (often merged with adjacent tokens)
445    estimate -= stats.whitespace_count as f32 * 0.3;
446
447    // Newline adjustment (usually single tokens)
448    estimate += stats.newline_count as f32 * 0.5;
449
450    // Code-focused models handle special chars differently
451    if matches!(
452        model,
453        TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
454    ) {
455        estimate += stats.special_char_count as f32 * 0.3;
456    }
457
458    estimate.ceil().max(1.0) as u32
459}