Skip to main content

ai_agent/utils/
tokens.rs

1// Source: ~/claudecode/openclaudecode/src/utils/tokens.ts
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TokenUsage {
6    pub input_tokens: u64,
7    pub output_tokens: u64,
8    #[serde(default)]
9    pub cache_creation_input_tokens: Option<u64>,
10    #[serde(default)]
11    pub cache_read_input_tokens: Option<u64>,
12    #[serde(default)]
13    pub iterations: Option<Vec<IterationUsage>>,
14}
15
16/// Per-iteration usage from the Anthropic API (server-side tool loops)
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct IterationUsage {
19    pub input_tokens: u64,
20    pub output_tokens: u64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25    pub msg_type: String,
26    pub message: InnerMessage,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct InnerMessage {
31    pub content: Vec<ContentBlock>,
32    pub usage: Option<TokenUsage>,
33    pub id: Option<String>,
34    pub model: Option<String>,
35    pub uuid: Option<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41    #[serde(rename = "text")]
42    Text { text: String },
43    #[serde(rename = "thinking")]
44    Thinking { thinking: String },
45    #[serde(rename = "redacted_thinking")]
46    RedactedThinking { data: String },
47    #[serde(rename = "tool_use")]
48    ToolUse {
49        input: serde_json::Value,
50        name: Option<String>,
51    },
52}
53
54const SYNTHETIC_MODEL: &str = "synthetic";
55
56pub fn get_token_usage(message: &Message) -> Option<&TokenUsage> {
57    if message.msg_type != "assistant" {
58        return None;
59    }
60
61    let usage = message.message.usage.as_ref()?;
62
63    if message.message.model.as_deref() == Some(SYNTHETIC_MODEL) {
64        return None;
65    }
66
67    if let Some(ContentBlock::Text { text }) = message.message.content.first() {
68        if text.contains("SYNTHETIC") {
69            return None;
70        }
71    }
72
73    Some(usage)
74}
75
76pub fn get_token_count_from_usage(usage: &TokenUsage) -> u32 {
77    let cache_creation = usage.cache_creation_input_tokens.unwrap_or(0);
78    let cache_read = usage.cache_read_input_tokens.unwrap_or(0);
79    (usage.input_tokens + cache_creation + cache_read + usage.output_tokens) as u32
80}
81
82/// Extract the message ID/UUID from an assistant message for sibling detection.
83pub fn get_assistant_message_id(message: &Message) -> Option<&str> {
84    if message.msg_type != "assistant" {
85        return None;
86    }
87    if let Some(ref id) = message.message.id {
88        return Some(id);
89    }
90    message.message.uuid.as_deref()
91}
92
93pub fn token_count_from_last_api_response(messages: &[Message]) -> u32 {
94    for message in messages.iter().rev() {
95        if let Some(usage) = get_token_usage(message) {
96            return get_token_count_from_usage(usage);
97        }
98    }
99    0
100}
101
102/// Final context window size from the last API response's usage.iterations[-1].
103/// Used for task_budget.remaining computation across compaction boundaries.
104/// Falls back to top-level input_tokens + output_tokens when iterations is absent.
105/// Excludes cache tokens to match server-side budget countdown.
106pub fn final_context_tokens_from_last_response(messages: &[Message]) -> u64 {
107    for message in messages.iter().rev() {
108        if let Some(usage) = get_token_usage(message) {
109            if let Some(ref iterations) = usage.iterations {
110                if !iterations.is_empty() {
111                    if let Some(last) = iterations.last() {
112                        return last.input_tokens + last.output_tokens;
113                    }
114                }
115            }
116            // No iterations → no server tool loop → top-level usage IS the final window
117            return usage.input_tokens + usage.output_tokens;
118        }
119    }
120    0
121}
122
123pub fn get_current_usage(messages: &[Message]) -> Option<TokenUsage> {
124    for message in messages.iter().rev() {
125        if let Some(usage) = get_token_usage(message) {
126            return Some(TokenUsage {
127                input_tokens: usage.input_tokens,
128                output_tokens: usage.output_tokens,
129                cache_creation_input_tokens: usage.cache_creation_input_tokens,
130                cache_read_input_tokens: usage.cache_read_input_tokens,
131                iterations: usage.iterations.clone(),
132            });
133        }
134    }
135    None
136}
137
138pub fn does_most_recent_assistant_message_exceed_200k(messages: &[Message]) -> bool {
139    const THRESHOLD: u32 = 200_000;
140
141    let last_asst = messages.iter().rev().find(|m| m.msg_type == "assistant");
142    let last_asst = match last_asst {
143        Some(m) => m,
144        None => return false,
145    };
146
147    match get_token_usage(last_asst) {
148        Some(usage) => get_token_count_from_usage(usage) > THRESHOLD,
149        None => false,
150    }
151}
152
153pub fn get_assistant_message_content_length(message: &Message) -> usize {
154    let mut content_length = 0;
155
156    for block in &message.message.content {
157        match block {
158            ContentBlock::Text { text } => content_length += text.len(),
159            ContentBlock::Thinking { thinking } => content_length += thinking.len(),
160            ContentBlock::RedactedThinking { data } => content_length += data.len(),
161            ContentBlock::ToolUse { input, .. } => {
162                content_length += serde_json::to_string(input).map(|s| s.len()).unwrap_or(0);
163            }
164        }
165    }
166
167    content_length
168}
169
170/// Rough token estimation for a slice of messages (4 chars per token).
171pub fn rough_token_count_estimation_for_messages(messages: &[Message]) -> u32 {
172    messages
173        .iter()
174        .map(|m| {
175            let total_chars: usize = m.message.content.iter().map(|b| match b {
176                ContentBlock::Text { text } => text.len(),
177                ContentBlock::Thinking { thinking } => thinking.len(),
178                ContentBlock::RedactedThinking { data } => data.len(),
179                ContentBlock::ToolUse { input, .. } => {
180                    serde_json::to_string(input).map(|s| s.len()).unwrap_or(0)
181                }
182            }).sum();
183            (total_chars as f64 / 4.0) as u32
184        })
185        .sum()
186}
187
188/// Token count with estimation for trailing messages that haven't seen an API response yet.
189/// Walks backward to find the last usage-bearing message, then walks back further to
190/// find any earlier sibling with the same message.id (parallel tool call splits).
191/// Returns usage count + rough estimate for all messages after the first sibling.
192pub fn token_count_with_estimation(messages: &[Message]) -> u32 {
193    let mut i = messages.len();
194    while i > 0 {
195        i -= 1;
196        let message = &messages[i];
197        if let Some(usage) = get_token_usage(message) {
198            // Walk back past any earlier sibling records split from the same API
199            // response (same message.id) so interleaved tool_results between them
200            // are included in the estimation slice.
201            if let Some(response_id) = get_assistant_message_id(message) {
202                let mut j = i;
203                while j > 0 {
204                    j -= 1;
205                    let prior = &messages[j];
206                    if let Some(prior_id) = get_assistant_message_id(prior) {
207                        if prior_id == response_id {
208                            i = j;
209                        } else {
210                            break;
211                        }
212                    }
213                    // priorId === undefined: user/tool_result/attachment, keep walking
214                }
215            }
216            let trailing = if i + 1 < messages.len() {
217                rough_token_count_estimation_for_messages(&messages[i + 1..])
218            } else {
219                0
220            };
221            return get_token_count_from_usage(usage) + trailing;
222        }
223    }
224    rough_token_count_estimation_for_messages(messages)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_token_count() {
233        let usage = TokenUsage {
234            input_tokens: 100,
235            output_tokens: 50,
236            cache_creation_input_tokens: Some(20),
237            cache_read_input_tokens: Some(30),
238            iterations: None,
239        };
240        assert_eq!(get_token_count_from_usage(&usage), 200);
241    }
242
243    #[test]
244    fn test_final_context_tokens_with_iterations() {
245        let msg = Message {
246            msg_type: "assistant".to_string(),
247            message: InnerMessage {
248                content: vec![],
249                usage: Some(TokenUsage {
250                    input_tokens: 1000,
251                    output_tokens: 500,
252                    cache_creation_input_tokens: Some(200),
253                    cache_read_input_tokens: Some(100),
254                    iterations: Some(vec![IterationUsage {
255                        input_tokens: 800,
256                        output_tokens: 400,
257                    }]),
258                }),
259                id: Some("msg-1".to_string()),
260                model: None,
261                uuid: None,
262            },
263        };
264        let tokens = final_context_tokens_from_last_response(&[msg]);
265        // Should use iterations[-1].input + output = 800 + 400 = 1200 (no cache)
266        assert_eq!(tokens, 1200);
267    }
268
269    #[test]
270    fn test_final_context_tokens_without_iterations() {
271        let msg = Message {
272            msg_type: "assistant".to_string(),
273            message: InnerMessage {
274                content: vec![],
275                usage: Some(TokenUsage {
276                    input_tokens: 1000,
277                    output_tokens: 500,
278                    cache_creation_input_tokens: Some(200),
279                    cache_read_input_tokens: Some(100),
280                    iterations: None,
281                }),
282                id: Some("msg-1".to_string()),
283                model: None,
284                uuid: None,
285            },
286        };
287        let tokens = final_context_tokens_from_last_response(&[msg]);
288        // Should use input + output = 1500 (no cache, no iterations)
289        assert_eq!(tokens, 1500);
290    }
291
292    #[test]
293    fn test_token_count_with_estimation_basic() {
294        let usage = TokenUsage {
295            input_tokens: 100,
296            output_tokens: 50,
297            cache_creation_input_tokens: None,
298            cache_read_input_tokens: None,
299            iterations: None,
300        };
301        let msg = Message {
302            msg_type: "assistant".to_string(),
303            message: InnerMessage {
304                content: vec![ContentBlock::Text { text: "hello".to_string() }],
305                usage: Some(usage),
306                id: Some("msg-1".to_string()),
307                model: None,
308                uuid: None,
309            },
310        };
311        let count = token_count_with_estimation(&[msg.clone()]);
312        assert_eq!(count, 150);
313    }
314
315    #[test]
316    fn test_rough_token_estimation_for_messages() {
317        let msg = Message {
318            msg_type: "user".to_string(),
319            message: InnerMessage {
320                content: vec![ContentBlock::Text { text: "Hello world".to_string() }],
321                usage: None,
322                id: None,
323                model: None,
324                uuid: None,
325            },
326        };
327        // "Hello world" = 11 chars / 4 = 2.75 → 2
328        let est = rough_token_count_estimation_for_messages(&[msg]);
329        assert!(est >= 2 && est <= 3);
330    }
331}