Skip to main content

agent_sdk/context/
estimator.rs

1//! Token estimation for context size calculation.
2
3use crate::llm::{Content, ContentBlock, Message};
4
5/// Estimates token count for messages.
6///
7/// Uses a simple heuristic of ~4 characters per token, which provides
8/// a reasonable approximation for most English text and code.
9///
10/// For more accurate counting, consider using a tokenizer library
11/// specific to your model (e.g., tiktoken for `OpenAI` models).
12pub struct TokenEstimator;
13
14impl TokenEstimator {
15    /// Characters per token estimate.
16    /// This is a conservative estimate; actual ratio varies by content.
17    const CHARS_PER_TOKEN: usize = 4;
18
19    /// Overhead tokens per message (role, formatting).
20    const MESSAGE_OVERHEAD: usize = 4;
21
22    /// Overhead for tool use blocks (id, name, formatting).
23    const TOOL_USE_OVERHEAD: usize = 20;
24
25    /// Overhead for tool result blocks (id, formatting).
26    const TOOL_RESULT_OVERHEAD: usize = 10;
27
28    /// Estimate tokens for a text string.
29    #[must_use]
30    pub const fn estimate_text(text: &str) -> usize {
31        // Simple estimation: ~4 chars per token
32        text.len().div_ceil(Self::CHARS_PER_TOKEN)
33    }
34
35    /// Estimate tokens for a single message.
36    #[must_use]
37    pub fn estimate_message(message: &Message) -> usize {
38        let content_tokens = match &message.content {
39            Content::Text(text) => Self::estimate_text(text),
40            Content::Blocks(blocks) => blocks.iter().map(Self::estimate_block).sum(),
41        };
42
43        content_tokens + Self::MESSAGE_OVERHEAD
44    }
45
46    /// Estimate tokens for a content block.
47    #[must_use]
48    pub fn estimate_block(block: &ContentBlock) -> usize {
49        match block {
50            ContentBlock::Text { text } => Self::estimate_text(text),
51            ContentBlock::Thinking { thinking, .. } => Self::estimate_text(thinking),
52            ContentBlock::RedactedThinking { .. } => 10, // Fixed overhead for opaque block
53            ContentBlock::ToolUse { name, input, .. } => {
54                let input_str = serde_json::to_string(input).unwrap_or_default();
55                Self::estimate_text(name)
56                    + Self::estimate_text(&input_str)
57                    + Self::TOOL_USE_OVERHEAD
58            }
59            ContentBlock::ToolResult { content, .. } => {
60                Self::estimate_text(content) + Self::TOOL_RESULT_OVERHEAD
61            }
62            ContentBlock::Image { source } | ContentBlock::Document { source } => {
63                // Rough estimate: base64 data is ~4/3 of original, 1 token per 4 chars
64                source.data.len() / 4 + Self::MESSAGE_OVERHEAD
65            }
66        }
67    }
68
69    /// Estimate total tokens for a message history.
70    #[must_use]
71    pub fn estimate_history(messages: &[Message]) -> usize {
72        messages.iter().map(Self::estimate_message).sum()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::llm::Role;
80    use serde_json::json;
81
82    #[test]
83    fn test_estimate_text() {
84        // Empty text
85        assert_eq!(TokenEstimator::estimate_text(""), 0);
86
87        // Short text (less than 4 chars)
88        assert_eq!(TokenEstimator::estimate_text("hi"), 1);
89
90        // Exactly 4 chars
91        assert_eq!(TokenEstimator::estimate_text("test"), 1);
92
93        // 5 chars should be 2 tokens
94        assert_eq!(TokenEstimator::estimate_text("hello"), 2);
95
96        // Longer text
97        assert_eq!(TokenEstimator::estimate_text("hello world!"), 3); // 12 chars / 4 = 3
98    }
99
100    #[test]
101    fn test_estimate_text_message() {
102        let message = Message {
103            role: Role::User,
104            content: Content::Text("Hello, how are you?".to_string()), // 19 chars = 5 tokens
105        };
106
107        let estimate = TokenEstimator::estimate_message(&message);
108        // 5 content tokens + 4 overhead = 9
109        assert_eq!(estimate, 9);
110    }
111
112    #[test]
113    fn test_estimate_blocks_message() {
114        let message = Message {
115            role: Role::Assistant,
116            content: Content::Blocks(vec![
117                ContentBlock::Text {
118                    text: "Let me help.".to_string(), // 12 chars = 3 tokens
119                },
120                ContentBlock::ToolUse {
121                    id: "tool_123".to_string(),
122                    name: "read".to_string(),            // 4 chars = 1 token
123                    input: json!({"path": "/test.txt"}), // ~20 chars = 5 tokens
124                    thought_signature: None,
125                },
126            ]),
127        };
128
129        let estimate = TokenEstimator::estimate_message(&message);
130        // Text: 3 tokens
131        // ToolUse: 1 (name) + 5 (input) + 20 (overhead) = 26 tokens
132        // Message overhead: 4
133        // Total: 3 + 26 + 4 = 33
134        assert!(estimate > 25); // Verify it accounts for tool use
135    }
136
137    #[test]
138    fn test_estimate_tool_result() {
139        let message = Message {
140            role: Role::User,
141            content: Content::Blocks(vec![ContentBlock::ToolResult {
142                tool_use_id: "tool_123".to_string(),
143                content: "File contents here...".to_string(), // 21 chars = 6 tokens
144                is_error: None,
145            }]),
146        };
147
148        let estimate = TokenEstimator::estimate_message(&message);
149        // 6 content + 10 overhead + 4 message overhead = 20
150        assert_eq!(estimate, 20);
151    }
152
153    #[test]
154    fn test_estimate_history() {
155        let messages = vec![
156            Message::user("Hello"),          // 5 chars = 2 tokens + 4 overhead = 6
157            Message::assistant("Hi there!"), // 9 chars = 3 tokens + 4 overhead = 7
158            Message::user("How are you?"),   // 12 chars = 3 tokens + 4 overhead = 7
159        ];
160
161        let estimate = TokenEstimator::estimate_history(&messages);
162        assert_eq!(estimate, 20);
163    }
164
165    #[test]
166    fn test_empty_history() {
167        let messages: Vec<Message> = vec![];
168        assert_eq!(TokenEstimator::estimate_history(&messages), 0);
169    }
170}