agent_sdk/context/
estimator.rs

1//! Token estimation for context size calculation.
2
3use crate::llm::{Content, ContentBlock, Message};
4
5/// Estimates token count for messages.
6///
7/// Uses a simple heuristic of ~4 characters per token, which provides
8/// a reasonable approximation for most English text and code.
9///
10/// For more accurate counting, consider using a tokenizer library
11/// specific to your model (e.g., tiktoken for `OpenAI` models).
12pub struct TokenEstimator;
13
14impl TokenEstimator {
15    /// Characters per token estimate.
16    /// This is a conservative estimate; actual ratio varies by content.
17    const CHARS_PER_TOKEN: usize = 4;
18
19    /// Overhead tokens per message (role, formatting).
20    const MESSAGE_OVERHEAD: usize = 4;
21
22    /// Overhead for tool use blocks (id, name, formatting).
23    const TOOL_USE_OVERHEAD: usize = 20;
24
25    /// Overhead for tool result blocks (id, formatting).
26    const TOOL_RESULT_OVERHEAD: usize = 10;
27
28    /// Estimate tokens for a text string.
29    #[must_use]
30    pub const fn estimate_text(text: &str) -> usize {
31        // Simple estimation: ~4 chars per token
32        text.len().div_ceil(Self::CHARS_PER_TOKEN)
33    }
34
35    /// Estimate tokens for a single message.
36    #[must_use]
37    pub fn estimate_message(message: &Message) -> usize {
38        let content_tokens = match &message.content {
39            Content::Text(text) => Self::estimate_text(text),
40            Content::Blocks(blocks) => blocks.iter().map(Self::estimate_block).sum(),
41        };
42
43        content_tokens + Self::MESSAGE_OVERHEAD
44    }
45
46    /// Estimate tokens for a content block.
47    #[must_use]
48    pub fn estimate_block(block: &ContentBlock) -> usize {
49        match block {
50            ContentBlock::Text { text } => Self::estimate_text(text),
51            ContentBlock::ToolUse { name, input, .. } => {
52                let input_str = serde_json::to_string(input).unwrap_or_default();
53                Self::estimate_text(name)
54                    + Self::estimate_text(&input_str)
55                    + Self::TOOL_USE_OVERHEAD
56            }
57            ContentBlock::ToolResult { content, .. } => {
58                Self::estimate_text(content) + Self::TOOL_RESULT_OVERHEAD
59            }
60        }
61    }
62
63    /// Estimate total tokens for a message history.
64    #[must_use]
65    pub fn estimate_history(messages: &[Message]) -> usize {
66        messages.iter().map(Self::estimate_message).sum()
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use crate::llm::Role;
74    use serde_json::json;
75
76    #[test]
77    fn test_estimate_text() {
78        // Empty text
79        assert_eq!(TokenEstimator::estimate_text(""), 0);
80
81        // Short text (less than 4 chars)
82        assert_eq!(TokenEstimator::estimate_text("hi"), 1);
83
84        // Exactly 4 chars
85        assert_eq!(TokenEstimator::estimate_text("test"), 1);
86
87        // 5 chars should be 2 tokens
88        assert_eq!(TokenEstimator::estimate_text("hello"), 2);
89
90        // Longer text
91        assert_eq!(TokenEstimator::estimate_text("hello world!"), 3); // 12 chars / 4 = 3
92    }
93
94    #[test]
95    fn test_estimate_text_message() {
96        let message = Message {
97            role: Role::User,
98            content: Content::Text("Hello, how are you?".to_string()), // 19 chars = 5 tokens
99        };
100
101        let estimate = TokenEstimator::estimate_message(&message);
102        // 5 content tokens + 4 overhead = 9
103        assert_eq!(estimate, 9);
104    }
105
106    #[test]
107    fn test_estimate_blocks_message() {
108        let message = Message {
109            role: Role::Assistant,
110            content: Content::Blocks(vec![
111                ContentBlock::Text {
112                    text: "Let me help.".to_string(), // 12 chars = 3 tokens
113                },
114                ContentBlock::ToolUse {
115                    id: "tool_123".to_string(),
116                    name: "read".to_string(),            // 4 chars = 1 token
117                    input: json!({"path": "/test.txt"}), // ~20 chars = 5 tokens
118                },
119            ]),
120        };
121
122        let estimate = TokenEstimator::estimate_message(&message);
123        // Text: 3 tokens
124        // ToolUse: 1 (name) + 5 (input) + 20 (overhead) = 26 tokens
125        // Message overhead: 4
126        // Total: 3 + 26 + 4 = 33
127        assert!(estimate > 25); // Verify it accounts for tool use
128    }
129
130    #[test]
131    fn test_estimate_tool_result() {
132        let message = Message {
133            role: Role::User,
134            content: Content::Blocks(vec![ContentBlock::ToolResult {
135                tool_use_id: "tool_123".to_string(),
136                content: "File contents here...".to_string(), // 21 chars = 6 tokens
137                is_error: None,
138            }]),
139        };
140
141        let estimate = TokenEstimator::estimate_message(&message);
142        // 6 content + 10 overhead + 4 message overhead = 20
143        assert_eq!(estimate, 20);
144    }
145
146    #[test]
147    fn test_estimate_history() {
148        let messages = vec![
149            Message::user("Hello"),          // 5 chars = 2 tokens + 4 overhead = 6
150            Message::assistant("Hi there!"), // 9 chars = 3 tokens + 4 overhead = 7
151            Message::user("How are you?"),   // 12 chars = 3 tokens + 4 overhead = 7
152        ];
153
154        let estimate = TokenEstimator::estimate_history(&messages);
155        assert_eq!(estimate, 20);
156    }
157
158    #[test]
159    fn test_empty_history() {
160        let messages: Vec<Message> = vec![];
161        assert_eq!(TokenEstimator::estimate_history(&messages), 0);
162    }
163}