agent_sdk/context/estimator.rs
1//! Token estimation for context size calculation.
2
3use crate::llm::{Content, ContentBlock, Message};
4
5/// Estimates token count for messages.
6///
7/// Uses a simple heuristic of ~4 characters per token, which provides
8/// a reasonable approximation for most English text and code.
9///
10/// For more accurate counting, consider using a tokenizer library
11/// specific to your model (e.g., tiktoken for `OpenAI` models).
12pub struct TokenEstimator;
13
14impl TokenEstimator {
15 /// Characters per token estimate.
16 /// This is a conservative estimate; actual ratio varies by content.
17 const CHARS_PER_TOKEN: usize = 4;
18
19 /// Overhead tokens per message (role, formatting).
20 const MESSAGE_OVERHEAD: usize = 4;
21
22 /// Overhead for tool use blocks (id, name, formatting).
23 const TOOL_USE_OVERHEAD: usize = 20;
24
25 /// Overhead for tool result blocks (id, formatting).
26 const TOOL_RESULT_OVERHEAD: usize = 10;
27
28 /// Minimum token estimate for redacted thinking blocks.
29 ///
30 /// Even small redacted thinking blocks carry significant API token cost
31 /// because they contain encrypted reasoning that the model must process.
32 const REDACTED_THINKING_MIN_TOKENS: usize = 512;
33
34 /// Estimate tokens for a text string.
35 #[must_use]
36 pub const fn estimate_text(text: &str) -> usize {
37 // Simple estimation: ~4 chars per token
38 text.len().div_ceil(Self::CHARS_PER_TOKEN)
39 }
40
41 /// Estimate tokens for a single message.
42 #[must_use]
43 pub fn estimate_message(message: &Message) -> usize {
44 let content_tokens = match &message.content {
45 Content::Text(text) => Self::estimate_text(text),
46 Content::Blocks(blocks) => blocks.iter().map(Self::estimate_block).sum(),
47 };
48
49 content_tokens + Self::MESSAGE_OVERHEAD
50 }
51
52 /// Estimate tokens for a content block.
53 #[must_use]
54 pub fn estimate_block(block: &ContentBlock) -> usize {
55 match block {
56 ContentBlock::Text { text } => Self::estimate_text(text),
57 ContentBlock::Thinking { thinking, .. } => Self::estimate_text(thinking),
58 ContentBlock::RedactedThinking { data } => {
59 // The data field is a base64-encoded encrypted blob whose size
60 // correlates with the original thinking content. Base64 encodes
61 // 3 bytes into 4 chars, so `data.len() * 3 / 4` approximates
62 // the raw byte count. Using the same chars-per-token heuristic
63 // on the raw bytes gives a reasonable lower bound.
64 //
65 // A floor of REDACTED_THINKING_MIN_TOKENS prevents tiny blocks
66 // from being under-counted — the API charges substantial token
67 // overhead for every redacted thinking block regardless of size.
68 let raw_bytes = data.len() * 3 / 4;
69 let estimated = raw_bytes.div_ceil(Self::CHARS_PER_TOKEN);
70 estimated.max(Self::REDACTED_THINKING_MIN_TOKENS)
71 }
72 ContentBlock::ToolUse { name, input, .. } => {
73 // Estimate the serialized JSON length without actually
74 // serializing: `needs_compaction` runs before every LLM call,
75 // so allocating a String per tool-use block on every round-trip
76 // is O(n^2) over a session. The recursive estimator also avoids
77 // the silent 0-byte underestimate that `to_string(..)
78 // .unwrap_or_default()` produced on a serialization failure.
79 let input_len = Self::estimate_json_len(input);
80 Self::estimate_text(name)
81 + input_len.div_ceil(Self::CHARS_PER_TOKEN)
82 + Self::TOOL_USE_OVERHEAD
83 }
84 ContentBlock::ToolResult { content, .. } => {
85 Self::estimate_text(content) + Self::TOOL_RESULT_OVERHEAD
86 }
87 ContentBlock::Image { source } | ContentBlock::Document { source } => {
88 // Rough estimate: base64 data is ~4/3 of original, 1 token per 4 chars
89 source.data.len() / 4 + Self::MESSAGE_OVERHEAD
90 }
91 // `ContentBlock` is `#[non_exhaustive]`; charge an unknown future
92 // block kind the per-message overhead as a conservative floor.
93 _ => Self::MESSAGE_OVERHEAD,
94 }
95 }
96
97 /// Estimate total tokens for a message history.
98 #[must_use]
99 pub fn estimate_history(messages: &[Message]) -> usize {
100 messages.iter().map(Self::estimate_message).sum()
101 }
102
103 /// Approximate the serialized-JSON byte length of a value without
104 /// allocating a serialized `String`.
105 ///
106 /// Mirrors `serde_json::to_string`'s output length closely enough for token
107 /// estimation: it sums key/string lengths, structural punctuation, and a
108 /// digit count for numbers. It is intentionally slightly conservative
109 /// (over-counts a trailing separator per element) since over-estimating
110 /// context size is safer than under-estimating it.
111 fn estimate_json_len(value: &serde_json::Value) -> usize {
112 match value {
113 serde_json::Value::Null => 4, // "null"
114 serde_json::Value::Bool(b) => {
115 if *b {
116 4 // "true"
117 } else {
118 5 // "false"
119 }
120 }
121 serde_json::Value::Number(n) => n.as_u64().map_or_else(
122 || {
123 n.as_i64().map_or(
124 // Floating point or arbitrary-precision: a fixed
125 // estimate is fine for a token heuristic.
126 8,
127 |i| Self::decimal_digits(i.unsigned_abs()) + usize::from(i < 0),
128 )
129 },
130 Self::decimal_digits,
131 ),
132 // String value plus surrounding quotes.
133 serde_json::Value::String(s) => s.len() + 2,
134 serde_json::Value::Array(items) => {
135 // Brackets plus a separator allowance per element.
136 2 + items
137 .iter()
138 .map(|item| Self::estimate_json_len(item) + 1)
139 .sum::<usize>()
140 }
141 serde_json::Value::Object(entries) => {
142 // Braces plus key (quoted) + ':' + value + ',' per entry.
143 2 + entries
144 .iter()
145 .map(|(key, val)| key.len() + 2 + 1 + Self::estimate_json_len(val) + 1)
146 .sum::<usize>()
147 }
148 }
149 }
150
151 /// Count the decimal digits in a `u64` without allocating.
152 const fn decimal_digits(mut n: u64) -> usize {
153 let mut digits = 1;
154 while n >= 10 {
155 n /= 10;
156 digits += 1;
157 }
158 digits
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::llm::Role;
166 use serde_json::json;
167
168 #[test]
169 fn test_estimate_text() {
170 // Empty text
171 assert_eq!(TokenEstimator::estimate_text(""), 0);
172
173 // Short text (less than 4 chars)
174 assert_eq!(TokenEstimator::estimate_text("hi"), 1);
175
176 // Exactly 4 chars
177 assert_eq!(TokenEstimator::estimate_text("test"), 1);
178
179 // 5 chars should be 2 tokens
180 assert_eq!(TokenEstimator::estimate_text("hello"), 2);
181
182 // Longer text
183 assert_eq!(TokenEstimator::estimate_text("hello world!"), 3); // 12 chars / 4 = 3
184 }
185
186 #[test]
187 fn test_estimate_text_message() {
188 let message = Message {
189 role: Role::User,
190 content: Content::Text("Hello, how are you?".to_string()), // 19 chars = 5 tokens
191 };
192
193 let estimate = TokenEstimator::estimate_message(&message);
194 // 5 content tokens + 4 overhead = 9
195 assert_eq!(estimate, 9);
196 }
197
198 #[test]
199 fn test_estimate_blocks_message() {
200 let message = Message {
201 role: Role::Assistant,
202 content: Content::Blocks(vec![
203 ContentBlock::Text {
204 text: "Let me help.".to_string(), // 12 chars = 3 tokens
205 },
206 ContentBlock::ToolUse {
207 id: "tool_123".to_string(),
208 name: "read".to_string(), // 4 chars = 1 token
209 input: json!({"path": "/test.txt"}), // ~20 chars = 5 tokens
210 thought_signature: None,
211 },
212 ]),
213 };
214
215 let estimate = TokenEstimator::estimate_message(&message);
216 // Text: 3 tokens
217 // ToolUse: 1 (name) + 5 (input) + 20 (overhead) = 26 tokens
218 // Message overhead: 4
219 // Total: 3 + 26 + 4 = 33
220 assert!(estimate > 25); // Verify it accounts for tool use
221 }
222
223 #[test]
224 fn test_estimate_tool_result() {
225 let message = Message {
226 role: Role::User,
227 content: Content::Blocks(vec![ContentBlock::ToolResult {
228 tool_use_id: "tool_123".to_string(),
229 content: "File contents here...".to_string(), // 21 chars = 6 tokens
230 is_error: None,
231 }]),
232 };
233
234 let estimate = TokenEstimator::estimate_message(&message);
235 // 6 content + 10 overhead + 4 message overhead = 20
236 assert_eq!(estimate, 20);
237 }
238
239 #[test]
240 fn test_estimate_history() {
241 let messages = vec![
242 Message::user("Hello"), // 5 chars = 2 tokens + 4 overhead = 6
243 Message::assistant("Hi there!"), // 9 chars = 3 tokens + 4 overhead = 7
244 Message::user("How are you?"), // 12 chars = 3 tokens + 4 overhead = 7
245 ];
246
247 let estimate = TokenEstimator::estimate_history(&messages);
248 assert_eq!(estimate, 20);
249 }
250
251 #[test]
252 fn test_empty_history() {
253 let messages: Vec<Message> = vec![];
254 assert_eq!(TokenEstimator::estimate_history(&messages), 0);
255 }
256
257 #[test]
258 fn test_estimate_redacted_thinking_uses_data_length() {
259 // Simulate a realistic redacted thinking blob (~8KB base64 data).
260 // 8192 base64 chars → ~6144 raw bytes → 6144/4 = 1536 estimated tokens.
261 let data = "A".repeat(8192);
262 let block = ContentBlock::RedactedThinking { data };
263
264 let estimate = TokenEstimator::estimate_block(&block);
265 assert_eq!(estimate, 1536);
266 }
267
268 #[test]
269 fn test_estimate_redacted_thinking_respects_minimum() {
270 // Tiny data blob: 100 base64 chars → ~75 raw bytes → 75/4 = 19 tokens.
271 // Should be clamped to the minimum (512).
272 let data = "A".repeat(100);
273 let block = ContentBlock::RedactedThinking { data };
274
275 let estimate = TokenEstimator::estimate_block(&block);
276 assert_eq!(estimate, TokenEstimator::REDACTED_THINKING_MIN_TOKENS);
277 }
278
279 #[test]
280 fn test_estimate_redacted_thinking_empty_data() {
281 // Empty data should return the minimum floor.
282 let block = ContentBlock::RedactedThinking {
283 data: String::new(),
284 };
285
286 let estimate = TokenEstimator::estimate_block(&block);
287 assert_eq!(estimate, TokenEstimator::REDACTED_THINKING_MIN_TOKENS);
288 }
289
290 #[test]
291 fn test_estimate_json_len_tracks_serialized_size() {
292 // The no-allocation estimator should track the real serialized length
293 // closely (within the per-element separator slack it intentionally adds).
294 for value in [
295 json!({"path": "/test.txt"}),
296 json!({"a": 1, "b": [1, 2, 3], "c": {"nested": true}}),
297 json!([null, false, "string", 12_345]),
298 json!("plain string"),
299 json!(9_876_543),
300 ] {
301 let estimated = TokenEstimator::estimate_json_len(&value);
302 let actual = serde_json::to_string(&value).map_or(0, |s| s.len());
303 assert!(
304 estimated >= actual,
305 "estimate {estimated} should be >= actual {actual} for {value}"
306 );
307 assert!(
308 estimated <= actual * 2 + 8,
309 "estimate {estimated} wildly exceeds actual {actual} for {value}"
310 );
311 }
312 }
313
314 #[test]
315 fn test_tool_use_estimate_is_nonzero_for_nonempty_input() {
316 // Regression: the old `to_string(..).unwrap_or_default()` path could
317 // silently produce a 0-length input estimate. The recursive estimator
318 // always accounts for the input.
319 let block = ContentBlock::ToolUse {
320 id: "tool_1".to_string(),
321 name: "bash".to_string(),
322 input: json!({"command": "echo hello world"}),
323 thought_signature: None,
324 };
325
326 let estimate = TokenEstimator::estimate_block(&block);
327 // name (1) + overhead (20) is 21; the input must add more on top.
328 assert!(
329 estimate > 21,
330 "input length must contribute to the estimate"
331 );
332 }
333
334 #[test]
335 fn test_redacted_thinking_accumulates_in_history() {
336 // 5 redacted thinking blocks at ~2000 tokens each should produce a
337 // meaningful total that triggers compaction.
338 let blocks: Vec<ContentBlock> = (0..5)
339 .map(|_| ContentBlock::RedactedThinking {
340 data: "B".repeat(10_000), // 10k base64 → 7500 raw → 1875 tokens
341 })
342 .collect();
343 let message = Message {
344 role: Role::Assistant,
345 content: Content::Blocks(blocks),
346 };
347
348 let estimate = TokenEstimator::estimate_message(&message);
349 // 5 × 1875 + 4 message overhead = 9379
350 assert_eq!(estimate, 9379);
351 }
352}