Skip to main content

agent_sdk/context/
estimator.rs

1//! Token estimation for context size calculation.
2
3use crate::llm::{Content, ContentBlock, Message};
4
5/// Estimates token count for messages.
6///
7/// Uses a simple heuristic of ~4 characters per token, which provides
8/// a reasonable approximation for most English text and code.
9///
10/// For more accurate counting, consider using a tokenizer library
11/// specific to your model (e.g., tiktoken for `OpenAI` models).
12pub struct TokenEstimator;
13
14impl TokenEstimator {
15    /// Characters per token estimate.
16    /// This is a conservative estimate; actual ratio varies by content.
17    const CHARS_PER_TOKEN: usize = 4;
18
19    /// Overhead tokens per message (role, formatting).
20    const MESSAGE_OVERHEAD: usize = 4;
21
22    /// Overhead for tool use blocks (id, name, formatting).
23    const TOOL_USE_OVERHEAD: usize = 20;
24
25    /// Overhead for tool result blocks (id, formatting).
26    const TOOL_RESULT_OVERHEAD: usize = 10;
27
28    /// Minimum token estimate for redacted thinking blocks.
29    ///
30    /// Even small redacted thinking blocks carry significant API token cost
31    /// because they contain encrypted reasoning that the model must process.
32    const REDACTED_THINKING_MIN_TOKENS: usize = 512;
33
34    /// Estimate tokens for a text string.
35    #[must_use]
36    pub const fn estimate_text(text: &str) -> usize {
37        // Simple estimation: ~4 chars per token
38        text.len().div_ceil(Self::CHARS_PER_TOKEN)
39    }
40
41    /// Estimate tokens for a single message.
42    #[must_use]
43    pub fn estimate_message(message: &Message) -> usize {
44        let content_tokens = match &message.content {
45            Content::Text(text) => Self::estimate_text(text),
46            Content::Blocks(blocks) => blocks.iter().map(Self::estimate_block).sum(),
47        };
48
49        content_tokens + Self::MESSAGE_OVERHEAD
50    }
51
52    /// Estimate tokens for a content block.
53    #[must_use]
54    pub fn estimate_block(block: &ContentBlock) -> usize {
55        match block {
56            ContentBlock::Text { text } => Self::estimate_text(text),
57            ContentBlock::Thinking { thinking, .. } => Self::estimate_text(thinking),
58            ContentBlock::RedactedThinking { data } => {
59                // The data field is a base64-encoded encrypted blob whose size
60                // correlates with the original thinking content.  Base64 encodes
61                // 3 bytes into 4 chars, so `data.len() * 3 / 4` approximates
62                // the raw byte count.  Using the same chars-per-token heuristic
63                // on the raw bytes gives a reasonable lower bound.
64                //
65                // A floor of REDACTED_THINKING_MIN_TOKENS prevents tiny blocks
66                // from being under-counted — the API charges substantial token
67                // overhead for every redacted thinking block regardless of size.
68                let raw_bytes = data.len() * 3 / 4;
69                let estimated = raw_bytes.div_ceil(Self::CHARS_PER_TOKEN);
70                estimated.max(Self::REDACTED_THINKING_MIN_TOKENS)
71            }
72            ContentBlock::ToolUse { name, input, .. } => {
73                // Estimate the serialized JSON length without actually
74                // serializing: `needs_compaction` runs before every LLM call,
75                // so allocating a String per tool-use block on every round-trip
76                // is O(n^2) over a session. The recursive estimator also avoids
77                // the silent 0-byte underestimate that `to_string(..)
78                // .unwrap_or_default()` produced on a serialization failure.
79                let input_len = Self::estimate_json_len(input);
80                Self::estimate_text(name)
81                    + input_len.div_ceil(Self::CHARS_PER_TOKEN)
82                    + Self::TOOL_USE_OVERHEAD
83            }
84            ContentBlock::ToolResult { content, .. } => {
85                Self::estimate_text(content) + Self::TOOL_RESULT_OVERHEAD
86            }
87            ContentBlock::Image { source } | ContentBlock::Document { source } => {
88                // Rough estimate: base64 data is ~4/3 of original, 1 token per 4 chars
89                source.data.len() / 4 + Self::MESSAGE_OVERHEAD
90            }
91            // `ContentBlock` is `#[non_exhaustive]`; charge an unknown future
92            // block kind the per-message overhead as a conservative floor.
93            _ => Self::MESSAGE_OVERHEAD,
94        }
95    }
96
97    /// Estimate total tokens for a message history.
98    #[must_use]
99    pub fn estimate_history(messages: &[Message]) -> usize {
100        messages.iter().map(Self::estimate_message).sum()
101    }
102
103    /// Approximate the serialized-JSON byte length of a value without
104    /// allocating a serialized `String`.
105    ///
106    /// Mirrors `serde_json::to_string`'s output length closely enough for token
107    /// estimation: it sums key/string lengths, structural punctuation, and a
108    /// digit count for numbers. It is intentionally slightly conservative
109    /// (over-counts a trailing separator per element) since over-estimating
110    /// context size is safer than under-estimating it.
111    fn estimate_json_len(value: &serde_json::Value) -> usize {
112        match value {
113            serde_json::Value::Null => 4, // "null"
114            serde_json::Value::Bool(b) => {
115                if *b {
116                    4 // "true"
117                } else {
118                    5 // "false"
119                }
120            }
121            serde_json::Value::Number(n) => n.as_u64().map_or_else(
122                || {
123                    n.as_i64().map_or(
124                        // Floating point or arbitrary-precision: a fixed
125                        // estimate is fine for a token heuristic.
126                        8,
127                        |i| Self::decimal_digits(i.unsigned_abs()) + usize::from(i < 0),
128                    )
129                },
130                Self::decimal_digits,
131            ),
132            // String value plus surrounding quotes.
133            serde_json::Value::String(s) => s.len() + 2,
134            serde_json::Value::Array(items) => {
135                // Brackets plus a separator allowance per element.
136                2 + items
137                    .iter()
138                    .map(|item| Self::estimate_json_len(item) + 1)
139                    .sum::<usize>()
140            }
141            serde_json::Value::Object(entries) => {
142                // Braces plus key (quoted) + ':' + value + ',' per entry.
143                2 + entries
144                    .iter()
145                    .map(|(key, val)| key.len() + 2 + 1 + Self::estimate_json_len(val) + 1)
146                    .sum::<usize>()
147            }
148        }
149    }
150
151    /// Count the decimal digits in a `u64` without allocating.
152    const fn decimal_digits(mut n: u64) -> usize {
153        let mut digits = 1;
154        while n >= 10 {
155            n /= 10;
156            digits += 1;
157        }
158        digits
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::llm::Role;
166    use serde_json::json;
167
168    #[test]
169    fn test_estimate_text() {
170        // Empty text
171        assert_eq!(TokenEstimator::estimate_text(""), 0);
172
173        // Short text (less than 4 chars)
174        assert_eq!(TokenEstimator::estimate_text("hi"), 1);
175
176        // Exactly 4 chars
177        assert_eq!(TokenEstimator::estimate_text("test"), 1);
178
179        // 5 chars should be 2 tokens
180        assert_eq!(TokenEstimator::estimate_text("hello"), 2);
181
182        // Longer text
183        assert_eq!(TokenEstimator::estimate_text("hello world!"), 3); // 12 chars / 4 = 3
184    }
185
186    #[test]
187    fn test_estimate_text_message() {
188        let message = Message {
189            role: Role::User,
190            content: Content::Text("Hello, how are you?".to_string()), // 19 chars = 5 tokens
191        };
192
193        let estimate = TokenEstimator::estimate_message(&message);
194        // 5 content tokens + 4 overhead = 9
195        assert_eq!(estimate, 9);
196    }
197
198    #[test]
199    fn test_estimate_blocks_message() {
200        let message = Message {
201            role: Role::Assistant,
202            content: Content::Blocks(vec![
203                ContentBlock::Text {
204                    text: "Let me help.".to_string(), // 12 chars = 3 tokens
205                },
206                ContentBlock::ToolUse {
207                    id: "tool_123".to_string(),
208                    name: "read".to_string(),            // 4 chars = 1 token
209                    input: json!({"path": "/test.txt"}), // ~20 chars = 5 tokens
210                    thought_signature: None,
211                },
212            ]),
213        };
214
215        let estimate = TokenEstimator::estimate_message(&message);
216        // Text: 3 tokens
217        // ToolUse: 1 (name) + 5 (input) + 20 (overhead) = 26 tokens
218        // Message overhead: 4
219        // Total: 3 + 26 + 4 = 33
220        assert!(estimate > 25); // Verify it accounts for tool use
221    }
222
223    #[test]
224    fn test_estimate_tool_result() {
225        let message = Message {
226            role: Role::User,
227            content: Content::Blocks(vec![ContentBlock::ToolResult {
228                tool_use_id: "tool_123".to_string(),
229                content: "File contents here...".to_string(), // 21 chars = 6 tokens
230                is_error: None,
231            }]),
232        };
233
234        let estimate = TokenEstimator::estimate_message(&message);
235        // 6 content + 10 overhead + 4 message overhead = 20
236        assert_eq!(estimate, 20);
237    }
238
239    #[test]
240    fn test_estimate_history() {
241        let messages = vec![
242            Message::user("Hello"),          // 5 chars = 2 tokens + 4 overhead = 6
243            Message::assistant("Hi there!"), // 9 chars = 3 tokens + 4 overhead = 7
244            Message::user("How are you?"),   // 12 chars = 3 tokens + 4 overhead = 7
245        ];
246
247        let estimate = TokenEstimator::estimate_history(&messages);
248        assert_eq!(estimate, 20);
249    }
250
251    #[test]
252    fn test_empty_history() {
253        let messages: Vec<Message> = vec![];
254        assert_eq!(TokenEstimator::estimate_history(&messages), 0);
255    }
256
257    #[test]
258    fn test_estimate_redacted_thinking_uses_data_length() {
259        // Simulate a realistic redacted thinking blob (~8KB base64 data).
260        // 8192 base64 chars → ~6144 raw bytes → 6144/4 = 1536 estimated tokens.
261        let data = "A".repeat(8192);
262        let block = ContentBlock::RedactedThinking { data };
263
264        let estimate = TokenEstimator::estimate_block(&block);
265        assert_eq!(estimate, 1536);
266    }
267
268    #[test]
269    fn test_estimate_redacted_thinking_respects_minimum() {
270        // Tiny data blob: 100 base64 chars → ~75 raw bytes → 75/4 = 19 tokens.
271        // Should be clamped to the minimum (512).
272        let data = "A".repeat(100);
273        let block = ContentBlock::RedactedThinking { data };
274
275        let estimate = TokenEstimator::estimate_block(&block);
276        assert_eq!(estimate, TokenEstimator::REDACTED_THINKING_MIN_TOKENS);
277    }
278
279    #[test]
280    fn test_estimate_redacted_thinking_empty_data() {
281        // Empty data should return the minimum floor.
282        let block = ContentBlock::RedactedThinking {
283            data: String::new(),
284        };
285
286        let estimate = TokenEstimator::estimate_block(&block);
287        assert_eq!(estimate, TokenEstimator::REDACTED_THINKING_MIN_TOKENS);
288    }
289
290    #[test]
291    fn test_estimate_json_len_tracks_serialized_size() {
292        // The no-allocation estimator should track the real serialized length
293        // closely (within the per-element separator slack it intentionally adds).
294        for value in [
295            json!({"path": "/test.txt"}),
296            json!({"a": 1, "b": [1, 2, 3], "c": {"nested": true}}),
297            json!([null, false, "string", 12_345]),
298            json!("plain string"),
299            json!(9_876_543),
300        ] {
301            let estimated = TokenEstimator::estimate_json_len(&value);
302            let actual = serde_json::to_string(&value).map_or(0, |s| s.len());
303            assert!(
304                estimated >= actual,
305                "estimate {estimated} should be >= actual {actual} for {value}"
306            );
307            assert!(
308                estimated <= actual * 2 + 8,
309                "estimate {estimated} wildly exceeds actual {actual} for {value}"
310            );
311        }
312    }
313
314    #[test]
315    fn test_tool_use_estimate_is_nonzero_for_nonempty_input() {
316        // Regression: the old `to_string(..).unwrap_or_default()` path could
317        // silently produce a 0-length input estimate. The recursive estimator
318        // always accounts for the input.
319        let block = ContentBlock::ToolUse {
320            id: "tool_1".to_string(),
321            name: "bash".to_string(),
322            input: json!({"command": "echo hello world"}),
323            thought_signature: None,
324        };
325
326        let estimate = TokenEstimator::estimate_block(&block);
327        // name (1) + overhead (20) is 21; the input must add more on top.
328        assert!(
329            estimate > 21,
330            "input length must contribute to the estimate"
331        );
332    }
333
334    #[test]
335    fn test_redacted_thinking_accumulates_in_history() {
336        // 5 redacted thinking blocks at ~2000 tokens each should produce a
337        // meaningful total that triggers compaction.
338        let blocks: Vec<ContentBlock> = (0..5)
339            .map(|_| ContentBlock::RedactedThinking {
340                data: "B".repeat(10_000), // 10k base64 → 7500 raw → 1875 tokens
341            })
342            .collect();
343        let message = Message {
344            role: Role::Assistant,
345            content: Content::Blocks(blocks),
346        };
347
348        let estimate = TokenEstimator::estimate_message(&message);
349        // 5 × 1875 + 4 message overhead = 9379
350        assert_eq!(estimate, 9379);
351    }
352}