infiniloom_engine/tokenizer/
core.rs

1//! Core tokenizer implementation
2//!
3//! This module provides the main Tokenizer struct with accurate BPE tokenization
4//! for OpenAI models and estimation-based counting for other models.
5
6use super::counts::TokenCounts;
7use super::models::TokenModel;
8use dashmap::DashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::OnceLock;
11use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
12
13/// Global tokenizer instances (lazy initialized, thread-safe)
14static GPT4O_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
15static GPT4_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
16
17/// Global token count cache - keyed by (content_hash, model)
18/// This provides significant speedup when the same content is tokenized multiple times.
19static TOKEN_CACHE: OnceLock<DashMap<(u64, TokenModel), u32>> = OnceLock::new();
20
21/// Get or initialize the global token cache
22fn get_token_cache() -> &'static DashMap<(u64, TokenModel), u32> {
23    TOKEN_CACHE.get_or_init(DashMap::new)
24}
25
26/// Compute a fast hash of content for cache keys
27fn hash_content(content: &str) -> u64 {
28    use std::collections::hash_map::DefaultHasher;
29    let mut hasher = DefaultHasher::new();
30    content.hash(&mut hasher);
31    hasher.finish()
32}
33
34/// Get or initialize the GPT-4o tokenizer (o200k_base)
35fn get_gpt4o_tokenizer() -> &'static CoreBPE {
36    GPT4O_TOKENIZER.get_or_init(|| {
37        o200k_base().expect("tiktoken o200k_base initialization failed - please report this bug")
38    })
39}
40
41/// Get or initialize the GPT-4 tokenizer (cl100k_base)
42fn get_gpt4_tokenizer() -> &'static CoreBPE {
43    GPT4_TOKENIZER.get_or_init(|| {
44        cl100k_base().expect("tiktoken cl100k_base initialization failed - please report this bug")
45    })
46}
47
48/// Accurate token counter with fallback to estimation
49///
50/// The tokenizer supports caching to avoid re-computing token counts for the same content.
51/// This is particularly useful when processing files multiple times or across different
52/// operations.
53pub struct Tokenizer {
54    /// Use exact tokenization when available
55    use_exact: bool,
56    /// Use global cache for token counts
57    use_cache: bool,
58}
59
60impl Default for Tokenizer {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Tokenizer {
67    /// Create a new tokenizer with exact mode and caching enabled
68    pub fn new() -> Self {
69        Self { use_exact: true, use_cache: true }
70    }
71
72    /// Create a tokenizer that only uses estimation (faster but less accurate)
73    pub fn estimation_only() -> Self {
74        Self { use_exact: false, use_cache: true }
75    }
76
77    /// Create a tokenizer without caching (useful for benchmarks or one-off counts)
78    pub fn without_cache() -> Self {
79        Self { use_exact: true, use_cache: false }
80    }
81
82    /// Count tokens for a specific model.
83    ///
84    /// When caching is enabled, results are stored in a global cache keyed by
85    /// content hash and model. This provides significant speedup for repeated
86    /// tokenization of the same content.
87    ///
88    /// # Returns
89    ///
90    /// The token count for the specified model. For OpenAI models (GPT-4o, GPT-4, etc.),
91    /// this is exact via tiktoken. For other models, it's a calibrated estimation.
92    #[must_use]
93    pub fn count(&self, text: &str, model: TokenModel) -> u32 {
94        if text.is_empty() {
95            return 0;
96        }
97
98        if self.use_cache {
99            let cache = get_token_cache();
100            let content_hash = hash_content(text);
101            let key = (content_hash, model);
102
103            // Check cache first
104            if let Some(count) = cache.get(&key) {
105                return *count;
106            }
107
108            // Compute and cache
109            let count = self.count_uncached(text, model);
110            cache.insert(key, count);
111            count
112        } else {
113            self.count_uncached(text, model)
114        }
115    }
116
117    /// Count tokens without using cache
118    fn count_uncached(&self, text: &str, model: TokenModel) -> u32 {
119        if self.use_exact && model.has_exact_tokenizer() {
120            self.count_exact(text, model)
121        } else {
122            self.estimate(text, model)
123        }
124    }
125
126    /// Count tokens using exact BPE encoding.
127    /// Falls back to estimation if tiktoken panics (rare edge cases with unusual byte sequences).
128    fn count_exact(&self, text: &str, model: TokenModel) -> u32 {
129        if model.uses_o200k() {
130            // All modern OpenAI models use o200k_base encoding
131            // GPT-5.x, GPT-4o, O1, O3, O4
132            let tokenizer = get_gpt4o_tokenizer();
133            // Catch panics from tiktoken on malformed input
134            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
135                tokenizer.encode_ordinary(text).len() as u32
136            })) {
137                Ok(count) => count,
138                Err(_) => self.estimate(text, model), // Fallback to estimation on panic
139            }
140        } else if model.uses_cl100k() {
141            // Legacy OpenAI models use cl100k_base encoding
142            // GPT-4, GPT-3.5-turbo
143            let tokenizer = get_gpt4_tokenizer();
144            // Catch panics from tiktoken on malformed input
145            match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
146                tokenizer.encode_ordinary(text).len() as u32
147            })) {
148                Ok(count) => count,
149                Err(_) => self.estimate(text, model), // Fallback to estimation on panic
150            }
151        } else {
152            // Non-OpenAI models use estimation
153            self.estimate(text, model)
154        }
155    }
156
157    /// Estimate tokens using character-based heuristics
158    fn estimate(&self, text: &str, model: TokenModel) -> u32 {
159        if text.is_empty() {
160            return 0;
161        }
162
163        let chars_per_token = model.chars_per_token();
164        let len = text.len() as f32;
165
166        // Base estimation
167        let mut estimate = len / chars_per_token;
168
169        // Count whitespace (often merged with adjacent tokens)
170        let whitespace_count = text.chars().filter(|c| *c == ' ' || *c == '\t').count() as f32;
171        estimate -= whitespace_count * 0.3;
172
173        // Count newlines (usually single tokens)
174        let newline_count = text.chars().filter(|c| *c == '\n').count() as f32;
175        estimate += newline_count * 0.5;
176
177        // Adjust for special characters (often separate tokens)
178        let special_chars = text
179            .chars()
180            .filter(|c| {
181                matches!(
182                    c,
183                    '{' | '}'
184                        | '('
185                        | ')'
186                        | '['
187                        | ']'
188                        | ';'
189                        | ':'
190                        | ','
191                        | '.'
192                        | '='
193                        | '+'
194                        | '-'
195                        | '*'
196                        | '/'
197                        | '<'
198                        | '>'
199                        | '!'
200                        | '&'
201                        | '|'
202                        | '@'
203                        | '#'
204                        | '$'
205                        | '%'
206                        | '^'
207                        | '~'
208                        | '`'
209                        | '"'
210                        | '\''
211                )
212            })
213            .count() as f32;
214
215        // Code-focused models handle special chars differently
216        if matches!(
217            model,
218            TokenModel::CodeLlama | TokenModel::Claude | TokenModel::DeepSeek | TokenModel::Mistral
219        ) {
220            estimate += special_chars * 0.3;
221        }
222
223        estimate.ceil().max(1.0) as u32
224    }
225
226    /// Count tokens for all supported models at once
227    ///
228    /// Returns counts for representative models from each encoding family:
229    /// - `o200k`: GPT-5.x, GPT-4o, O1/O3/O4 (all use same tokenizer)
230    /// - `cl100k`: GPT-4, GPT-3.5-turbo (legacy, same tokenizer)
231    /// - Other vendors use estimation
232    pub fn count_all(&self, text: &str) -> TokenCounts {
233        TokenCounts {
234            // OpenAI o200k_base (GPT-5.x, GPT-4o, O-series all share this)
235            o200k: self.count(text, TokenModel::Gpt4o),
236            // OpenAI cl100k_base (legacy GPT-4, GPT-3.5)
237            cl100k: self.count(text, TokenModel::Gpt4),
238            // Other vendors (estimation-based)
239            claude: self.count(text, TokenModel::Claude),
240            gemini: self.count(text, TokenModel::Gemini),
241            llama: self.count(text, TokenModel::Llama),
242            mistral: self.count(text, TokenModel::Mistral),
243            deepseek: self.count(text, TokenModel::DeepSeek),
244            qwen: self.count(text, TokenModel::Qwen),
245            cohere: self.count(text, TokenModel::Cohere),
246            grok: self.count(text, TokenModel::Grok),
247        }
248    }
249
250    /// Estimate which model will have the lowest token count
251    pub fn most_efficient_model(&self, text: &str) -> (TokenModel, u32) {
252        let counts = self.count_all(text);
253        let models = [
254            (TokenModel::Gpt4o, counts.o200k), // GPT-5.x, GPT-4o, O-series
255            (TokenModel::Gpt4, counts.cl100k), // Legacy GPT-4
256            (TokenModel::Claude, counts.claude),
257            (TokenModel::Gemini, counts.gemini),
258            (TokenModel::Llama, counts.llama),
259            (TokenModel::Mistral, counts.mistral),
260            (TokenModel::DeepSeek, counts.deepseek),
261            (TokenModel::Qwen, counts.qwen),
262            (TokenModel::Cohere, counts.cohere),
263            (TokenModel::Grok, counts.grok),
264        ];
265
266        // Safe: models array is non-empty, so min_by_key always returns Some
267        models
268            .into_iter()
269            .min_by_key(|(_, count)| *count)
270            .unwrap_or((TokenModel::Claude, 0))
271    }
272
273    /// Truncate text to fit within a token budget
274    pub fn truncate_to_budget<'a>(&self, text: &'a str, model: TokenModel, budget: u32) -> &'a str {
275        let current = self.count(text, model);
276        if current <= budget {
277            return text;
278        }
279
280        // Binary search for the right truncation point
281        let mut low = 0usize;
282        let mut high = text.len();
283
284        while low < high {
285            let mid_raw = (low + high).div_ceil(2);
286            // Find valid UTF-8 boundary (rounds down)
287            let mid = text.floor_char_boundary(mid_raw);
288
289            // CRITICAL: Prevent infinite loop when low and high converge within
290            // a multi-byte UTF-8 character. If floor_char_boundary rounds mid
291            // back to low, we can't make progress - break out.
292            if mid <= low {
293                break;
294            }
295
296            let count = self.count(&text[..mid], model);
297
298            if count <= budget {
299                low = mid;
300            } else {
301                high = mid.saturating_sub(1);
302            }
303        }
304
305        // Try to truncate at word boundary
306        let mut end = low;
307        while end > 0 {
308            let c = text.as_bytes().get(end - 1).copied().unwrap_or(0);
309            if c == b' ' || c == b'\n' {
310                break;
311            }
312            end -= 1;
313        }
314
315        if end > 0 {
316            &text[..end]
317        } else {
318            let low = text.floor_char_boundary(low);
319            &text[..low]
320        }
321    }
322
323    /// Check if text exceeds a token budget
324    pub fn exceeds_budget(&self, text: &str, model: TokenModel, budget: u32) -> bool {
325        self.count(text, model) > budget
326    }
327}
328
329/// Quick estimation without creating a Tokenizer instance
330pub fn quick_estimate(text: &str, model: TokenModel) -> u32 {
331    if text.is_empty() {
332        return 0;
333    }
334    let chars_per_token = model.chars_per_token();
335    (text.len() as f32 / chars_per_token).ceil().max(1.0) as u32
336}