infiniloom_engine/tokenizer/
core.rs

1//! Core tokenizer implementation
2//!
3//! This module provides the main Tokenizer struct with accurate BPE tokenization
4//! for OpenAI models and estimation-based counting for other models.
5
6use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13/// Global tokenizer instances (lazy initialized, thread-safe)
14static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17/// Global token count cache - keyed by (content_hash, model)
18/// This provides significant speedup when the same content is tokenized multiple times.
19static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21/// Maximum number of entries in the token cache before eviction.
22/// 100K entries ≈ 2.4MB memory (24 bytes per entry: 8 + 8 + 4 + padding).
23/// This prevents unbounded memory growth in long-running processes.
24const MAX_CACHE_ENTRIES: usize = 100_000;
25
26/// Get or initialize the global token cache
27fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
28    TOKEN_CACHE.get_or_init(DashMap::new)
29}
30
31/// Check if cache needs cleanup and clear if it exceeds the limit.
32/// Uses a simple strategy: when cache is full, clear it entirely.
33/// This is fast and avoids complex LRU tracking overhead.
34fn maybe_cleanup_cache(cache: &DashMap<(u64, TokenModel), u32>) {
35    if cache.len() >= MAX_CACHE_ENTRIES {
36        cache.clear();
37    }
38}
39
40/// Compute a fast hash of content for cache keys
41fn hash_content(content: &str) -> u64 {
42    use std::collections::hash_map::DefaultHasher;
43    let mut hasher = DefaultHasher::new();
44    content.hash(&mut hasher);
45    hasher.finish()
46}
47
48/// Get or initialize the GPT-4o tokenizer (o200k_base)
49fn get_gpt4o_tokenizer() -> &'static CoreBPE {
50    GPT4O_TOKENIZER.get_or_init(|| {
51        o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
52    })
53}
54
55/// Get or initialize the GPT-4 tokenizer (cl100k_base)
56fn get_gpt4_tokenizer() -> &'static CoreBPE {
57    GPT4_TOKENIZER.get_or_init(|| {
58        cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
59    })
60}
61
62/// Accurate token counter with fallback to estimation
63///
64/// The tokenizer supports caching to avoid re-computing token counts for the same content.
65/// This is particularly useful when processing files multiple times or across different
66/// operations.
67pub struct Tokenizer {
68    /// Use exact tokenization when available
69    use_exact: bool,
70    /// Use global cache for token counts
71    use_cache: bool,
72}
73
74impl Default for Tokenizer {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Tokenizer {
81    /// Create a new tokenizer with exact mode and caching enabled
82    pub fn new() -> Self {
83        Self { use_exact: true, use_cache: true }
84    }
85
86    /// Create a tokenizer that only uses estimation (faster but less accurate)
87    pub fn estimation_only() -> Self {
88        Self { use_exact: false, use_cache: true }
89    }
90
91    /// Create a tokenizer without caching (useful for benchmarks or one-off counts)
92    pub fn without_cache() -> Self {
93        Self { use_exact: true, use_cache: false }
94    }
95
96    /// Count tokens for a specific model.
97    ///
98    /// When caching is enabled, results are stored in a global cache keyed by
99    /// content hash and model. This provides significant speedup for repeated
100    /// tokenization of the same content.
101    ///
102    /// # Returns
103    ///
104    /// The token count for the specified model. For OpenAI models (GPT-4o, GPT-4, etc.),
105    /// this is exact via tiktoken. For other models, it's a calibrated estimation.
106    #[must_use]
107    pub fn count(&self, text: &str, model: TokenModel) -> u32 {
108        if text.is_empty() {
109            return 0;
110        }
111
112        if self.use_cache {
113            let cache = get_token_cache();
114            let content_hash = hash_content(text);
115            let key = (content_hash, model);
116
117            // Check cache first
118            if let Some(count) = cache.get(&key) {
119                return *count;
120            }
121
122            // Compute and cache (with size limit enforcement)
123            let count = self.count_uncached(text, model);
124            maybe_cleanup_cache(cache);
125            cache.insert(key, count);
126            count
127        } else {
128            self.count_uncached(text, model)
129        }
130    }
131
132    /// Count tokens without using cache
133    fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
134        if self.use_exact && model.has_exact_tokenizer() {
135            self.count_exact(text, model)
136        } else {
137            self.estimate(text, model)
138        }
139    }
140
141    /// Count tokens using exact BPE encoding.
142    /// Falls back to estimation if tiktoken panics (rare edge cases with unusual byte sequences).
143    fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
144        if model.uses_o200k() {
145            // All modern OpenAI models use o200k_base encoding
146            // GPT-5.x, GPT-4o, O1, O3, O4
147            let tokenizer = get_gpt4o_tokenizer();
148            // Catch panics from tiktoken on malformed input
149            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
150                tokenizer.encode_ordinary(text).len() as u32
151            })) {
152                Ok(count) => count,
153                Err(_) => self.estimate(text, model), // Fallback to estimation on panic
154            }
155        } else if model.uses_cl100k() {
156            // Legacy OpenAI models use cl100k_base encoding
157            // GPT-4, GPT-3.5-turbo
158            let tokenizer = get_gpt4_tokenizer();
159            // Catch panics from tiktoken on malformed input
160            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
161                tokenizer.encode_ordinary(text).len() as u32
162            })) {
163                Ok(count) => count,
164                Err(_) => self.estimate(text, model), // Fallback to estimation on panic
165            }
166        } else {
167            // Non-OpenAI models use estimation
168            self.estimate(text, model)
169        }
170    }
171
172    /// Estimate tokens using character-based heuristics
173    fn estimate(&self, text: &str, model: TokenModel) -> u32 {
174        if text.is_empty() {
175            return 0;
176        }
177
178        let chars_per_token = model.chars_per_token();
179        let len = text.len() as f32;
180
181        // Base estimation
182        let mut estimate = len / chars_per_token;
183
184        // Count whitespace (often merged with adjacent tokens)
185        let whitespace_count = text.chars().filter(|c| *c == ' ' || *c == '\t').count() as f32;
186        estimate -= whitespace_count * 0.3;
187
188        // Count newlines (usually single tokens)
189        let newline_count = text.chars().filter(|c| *c == '\n').count() as f32;
190        estimate += newline_count * 0.5;
191
192        // Adjust for special characters (often separate tokens)
193        let special_chars = text
194            .chars()
195            .filter(|c| {
196                matches!(
197                    c,
198                    '{' | '}'
199                        | '('
200                        | ')'
201                        | '['
202                        | ']'
203                        | ';'
204                        | ':'
205                        | ','
206                        | '.'
207                        | '='
208                        | '+'
209                        | '-'
210                        | '*'
211                        | '/'
212                        | '<'
213                        | '>'
214                        | '!'
215                        | '&'
216                        | '|'
217                        | '@'
218                        | '#'
219                        | '$'
220                        | '%'
221                        | '^'
222                        | '~'
223                        | '`'
224                        | '"'
225                        | '\''
226                )
227            })
228            .count() as f32;
229
230        // Code-focused models handle special chars differently
231        if matches!(
232            model,
233            TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
234        ) {
235            estimate += special_chars * 0.3;
236        }
237
238        estimate.ceil().max(1.0) as u32
239    }
240
241    /// Count tokens for all supported models at once
242    ///
243    /// Returns counts for representative models from each encoding family:
244    /// - `o200k`: GPT-5.x, GPT-4o, O1/O3/O4 (all use same tokenizer)
245    /// - `cl100k`: GPT-4, GPT-3.5-turbo (legacy, same tokenizer)
246    /// - Other vendors use estimation
247    pub fn count_all(&self, text: &str) -> TokenCounts {
248        TokenCounts {
249            // OpenAI o200k_base (GPT-5.x, GPT-4o, O-series all share this)
250            o200k: self.count(text, TokenModel::Gpt4o),
251            // OpenAI cl100k_base (legacy GPT-4, GPT-3.5)
252            cl100k: self.count(text, TokenModel::Gpt4),
253            // Other vendors (estimation-based)
254            claude: self.count(text, TokenModel::Claude),
255            gemini: self.count(text, TokenModel::Gemini),
256            llama: self.count(text, TokenModel::Llama),
257            mistral: self.count(text, TokenModel::Mistral),
258            deepseek: self.count(text, TokenModel::DeepSeek),
259            qwen: self.count(text, TokenModel::Qwen),
260            cohere: self.count(text, TokenModel::Cohere),
261            grok: self.count(text, TokenModel::Grok),
262        }
263    }
264
265    /// Estimate which model will have the lowest token count
266    pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
267        let counts = self.count_all(text);
268        let models = [
269            (TokenModel::Gpt4o, counts.o200k), // GPT-5.x, GPT-4o, O-series
270            (TokenModel::Gpt4, counts.cl100k), // Legacy GPT-4
271            (TokenModel::Claude, counts.claude),
272            (TokenModel::Gemini, counts.gemini),
273            (TokenModel::Llama, counts.llama),
274            (TokenModel::Mistral, counts.mistral),
275            (TokenModel::DeepSeek, counts.deepseek),
276            (TokenModel::Qwen, counts.qwen),
277            (TokenModel::Cohere, counts.cohere),
278            (TokenModel::Grok, counts.grok),
279        ];
280
281        // Safe: models array is non-empty, so min_by_key always returns Some
282        models
283            .into_iter()
284            .min_by_key(|(_, count)| *count)
285            .unwrap_or((TokenModel::Claude, 0))
286    }
287
288    /// Truncate text to fit within a token budget
289    pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
290        let current = self.count(text, model);
291        if current <= budget {
292            return text;
293        }
294
295        // Binary search for the right truncation point
296        let mut low = 0usize;
297        let mut high = text.len();
298
299        while low < high {
300            let mid_raw = (low + high).div_ceil(2);
301            // Find valid UTF-8 boundary (rounds down)
302            let mid = text.floor_char_boundary(mid_raw);
303
304            // CRITICAL: Prevent infinite loop when low and high converge within
305            // a multi-byte UTF-8 character. If floor_char_boundary rounds mid
306            // back to low, we can't make progress - break out.
307            if mid <= low {
308                break;
309            }
310
311            let count = self.count(&text[..mid], model);
312
313            if count <= budget {
314                low = mid;
315            } else {
316                high = mid.saturating_sub(1);
317            }
318        }
319
320        // Try to truncate at word boundary
321        let mut end = low;
322        while end > 0 {
323            let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
324            if c == b' ' || c == b'\n' {
325                break;
326            }
327            end -= 1;
328        }
329
330        if end > 0 {
331            &text[..end]
332        } else {
333            let low = text.floor_char_boundary(low);
334            &text[..low]
335        }
336    }
337
338    /// Check if text exceeds a token budget
339    pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
340        self.count(text, model) > budget
341    }
342}
343
344/// Quick estimation without creating a Tokenizer instance
345pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
346    if text.is_empty() {
347        return 0;
348    }
349    let chars_per_token = model.chars_per_token();
350    (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
351}