Skip to main content

aster/context/
token_estimator.rs

1//! Token Estimator Module
2//!
3//! Provides accurate token estimation for different content types including:
4//! - Asian characters (Chinese, Japanese, Korean)
5//! - Code content
6//! - Regular English text
7//!
8//! # Token Estimation Strategy
9//!
10//! Different content types have different character-to-token ratios:
11//! - Asian text: ~2 characters per token
12//! - Code: ~3 characters per token
13//! - English text: ~3.5 characters per token
14//!
15//! Special characters and newlines add additional weight.
16
17use crate::context::types::{CHARS_PER_TOKEN_ASIAN, CHARS_PER_TOKEN_CODE, CHARS_PER_TOKEN_DEFAULT};
18use crate::conversation::message::{Message, MessageContent};
19
20/// Message overhead in tokens (role, formatting, etc.)
21const MESSAGE_OVERHEAD_TOKENS: usize = 4;
22
23/// Token Estimator for different content types.
24///
25/// Provides methods to estimate token counts for text, messages, and message arrays.
26pub struct TokenEstimator;
27
28impl TokenEstimator {
29    /// Estimate the number of tokens in a text string.
30    ///
31    /// Uses different character-per-token ratios based on content type:
32    /// - Asian characters: ~2 chars/token
33    /// - Code: ~3 chars/token
34    /// - English text: ~3.5 chars/token
35    ///
36    /// Also adds weight for special characters and newlines.
37    ///
38    /// # Arguments
39    ///
40    /// * `text` - The text to estimate tokens for
41    ///
42    /// # Returns
43    ///
44    /// Estimated number of tokens
45    ///
46    /// # Example
47    ///
48    /// ```
49    /// use aster::context::token_estimator::TokenEstimator;
50    ///
51    /// let english_text = "Hello, world!";
52    /// let tokens = TokenEstimator::estimate_tokens(english_text);
53    /// assert!(tokens > 0);
54    /// ```
55    pub fn estimate_tokens(text: &str) -> usize {
56        if text.is_empty() {
57            return 0;
58        }
59
60        // Determine the primary content type
61        let chars_per_token = if Self::has_asian_chars(text) {
62            CHARS_PER_TOKEN_ASIAN
63        } else if Self::is_code(text) {
64            CHARS_PER_TOKEN_CODE
65        } else {
66            CHARS_PER_TOKEN_DEFAULT
67        };
68
69        // Count base characters
70        let char_count = text.chars().count();
71
72        // Calculate base token estimate
73        let base_tokens = (char_count as f64 / chars_per_token).ceil() as usize;
74
75        // Add weight for special characters and newlines
76        let special_weight = Self::calculate_special_weight(text);
77
78        base_tokens + special_weight
79    }
80
81    /// Check if text contains Asian characters (Chinese, Japanese, Korean).
82    ///
83    /// # Arguments
84    ///
85    /// * `text` - The text to check
86    ///
87    /// # Returns
88    ///
89    /// `true` if the text contains significant Asian characters
90    pub fn has_asian_chars(text: &str) -> bool {
91        let total_chars = text.chars().count();
92        if total_chars == 0 {
93            return false;
94        }
95
96        let asian_count = text.chars().filter(|c| Self::is_asian_char(*c)).count();
97
98        // Consider text as Asian if more than 20% of characters are Asian
99        (asian_count as f64 / total_chars as f64) > 0.2
100    }
101
102    /// Check if a single character is an Asian character.
103    fn is_asian_char(c: char) -> bool {
104        matches!(c,
105            // CJK Unified Ideographs
106            '\u{4E00}'..='\u{9FFF}' |
107            // CJK Unified Ideographs Extension A
108            '\u{3400}'..='\u{4DBF}' |
109            // CJK Unified Ideographs Extension B
110            '\u{20000}'..='\u{2A6DF}' |
111            // CJK Compatibility Ideographs
112            '\u{F900}'..='\u{FAFF}' |
113            // Hiragana
114            '\u{3040}'..='\u{309F}' |
115            // Katakana
116            '\u{30A0}'..='\u{30FF}' |
117            // Hangul Syllables
118            '\u{AC00}'..='\u{D7AF}' |
119            // Hangul Jamo
120            '\u{1100}'..='\u{11FF}' |
121            // Bopomofo
122            '\u{3100}'..='\u{312F}'
123        )
124    }
125
126    /// Check if text appears to be code.
127    ///
128    /// Uses heuristics to detect code content:
129    /// - Presence of code-specific characters ({}, [], ;, etc.)
130    /// - Indentation patterns with code keywords
131    /// - Common code keywords
132    ///
133    /// # Arguments
134    ///
135    /// * `text` - The text to check
136    ///
137    /// # Returns
138    ///
139    /// `true` if the text appears to be code
140    pub fn is_code(text: &str) -> bool {
141        // Check for code block markers
142        if text.contains("```") || text.contains("~~~") {
143            return true;
144        }
145
146        // Count code-specific indicators
147        let code_indicators = [
148            '{', '}', '[', ']', '(', ')', ';', '=', '+', '-', '*', '/', '<', '>', '&', '|', '!',
149        ];
150
151        let total_chars = text.chars().count();
152        if total_chars == 0 {
153            return false;
154        }
155
156        let code_char_count = text.chars().filter(|c| code_indicators.contains(c)).count();
157
158        // Check for common code patterns (keywords followed by specific syntax)
159        let has_code_patterns = text.contains("fn ")
160            || text.contains("def ")
161            || text.contains("function ")
162            || text.contains("class ")
163            || text.contains("const ")
164            || text.contains("let ")
165            || text.contains("var ")
166            || text.contains("import ")
167            || text.contains("pub ")
168            || text.contains("async ")
169            || text.contains("await ")
170            || text.contains("return ")
171            || text.contains("if ")
172            || text.contains("for ")
173            || text.contains("while ");
174
175        // Check for indentation with code patterns (more strict)
176        // Only consider it code if there's indentation AND code patterns
177        let has_indentation_with_code = text.lines().any(|line| {
178            let trimmed = line.trim_start();
179            let indent_size = line.len() - trimmed.len();
180            // Require at least 2 spaces of indentation AND the line must have code-like content
181            indent_size >= 2
182                && (trimmed.contains('{')
183                    || trimmed.contains('}')
184                    || trimmed.contains(';')
185                    || trimmed.starts_with("let ")
186                    || trimmed.starts_with("const ")
187                    || trimmed.starts_with("return ")
188                    || trimmed.starts_with("if ")
189                    || trimmed.starts_with("for ")
190                    || trimmed.starts_with("while ")
191                    || trimmed.starts_with("//")
192                    || trimmed.starts_with("#"))
193        });
194
195        // Consider it code if:
196        // - More than 5% of characters are code indicators, OR
197        // - Has code patterns (keywords), OR
198        // - Has indentation with code-like content
199        (code_char_count as f64 / total_chars as f64) > 0.05
200            || has_code_patterns
201            || has_indentation_with_code
202    }
203
204    /// Calculate additional weight for special characters and newlines.
205    fn calculate_special_weight(text: &str) -> usize {
206        let newline_count = text.chars().filter(|c| *c == '\n').count();
207        let special_count = text
208            .chars()
209            .filter(|c| {
210                matches!(
211                    c,
212                    '\t' | '\r' | '\\' | '"' | '\'' | '`' | '~' | '@' | '#' | '$' | '%' | '^'
213                )
214            })
215            .count();
216
217        // Each newline adds ~0.5 tokens, special chars add ~0.25 tokens
218        (newline_count as f64 * 0.5).ceil() as usize + (special_count as f64 * 0.25).ceil() as usize
219    }
220
221    /// Estimate the number of tokens in a message.
222    ///
223    /// Includes message overhead (role, formatting) plus content tokens.
224    ///
225    /// # Arguments
226    ///
227    /// * `message` - The message to estimate tokens for
228    ///
229    /// # Returns
230    ///
231    /// Estimated number of tokens
232    pub fn estimate_message_tokens(message: &Message) -> usize {
233        let content_tokens: usize = message
234            .content
235            .iter()
236            .map(Self::estimate_content_tokens)
237            .sum();
238
239        content_tokens + MESSAGE_OVERHEAD_TOKENS
240    }
241
242    /// Estimate tokens for a single message content block.
243    fn estimate_content_tokens(content: &MessageContent) -> usize {
244        match content {
245            MessageContent::Text(text_content) => Self::estimate_tokens(&text_content.text),
246            MessageContent::Image(_) => {
247                // Images typically use a fixed token count
248                // Claude uses ~1600 tokens for a typical image
249                1600
250            }
251            MessageContent::ToolRequest(tool_request) => {
252                // Estimate based on tool name and arguments
253                let mut tokens = 10; // Base overhead for tool request structure
254
255                if let Ok(call) = &tool_request.tool_call {
256                    tokens += Self::estimate_tokens(&call.name);
257                    if let Some(args) = &call.arguments {
258                        let args_str = serde_json::to_string(args).unwrap_or_default();
259                        tokens += Self::estimate_tokens(&args_str);
260                    }
261                }
262
263                tokens
264            }
265            MessageContent::ToolResponse(tool_response) => {
266                let mut tokens = 10; // Base overhead
267
268                if let Ok(result) = &tool_response.tool_result {
269                    for content in &result.content {
270                        if let Some(text) = content.as_text() {
271                            tokens += Self::estimate_tokens(&text.text);
272                        }
273                    }
274                }
275
276                tokens
277            }
278            MessageContent::Thinking(thinking) => Self::estimate_tokens(&thinking.thinking),
279            MessageContent::RedactedThinking(_) => 50, // Fixed estimate for redacted thinking
280            MessageContent::ToolConfirmationRequest(req) => {
281                let args_str = serde_json::to_string(&req.arguments).unwrap_or_default();
282                10 + Self::estimate_tokens(&req.tool_name) + Self::estimate_tokens(&args_str)
283            }
284            MessageContent::ActionRequired(action) => {
285                match &action.data {
286                    crate::conversation::message::ActionRequiredData::ToolConfirmation {
287                        tool_name,
288                        arguments,
289                        ..
290                    } => {
291                        let args_str = serde_json::to_string(arguments).unwrap_or_default();
292                        10 + Self::estimate_tokens(tool_name) + Self::estimate_tokens(&args_str)
293                    }
294                    crate::conversation::message::ActionRequiredData::Elicitation {
295                        message,
296                        ..
297                    } => 10 + Self::estimate_tokens(message),
298                    crate::conversation::message::ActionRequiredData::ElicitationResponse {
299                        ..
300                    } => 20, // Fixed estimate
301                }
302            }
303            MessageContent::FrontendToolRequest(req) => {
304                let mut tokens = 10;
305                if let Ok(call) = &req.tool_call {
306                    tokens += Self::estimate_tokens(&call.name);
307                    if let Some(args) = &call.arguments {
308                        let args_str = serde_json::to_string(args).unwrap_or_default();
309                        tokens += Self::estimate_tokens(&args_str);
310                    }
311                }
312                tokens
313            }
314            MessageContent::SystemNotification(notification) => {
315                Self::estimate_tokens(&notification.msg)
316            }
317        }
318    }
319
320    /// Estimate the total number of tokens for an array of messages.
321    ///
322    /// # Arguments
323    ///
324    /// * `messages` - The messages to estimate tokens for
325    ///
326    /// # Returns
327    ///
328    /// Total estimated tokens across all messages
329    pub fn estimate_total_tokens(messages: &[Message]) -> usize {
330        messages.iter().map(Self::estimate_message_tokens).sum()
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_estimate_tokens_empty() {
340        assert_eq!(TokenEstimator::estimate_tokens(""), 0);
341    }
342
343    #[test]
344    fn test_estimate_tokens_english() {
345        let text = "Hello, world! This is a test.";
346        let tokens = TokenEstimator::estimate_tokens(text);
347        // ~30 chars / 3.5 ≈ 9 tokens + special weight
348        assert!(tokens > 0);
349        assert!(tokens < 20);
350    }
351
352    #[test]
353    fn test_estimate_tokens_chinese() {
354        let text = "你好世界,这是一个测试。";
355        let tokens = TokenEstimator::estimate_tokens(text);
356        // ~12 chars / 2 ≈ 6 tokens
357        assert!(tokens > 0);
358        assert!(tokens < 15);
359    }
360
361    #[test]
362    fn test_estimate_tokens_code() {
363        let text = r#"
364fn main() {
365    println!("Hello, world!");
366}
367"#;
368        let tokens = TokenEstimator::estimate_tokens(text);
369        assert!(tokens > 0);
370    }
371
372    #[test]
373    fn test_has_asian_chars_chinese() {
374        assert!(TokenEstimator::has_asian_chars("你好世界"));
375        assert!(TokenEstimator::has_asian_chars("Hello 你好"));
376    }
377
378    #[test]
379    fn test_has_asian_chars_japanese() {
380        assert!(TokenEstimator::has_asian_chars("こんにちは"));
381        assert!(TokenEstimator::has_asian_chars("カタカナ"));
382    }
383
384    #[test]
385    fn test_has_asian_chars_korean() {
386        assert!(TokenEstimator::has_asian_chars("안녕하세요"));
387    }
388
389    #[test]
390    fn test_has_asian_chars_english() {
391        assert!(!TokenEstimator::has_asian_chars("Hello, world!"));
392        assert!(!TokenEstimator::has_asian_chars(""));
393    }
394
395    #[test]
396    fn test_is_code_rust() {
397        let code = r#"
398fn main() {
399    let x = 5;
400    println!("{}", x);
401}
402"#;
403        assert!(TokenEstimator::is_code(code));
404    }
405
406    #[test]
407    fn test_is_code_javascript() {
408        let code = r#"
409function hello() {
410    const x = 5;
411    return x + 1;
412}
413"#;
414        assert!(TokenEstimator::is_code(code));
415    }
416
417    #[test]
418    fn test_is_code_python() {
419        let code = r#"
420def hello():
421    x = 5
422    return x + 1
423"#;
424        assert!(TokenEstimator::is_code(code));
425    }
426
427    #[test]
428    fn test_is_code_markdown_block() {
429        let text = "```rust\nfn main() {}\n```";
430        assert!(TokenEstimator::is_code(text));
431    }
432
433    #[test]
434    fn test_is_code_plain_text() {
435        let text = "This is just plain English text without any code.";
436        assert!(!TokenEstimator::is_code(text));
437    }
438
439    #[test]
440    fn test_estimate_message_tokens() {
441        let message = Message::user().with_text("Hello, world!");
442        let tokens = TokenEstimator::estimate_message_tokens(&message);
443        // Content tokens + MESSAGE_OVERHEAD_TOKENS
444        assert!(tokens >= MESSAGE_OVERHEAD_TOKENS);
445    }
446
447    #[test]
448    fn test_estimate_total_tokens() {
449        let messages = vec![
450            Message::user().with_text("Hello"),
451            Message::assistant().with_text("Hi there!"),
452        ];
453        let total = TokenEstimator::estimate_total_tokens(&messages);
454        assert!(total > 0);
455        assert!(total >= MESSAGE_OVERHEAD_TOKENS * 2);
456    }
457
458    #[test]
459    fn test_estimate_tokens_with_newlines() {
460        let text = "Line 1\nLine 2\nLine 3";
461        let tokens = TokenEstimator::estimate_tokens(text);
462        // Should include weight for newlines
463        assert!(tokens > 0);
464    }
465
466    #[test]
467    fn test_estimate_tokens_with_special_chars() {
468        let text = "Hello @user #tag $var %percent";
469        let tokens = TokenEstimator::estimate_tokens(text);
470        // Should include weight for special characters
471        assert!(tokens > 0);
472    }
473}