Skip to main content

ai_agent/services/
token_estimation.rs

1//! Token estimation for text.
2//!
3//! Provides token counting similar to claude code's token estimation.
4
5use crate::types::Message;
6use serde::{Deserialize, Serialize};
7
8/// Estimated token count with metadata
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenEstimate {
11    pub tokens: usize,
12    pub characters: usize,
13    pub words: usize,
14    pub method: EstimationMethod,
15}
16
17/// Method used for estimation
18#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum EstimationMethod {
20    /// Fast estimation using character ratio
21    CharacterRatio,
22    /// Word-based estimation
23    WordBased,
24    /// Exact TikToken estimation (if available)
25    TikToken,
26}
27
28// ============================================================================
29// Translation of claude code's tokenEstimation.ts - strictly line by line
30// ============================================================================
31
32/// Rough token count estimation - matches original TypeScript:
33/// `export function roughTokenCountEstimation(content: string, bytesPerToken: number = 4): number`
34pub fn rough_token_count_estimation(content: &str, bytes_per_token: f64) -> usize {
35    (content.len() as f64 / bytes_per_token).round() as usize
36}
37
38/// Returns bytes-per-token ratio for a given file extension
39/// Matches original TypeScript:
40/// `export function bytesPerTokenForFileType(fileExtension: string): number`
41/// Dense JSON has many single-character tokens which makes ratio closer to 2
42pub fn bytes_per_token_for_file_type(file_extension: &str) -> f64 {
43    match file_extension {
44        "json" | "jsonl" | "jsonc" => 2.0,
45        _ => 4.0,
46    }
47}
48
49/// Like roughTokenCountEstimation but uses more accurate bytes-per-token ratio
50/// when file type is known - matches original TypeScript:
51/// `export function roughTokenCountEstimationForFileType(content: string, fileExtension: string): number`
52pub fn rough_token_count_estimation_for_file_type(content: &str, file_extension: &str) -> usize {
53    rough_token_count_estimation(content, bytes_per_token_for_file_type(file_extension))
54}
55
56/// Estimate tokens for a single message - matches original TypeScript:
57/// `export function roughTokenCountEstimationForMessage(message: {...}): number`
58pub fn rough_token_count_estimation_for_message(message: &Message) -> usize {
59    rough_token_count_estimation_for_content(&message.content)
60}
61
62/// Estimate tokens for message content (string or array) - matches original TypeScript:
63/// `function roughTokenCountEstimationForContent(content: ...): number`
64pub fn rough_token_count_estimation_for_content(content: &str) -> usize {
65    if content.is_empty() {
66        return 0;
67    }
68    rough_token_count_estimation(content, 4.0)
69}
70
71/// Estimate tokens for an array of messages - matches original TypeScript:
72/// `export function roughTokenCountEstimationForMessages(messages: readonly {...}[]): number`
73pub fn rough_token_count_estimation_for_messages(messages: &[Message]) -> usize {
74    messages
75        .iter()
76        .map(|msg| rough_token_count_estimation_for_message(msg))
77        .sum()
78}
79
80// ============================================================================
81// Legacy estimation functions (kept for backward compatibility)
82// ============================================================================
83
84/// Estimate tokens using character ratio method (faster but less accurate)
85/// Average ratio is ~4 characters per token for English
86pub fn estimate_tokens_characters(text: &str) -> TokenEstimate {
87    let characters = text.len();
88    let words = text.split_whitespace().count();
89
90    // Use 4:1 character to token ratio as baseline
91    // Adjust based on text characteristics
92    let ratio = if text.contains("```") {
93        // Code blocks have more characters per token
94        5.5
95    } else if words > 0 {
96        let avg_word_len = characters as f64 / words as f64;
97        if avg_word_len > 8.0 {
98            // Long words = more characters per token
99            5.0
100        } else if avg_word_len < 3.0 {
101            // Short words = fewer characters per token
102            3.5
103        } else {
104            4.0
105        }
106    } else {
107        4.0
108    };
109
110    let tokens = (characters as f64 / ratio).ceil() as usize;
111
112    TokenEstimate {
113        tokens,
114        characters,
115        words,
116        method: EstimationMethod::CharacterRatio,
117    }
118}
119
120/// Estimate tokens using word-based method
121pub fn estimate_tokens_words(text: &str) -> TokenEstimate {
122    let words = text.split_whitespace().count();
123    let characters = text.len();
124
125    // Average ~1.3 words per token for English
126    let tokens = (words as f64 / 1.3).ceil() as usize;
127
128    TokenEstimate {
129        tokens,
130        characters,
131        words,
132        method: EstimationMethod::WordBased,
133    }
134}
135
136/// Estimate tokens using combined method (best balance of speed and accuracy)
137pub fn estimate_tokens(text: &str) -> TokenEstimate {
138    let char_estimate = estimate_tokens_characters(text);
139    let word_estimate = estimate_tokens_words(text);
140
141    // Use the average of both methods for better accuracy
142    let tokens = (char_estimate.tokens + word_estimate.tokens) / 2;
143
144    TokenEstimate {
145        tokens,
146        characters: char_estimate.characters,
147        words: char_estimate.words,
148        method: EstimationMethod::CharacterRatio,
149    }
150}
151
152/// Estimate tokens in messages (handles role/content format)
153pub fn estimate_message_tokens<T: MessageContent>(messages: &[T]) -> usize {
154    messages
155        .iter()
156        .map(|m| {
157            let content = m.content();
158            // Add overhead for role annotation
159            let role_overhead = 4;
160            estimate_tokens(content).tokens + role_overhead
161        })
162        .sum()
163}
164
165/// Estimate tokens in a conversation string
166pub fn estimate_conversation(conversation: &str) -> TokenEstimate {
167    // Count turns by looking for common patterns
168    let turns = conversation
169        .matches("User:")
170        .count()
171        .max(conversation.matches("Assistant:").count());
172
173    // Each turn has overhead for role prefix
174    let turn_overhead = turns * 10;
175
176    let base = estimate_tokens(conversation);
177    TokenEstimate {
178        tokens: base.tokens + turn_overhead,
179        characters: base.characters,
180        words: base.words,
181        method: base.method,
182    }
183}
184
185/// Estimate tokens for tool definitions
186pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> usize {
187    tools
188        .iter()
189        .map(|t| {
190            let name_tokens = estimate_tokens(&t.name).tokens;
191            let desc_tokens = t
192                .description
193                .as_ref()
194                .map(|d| estimate_tokens(d).tokens)
195                .unwrap_or(0);
196            let params_tokens = estimate_tokens(&t.input_schema).tokens;
197            name_tokens + desc_tokens + params_tokens + 20 // overhead
198        })
199        .sum()
200}
201
202/// Simple message content for estimation
203pub trait MessageContent {
204    fn content(&self) -> &str;
205}
206
207impl MessageContent for String {
208    fn content(&self) -> &str {
209        self.as_str()
210    }
211}
212
213impl MessageContent for &str {
214    fn content(&self) -> &str {
215        self
216    }
217}
218
219/// Message with role
220#[derive(Debug, Clone)]
221pub struct ChatMessage {
222    pub role: String,
223    pub content: String,
224}
225
226impl MessageContent for ChatMessage {
227    fn content(&self) -> &str {
228        &self.content
229    }
230}
231
232/// Tool definition for estimation
233#[derive(Debug, Clone)]
234pub struct ToolDefinition {
235    pub name: String,
236    pub description: Option<String>,
237    pub input_schema: String,
238}
239
240/// Calculate padding needed for context window
241/// Returns the amount of extra input tokens that could fit given the output token budget
242pub fn calculate_padding(input_tokens: usize, max_tokens: usize, context_limit: usize) -> usize {
243    // Calculate how much room is left for input given the output budget
244    let available_for_input = context_limit.saturating_sub(max_tokens);
245    if input_tokens < available_for_input {
246        available_for_input.saturating_sub(input_tokens)
247    } else {
248        0
249    }
250}
251
252/// Estimate if content fits in context
253pub fn fits_in_context(content_tokens: usize, max_tokens: usize, context_limit: usize) -> bool {
254    content_tokens + max_tokens <= context_limit
255}
256
257/// Token encoding utilities
258pub mod encoding {
259    /// Common tokenization patterns
260    pub const CHARS_PER_TOKEN_EN: f64 = 4.0;
261    pub const CHARS_PER_TOKEN_CODE: f64 = 5.5;
262    pub const CHARS_PER_TOKEN_CJK: f64 = 2.0; // Chinese, Japanese, Korean
263
264    /// Detect if text is primarily code
265    pub fn is_code(text: &str) -> bool {
266        let code_indicators = [
267            "```", "function", "class ", "def ", "const ", "let ", "var ", "import ",
268        ];
269        code_indicators.iter().any(|i| text.contains(i))
270    }
271
272    /// Detect if text is primarily CJK
273    pub fn is_cjk(text: &str) -> bool {
274        text.chars().any(|c| {
275            (c >= '\u{4E00}' && c <= '\u{9FFF}') ||  // CJK Unified Ideographs
276            (c >= '\u{3040}' && c <= '\u{309F}') ||  // Hiragana
277            (c >= '\u{30A0}' && c <= '\u{30FF}') ||  // Katakana
278            (c >= '\u{AC00}' && c <= '\u{D7AF}') // Korean
279        })
280    }
281
282    /// Get appropriate chars per token ratio
283    pub fn chars_per_token(text: &str) -> f64 {
284        if is_code(text) {
285            super::encoding::CHARS_PER_TOKEN_CODE
286        } else if is_cjk(text) {
287            super::encoding::CHARS_PER_TOKEN_CJK
288        } else {
289            super::encoding::CHARS_PER_TOKEN_EN
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::types::MessageRole;
298
299    // ============================================================================
300    // Tests for the translated TypeScript functions
301    // ============================================================================
302
303    #[test]
304    fn test_rough_token_count_estimation() {
305        // "Hello world" = 11 chars, 11/4 = 2.75 rounds to 3
306        assert_eq!(rough_token_count_estimation("Hello world", 4.0), 3);
307        // 100 chars / 4 = 25 tokens
308        assert_eq!(rough_token_count_estimation(&"a".repeat(100), 4.0), 25);
309    }
310
311    #[test]
312    fn test_bytes_per_token_for_file_type() {
313        assert_eq!(bytes_per_token_for_file_type("json"), 2.0);
314        assert_eq!(bytes_per_token_for_file_type("jsonl"), 2.0);
315        assert_eq!(bytes_per_token_for_file_type("rs"), 4.0);
316        assert_eq!(bytes_per_token_for_file_type("txt"), 4.0);
317    }
318
319    #[test]
320    fn test_rough_token_count_estimation_for_file_type() {
321        // JSON: 100 chars / 2 = 50 tokens
322        assert_eq!(
323            rough_token_count_estimation_for_file_type(&"a".repeat(100), "json"),
324            50
325        );
326        // Rust: 100 chars / 4 = 25 tokens
327        assert_eq!(
328            rough_token_count_estimation_for_file_type(&"a".repeat(100), "rs"),
329            25
330        );
331    }
332
333    #[test]
334    fn test_rough_token_count_estimation_for_content() {
335        assert_eq!(rough_token_count_estimation_for_content(""), 0);
336        // "Hello" = 5 chars, 5/4 = 1.25 rounds to 1
337        assert_eq!(rough_token_count_estimation_for_content("Hello"), 1);
338    }
339
340    #[test]
341    fn test_rough_token_count_estimation_for_message() {
342        let msg = crate::types::Message {
343            role: MessageRole::User,
344            content: "Hello world".to_string(),
345            ..Default::default()
346        };
347        // "Hello world" = 11 chars, 11/4 = 2.75 rounds to 3
348        assert_eq!(rough_token_count_estimation_for_message(&msg), 3);
349    }
350
351    #[test]
352    fn test_rough_token_count_estimation_for_messages() {
353        let messages = vec![
354            crate::types::Message {
355                role: MessageRole::User,
356                content: "Hello".to_string(),
357                ..Default::default()
358            },
359            crate::types::Message {
360                role: MessageRole::Assistant,
361                content: "Hi there".to_string(),
362                ..Default::default()
363            },
364        ];
365        // "Hello" = 5 chars / 4 = 1.25 -> 1 token
366        // "Hi there" = 8 chars / 4 = 2 tokens
367        // Total = 3 tokens
368        assert_eq!(rough_token_count_estimation_for_messages(&messages), 3);
369    }
370
371    // ============================================================================
372    // Tests for legacy estimation functions
373    // ============================================================================
374
375    #[test]
376    fn test_estimate_tokens_characters() {
377        let result = estimate_tokens_characters("Hello, world!");
378        assert!(result.tokens >= 3);
379        assert_eq!(result.characters, 13);
380    }
381
382    #[test]
383    fn test_estimate_tokens_words() {
384        let result = estimate_tokens_words("Hello world this is a test");
385        assert!(result.tokens > 0);
386        assert_eq!(result.words, 6);
387    }
388
389    #[test]
390    fn test_estimate_tokens() {
391        let result = estimate_tokens("The quick brown fox jumps over the lazy dog");
392        assert!(result.tokens > 0);
393    }
394
395    #[test]
396    fn test_estimate_conversation() {
397        let conv = "User: Hello\nAssistant: Hi there!\nUser: How are you?";
398        let result = estimate_conversation(conv);
399        assert!(result.tokens > 0);
400    }
401
402    #[test]
403    fn test_estimate_tool_definitions() {
404        let tools = vec![ToolDefinition {
405            name: "Read".to_string(),
406            description: Some("Read a file".to_string()),
407            input_schema: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#
408                .to_string(),
409        }];
410        let tokens = estimate_tool_definitions(&tools);
411        assert!(tokens > 0);
412    }
413
414    #[test]
415    fn test_calculate_padding() {
416        assert_eq!(calculate_padding(1000, 500, 2000), 500);
417        assert_eq!(calculate_padding(1500, 500, 2000), 0);
418    }
419
420    #[test]
421    fn test_fits_in_context() {
422        assert!(fits_in_context(1000, 500, 2000));
423        assert!(!fits_in_context(1600, 500, 2000));
424    }
425
426    #[test]
427    fn test_encoding_chars_per_token() {
428        assert_eq!(
429            encoding::chars_per_token("Hello world"),
430            encoding::CHARS_PER_TOKEN_EN
431        );
432        assert_eq!(
433            encoding::chars_per_token("function test() {}"),
434            encoding::CHARS_PER_TOKEN_CODE
435        );
436    }
437
438    #[test]
439    fn test_is_code() {
440        assert!(encoding::is_code("function foo() { return 1; }"));
441        assert!(!encoding::is_code("Hello world"));
442    }
443
444    #[test]
445    fn test_is_cjk() {
446        assert!(encoding::is_cjk("你好世界"));
447        assert!(!encoding::is_cjk("Hello world"));
448    }
449
450    #[test]
451    fn test_message_content_trait() {
452        let msg = ChatMessage {
453            role: "user".to_string(),
454            content: "Hello".to_string(),
455        };
456        assert_eq!(msg.content(), "Hello");
457    }
458}