Skip to main content

agent_code_lib/services/
tokens.rs

1//! Token estimation.
2//!
3//! Estimates token counts for messages and content blocks using a
4//! character-based heuristic. Uses actual API usage data when
5//! available, falling back to rough estimation for new messages.
6//!
7//! Default ratio: 4 bytes per token (conservative for most content).
8
9use crate::llm::message::{ContentBlock, Message};
10
11/// Default bytes per token for estimation.
12const BYTES_PER_TOKEN: f64 = 4.0;
13
14/// Fixed token estimate for image content blocks.
15const IMAGE_TOKEN_ESTIMATE: u64 = 2000;
16
17/// Estimate token count from a string.
18pub fn estimate_tokens(content: &str) -> u64 {
19    (content.len() as f64 / BYTES_PER_TOKEN).round() as u64
20}
21
22/// Estimate tokens for a single content block.
23pub fn estimate_block_tokens(block: &ContentBlock) -> u64 {
24    match block {
25        ContentBlock::Text { text } => estimate_tokens(text),
26        ContentBlock::ToolUse { name, input, .. } => {
27            let input_str = serde_json::to_string(input).unwrap_or_default();
28            estimate_tokens(name) + estimate_tokens(&input_str)
29        }
30        ContentBlock::ToolResult { content, .. } => estimate_tokens(content),
31        ContentBlock::Thinking { thinking, .. } => estimate_tokens(thinking),
32        ContentBlock::Image { .. } => IMAGE_TOKEN_ESTIMATE,
33        ContentBlock::Document { data, .. } => {
34            // Base64-encoded documents: estimate from decoded size.
35            let decoded_size = (data.len() as f64 * 0.75) as u64;
36            (decoded_size as f64 / BYTES_PER_TOKEN).round() as u64
37        }
38    }
39}
40
41/// Estimate tokens for a single message.
42pub fn estimate_message_tokens(msg: &Message) -> u64 {
43    match msg {
44        Message::User(u) => {
45            // Per-message overhead (role, formatting).
46            let overhead = 4;
47            let content: u64 = u.content.iter().map(estimate_block_tokens).sum();
48            overhead + content
49        }
50        Message::Assistant(a) => {
51            let overhead = 4;
52            let content: u64 = a.content.iter().map(estimate_block_tokens).sum();
53            overhead + content
54        }
55        Message::System(s) => {
56            let overhead = 4;
57            overhead + estimate_tokens(&s.content)
58        }
59    }
60}
61
62/// Estimate total context tokens for a message history.
63///
64/// Uses a hybrid approach: actual API usage counts for the most
65/// recent assistant response, plus rough estimation for any
66/// messages added after that point.
67pub fn estimate_context_tokens(messages: &[Message]) -> u64 {
68    if messages.is_empty() {
69        return 0;
70    }
71
72    // Find the most recent assistant message with usage data.
73    let mut last_usage_idx = None;
74    for (i, msg) in messages.iter().enumerate().rev() {
75        if let Message::Assistant(a) = msg
76            && a.usage.is_some()
77        {
78            last_usage_idx = Some(i);
79            break;
80        }
81    }
82
83    match last_usage_idx {
84        Some(idx) => {
85            // Use actual API token count up to this point.
86            let usage = messages[idx]
87                .as_assistant()
88                .and_then(|a| a.usage.as_ref())
89                .unwrap();
90            let api_tokens = usage.total();
91
92            // Estimate tokens for messages added after the API call.
93            let new_tokens: u64 = messages[idx + 1..]
94                .iter()
95                .map(estimate_message_tokens)
96                .sum();
97
98            api_tokens + new_tokens
99        }
100        None => {
101            // No API usage data — estimate everything.
102            messages.iter().map(estimate_message_tokens).sum()
103        }
104    }
105}
106
107/// Get the context window size for a model.
108pub fn context_window_for_model(model: &str) -> u64 {
109    let lower = model.to_lowercase();
110
111    // Check for extended context variants first.
112    if lower.contains("1m") || lower.contains("1000k") {
113        return 1_000_000;
114    }
115
116    if lower.contains("opus") || lower.contains("sonnet") || lower.contains("haiku") {
117        200_000
118    } else if lower.contains("gpt-4") {
119        128_000
120    } else if lower.contains("gpt-3.5") {
121        16_384
122    } else {
123        128_000
124    }
125}
126
127/// Get the max output tokens for a model.
128pub fn max_output_tokens_for_model(model: &str) -> u64 {
129    let lower = model.to_lowercase();
130    if lower.contains("opus") || lower.contains("sonnet") {
131        16_384
132    } else if lower.contains("haiku") {
133        8_192
134    } else {
135        16_384
136    }
137}
138
139/// Get the maximum thinking token budget for a model.
140pub fn max_thinking_tokens_for_model(model: &str) -> u64 {
141    let lower = model.to_lowercase();
142    if lower.contains("opus") {
143        32_000
144    } else if lower.contains("sonnet") {
145        16_000
146    } else if lower.contains("haiku") {
147        8_000
148    } else {
149        16_000
150    }
151}
152
153// Helper to extract assistant message ref.
154trait AsAssistant {
155    fn as_assistant(&self) -> Option<&crate::llm::message::AssistantMessage>;
156}
157
158impl AsAssistant for Message {
159    fn as_assistant(&self) -> Option<&crate::llm::message::AssistantMessage> {
160        match self {
161            Message::Assistant(a) => Some(a),
162            _ => None,
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_estimate_tokens() {
173        // 100 chars / 4 = 25 tokens.
174        let text = "a".repeat(100);
175        assert_eq!(estimate_tokens(&text), 25);
176    }
177
178    #[test]
179    fn test_empty_messages() {
180        assert_eq!(estimate_context_tokens(&[]), 0);
181    }
182
183    #[test]
184    fn test_estimate_block_tokens_text() {
185        let block = ContentBlock::Text {
186            text: "a".repeat(400),
187        };
188        assert_eq!(estimate_block_tokens(&block), 100);
189    }
190
191    #[test]
192    fn test_estimate_block_tokens_image() {
193        let block = ContentBlock::Image {
194            media_type: "image/png".into(),
195            data: "base64data".into(),
196        };
197        assert_eq!(estimate_block_tokens(&block), IMAGE_TOKEN_ESTIMATE);
198    }
199
200    #[test]
201    fn test_estimate_block_tokens_tool_use() {
202        let block = ContentBlock::ToolUse {
203            id: "call_1".into(),
204            name: "Bash".into(),
205            input: serde_json::json!({"command": "ls"}),
206        };
207        let tokens = estimate_block_tokens(&block);
208        assert!(tokens > 0);
209    }
210
211    #[test]
212    fn test_estimate_message_tokens() {
213        let msg = crate::llm::message::user_message("hello world");
214        let tokens = estimate_message_tokens(&msg);
215        // 11 chars / 4 = ~3, + 4 overhead = ~7
216        assert!(tokens >= 5);
217    }
218
219    #[test]
220    fn test_context_window_for_model() {
221        assert_eq!(context_window_for_model("claude-opus-4"), 200_000);
222        assert_eq!(context_window_for_model("claude-sonnet-4"), 200_000);
223        assert_eq!(context_window_for_model("gpt-4"), 128_000);
224        assert_eq!(context_window_for_model("claude-sonnet-1m"), 1_000_000);
225    }
226
227    #[test]
228    fn test_max_output_tokens() {
229        assert_eq!(max_output_tokens_for_model("claude-opus"), 16_384);
230        assert_eq!(max_output_tokens_for_model("claude-haiku"), 8_192);
231    }
232
233    #[test]
234    fn test_max_thinking_tokens() {
235        assert_eq!(max_thinking_tokens_for_model("claude-opus"), 32_000);
236        assert_eq!(max_thinking_tokens_for_model("claude-sonnet"), 16_000);
237        assert_eq!(max_thinking_tokens_for_model("claude-haiku"), 8_000);
238    }
239
240    #[test]
241    fn test_estimate_tokens_empty_string() {
242        assert_eq!(estimate_tokens(""), 0);
243    }
244
245    #[test]
246    fn test_estimate_tokens_unicode() {
247        // Multi-byte chars: each char may be 2-4 bytes in UTF-8.
248        let text = "\u{1F600}\u{1F600}\u{1F600}"; // 3 emoji, 4 bytes each = 12 bytes
249        let tokens = estimate_tokens(text);
250        // 12 / 4 = 3
251        assert_eq!(tokens, 3);
252    }
253
254    #[test]
255    fn test_estimate_block_tokens_document() {
256        let block = ContentBlock::Document {
257            media_type: "application/pdf".into(),
258            data: "a".repeat(400), // 400 base64 chars -> ~300 decoded bytes -> 300/4 = 75 tokens
259            title: Some("test.pdf".into()),
260        };
261        let tokens = estimate_block_tokens(&block);
262        assert!(tokens > 0);
263        assert_eq!(tokens, 75);
264    }
265
266    #[test]
267    fn test_estimate_block_tokens_thinking() {
268        let block = ContentBlock::Thinking {
269            thinking: "a".repeat(200),
270            signature: Some("sig".into()),
271        };
272        let tokens = estimate_block_tokens(&block);
273        // 200 / 4 = 50
274        assert_eq!(tokens, 50);
275    }
276
277    #[test]
278    fn test_estimate_block_tokens_tool_result() {
279        let block = ContentBlock::ToolResult {
280            tool_use_id: "call_1".into(),
281            content: "a".repeat(80),
282            is_error: false,
283            extra_content: vec![],
284        };
285        let tokens = estimate_block_tokens(&block);
286        // 80 / 4 = 20
287        assert_eq!(tokens, 20);
288    }
289
290    #[test]
291    fn test_estimate_message_tokens_system() {
292        let msg = Message::System(crate::llm::message::SystemMessage {
293            uuid: uuid::Uuid::new_v4(),
294            timestamp: String::new(),
295            subtype: crate::llm::message::SystemMessageType::Informational,
296            content: "a".repeat(40),
297            level: crate::llm::message::MessageLevel::Info,
298        });
299        let tokens = estimate_message_tokens(&msg);
300        // 40/4 = 10 + 4 overhead = 14
301        assert_eq!(tokens, 14);
302    }
303
304    #[test]
305    fn test_estimate_message_tokens_assistant_with_tool_use() {
306        let msg = Message::Assistant(crate::llm::message::AssistantMessage {
307            uuid: uuid::Uuid::new_v4(),
308            timestamp: String::new(),
309            content: vec![
310                ContentBlock::Text {
311                    text: "Let me run that.".into(),
312                },
313                ContentBlock::ToolUse {
314                    id: "call_1".into(),
315                    name: "Bash".into(),
316                    input: serde_json::json!({"command": "ls"}),
317                },
318            ],
319            model: None,
320            usage: None,
321            stop_reason: None,
322            request_id: None,
323        });
324        let tokens = estimate_message_tokens(&msg);
325        // Should include overhead + text tokens + tool_use tokens
326        assert!(tokens > 4);
327    }
328
329    #[test]
330    fn test_estimate_context_tokens_only_user_messages() {
331        let messages = vec![
332            crate::llm::message::user_message("hello world"),
333            crate::llm::message::user_message("how are you"),
334        ];
335        let tokens = estimate_context_tokens(&messages);
336        // No usage data, so everything is estimated.
337        assert!(tokens > 0);
338    }
339
340    #[test]
341    fn test_context_window_for_gpt35() {
342        assert_eq!(context_window_for_model("gpt-3.5-turbo"), 16_384);
343    }
344
345    #[test]
346    fn test_context_window_for_unknown_model() {
347        // Unknown models default to 128K.
348        assert_eq!(context_window_for_model("some-unknown-model"), 128_000);
349    }
350
351    #[test]
352    fn test_context_window_for_1000k_variant() {
353        assert_eq!(context_window_for_model("claude-sonnet-1000k"), 1_000_000);
354    }
355
356    #[test]
357    fn test_max_output_tokens_for_unknown_model() {
358        // Unknown models default to 16384.
359        assert_eq!(max_output_tokens_for_model("unknown-llm"), 16_384);
360    }
361
362    #[test]
363    fn test_max_thinking_tokens_for_unknown_model() {
364        // Unknown models default to 16000.
365        assert_eq!(max_thinking_tokens_for_model("unknown-llm"), 16_000);
366    }
367}