Skip to main content

agent_sdk/context/
estimator.rs

1//! Token estimation for context size calculation.
2
3use crate::llm::{Content, ContentBlock, Message};
4
5/// Estimates token count for messages.
6///
7/// Uses a simple heuristic of ~4 characters per token, which provides
8/// a reasonable approximation for most English text and code.
9///
10/// For more accurate counting, consider using a tokenizer library
11/// specific to your model (e.g., tiktoken for `OpenAI` models).
12pub struct TokenEstimator;
13
14impl TokenEstimator {
15    /// Characters per token estimate.
16    /// This is a conservative estimate; actual ratio varies by content.
17    const CHARS_PER_TOKEN: usize = 4;
18
19    /// Overhead tokens per message (role, formatting).
20    const MESSAGE_OVERHEAD: usize = 4;
21
22    /// Overhead for tool use blocks (id, name, formatting).
23    const TOOL_USE_OVERHEAD: usize = 20;
24
25    /// Overhead for tool result blocks (id, formatting).
26    const TOOL_RESULT_OVERHEAD: usize = 10;
27
28    /// Estimate tokens for a text string.
29    #[must_use]
30    pub const fn estimate_text(text: &str) -> usize {
31        // Simple estimation: ~4 chars per token
32        text.len().div_ceil(Self::CHARS_PER_TOKEN)
33    }
34
35    /// Estimate tokens for a single message.
36    #[must_use]
37    pub fn estimate_message(message: &Message) -> usize {
38        let content_tokens = match &message.content {
39            Content::Text(text) => Self::estimate_text(text),
40            Content::Blocks(blocks) => blocks.iter().map(Self::estimate_block).sum(),
41        };
42
43        content_tokens + Self::MESSAGE_OVERHEAD
44    }
45
46    /// Estimate tokens for a content block.
47    #[must_use]
48    pub fn estimate_block(block: &ContentBlock) -> usize {
49        match block {
50            ContentBlock::Text { text } => Self::estimate_text(text),
51            ContentBlock::Thinking { thinking } => Self::estimate_text(thinking),
52            ContentBlock::ToolUse { name, input, .. } => {
53                let input_str = serde_json::to_string(input).unwrap_or_default();
54                Self::estimate_text(name)
55                    + Self::estimate_text(&input_str)
56                    + Self::TOOL_USE_OVERHEAD
57            }
58            ContentBlock::ToolResult { content, .. } => {
59                Self::estimate_text(content) + Self::TOOL_RESULT_OVERHEAD
60            }
61        }
62    }
63
64    /// Estimate total tokens for a message history.
65    #[must_use]
66    pub fn estimate_history(messages: &[Message]) -> usize {
67        messages.iter().map(Self::estimate_message).sum()
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::llm::Role;
75    use serde_json::json;
76
77    #[test]
78    fn test_estimate_text() {
79        // Empty text
80        assert_eq!(TokenEstimator::estimate_text(""), 0);
81
82        // Short text (less than 4 chars)
83        assert_eq!(TokenEstimator::estimate_text("hi"), 1);
84
85        // Exactly 4 chars
86        assert_eq!(TokenEstimator::estimate_text("test"), 1);
87
88        // 5 chars should be 2 tokens
89        assert_eq!(TokenEstimator::estimate_text("hello"), 2);
90
91        // Longer text
92        assert_eq!(TokenEstimator::estimate_text("hello world!"), 3); // 12 chars / 4 = 3
93    }
94
95    #[test]
96    fn test_estimate_text_message() {
97        let message = Message {
98            role: Role::User,
99            content: Content::Text("Hello, how are you?".to_string()), // 19 chars = 5 tokens
100        };
101
102        let estimate = TokenEstimator::estimate_message(&message);
103        // 5 content tokens + 4 overhead = 9
104        assert_eq!(estimate, 9);
105    }
106
107    #[test]
108    fn test_estimate_blocks_message() {
109        let message = Message {
110            role: Role::Assistant,
111            content: Content::Blocks(vec![
112                ContentBlock::Text {
113                    text: "Let me help.".to_string(), // 12 chars = 3 tokens
114                },
115                ContentBlock::ToolUse {
116                    id: "tool_123".to_string(),
117                    name: "read".to_string(),            // 4 chars = 1 token
118                    input: json!({"path": "/test.txt"}), // ~20 chars = 5 tokens
119                    thought_signature: None,
120                },
121            ]),
122        };
123
124        let estimate = TokenEstimator::estimate_message(&message);
125        // Text: 3 tokens
126        // ToolUse: 1 (name) + 5 (input) + 20 (overhead) = 26 tokens
127        // Message overhead: 4
128        // Total: 3 + 26 + 4 = 33
129        assert!(estimate > 25); // Verify it accounts for tool use
130    }
131
132    #[test]
133    fn test_estimate_tool_result() {
134        let message = Message {
135            role: Role::User,
136            content: Content::Blocks(vec![ContentBlock::ToolResult {
137                tool_use_id: "tool_123".to_string(),
138                content: "File contents here...".to_string(), // 21 chars = 6 tokens
139                is_error: None,
140            }]),
141        };
142
143        let estimate = TokenEstimator::estimate_message(&message);
144        // 6 content + 10 overhead + 4 message overhead = 20
145        assert_eq!(estimate, 20);
146    }
147
148    #[test]
149    fn test_estimate_history() {
150        let messages = vec![
151            Message::user("Hello"),          // 5 chars = 2 tokens + 4 overhead = 6
152            Message::assistant("Hi there!"), // 9 chars = 3 tokens + 4 overhead = 7
153            Message::user("How are you?"),   // 12 chars = 3 tokens + 4 overhead = 7
154        ];
155
156        let estimate = TokenEstimator::estimate_history(&messages);
157        assert_eq!(estimate, 20);
158    }
159
160    #[test]
161    fn test_empty_history() {
162        let messages: Vec<Message> = vec![];
163        assert_eq!(TokenEstimator::estimate_history(&messages), 0);
164    }
165}