1use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TokenUsage {
6 pub input_tokens: u64,
7 pub output_tokens: u64,
8 #[serde(default)]
9 pub cache_creation_input_tokens: Option<u64>,
10 #[serde(default)]
11 pub cache_read_input_tokens: Option<u64>,
12 #[serde(default)]
13 pub iterations: Option<Vec<IterationUsage>>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct IterationUsage {
19 pub input_tokens: u64,
20 pub output_tokens: u64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25 pub msg_type: String,
26 pub message: InnerMessage,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct InnerMessage {
31 pub content: Vec<ContentBlock>,
32 pub usage: Option<TokenUsage>,
33 pub id: Option<String>,
34 pub model: Option<String>,
35 pub uuid: Option<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41 #[serde(rename = "text")]
42 Text { text: String },
43 #[serde(rename = "thinking")]
44 Thinking { thinking: String },
45 #[serde(rename = "redacted_thinking")]
46 RedactedThinking { data: String },
47 #[serde(rename = "tool_use")]
48 ToolUse {
49 input: serde_json::Value,
50 name: Option<String>,
51 },
52}
53
54const SYNTHETIC_MODEL: &str = "synthetic";
55
56pub fn get_token_usage(message: &Message) -> Option<&TokenUsage> {
57 if message.msg_type != "assistant" {
58 return None;
59 }
60
61 let usage = message.message.usage.as_ref()?;
62
63 if message.message.model.as_deref() == Some(SYNTHETIC_MODEL) {
64 return None;
65 }
66
67 if let Some(ContentBlock::Text { text }) = message.message.content.first() {
68 if text.contains("SYNTHETIC") {
69 return None;
70 }
71 }
72
73 Some(usage)
74}
75
76pub fn get_token_count_from_usage(usage: &TokenUsage) -> u32 {
77 let cache_creation = usage.cache_creation_input_tokens.unwrap_or(0);
78 let cache_read = usage.cache_read_input_tokens.unwrap_or(0);
79 (usage.input_tokens + cache_creation + cache_read + usage.output_tokens) as u32
80}
81
82pub fn get_assistant_message_id(message: &Message) -> Option<&str> {
84 if message.msg_type != "assistant" {
85 return None;
86 }
87 if let Some(ref id) = message.message.id {
88 return Some(id);
89 }
90 message.message.uuid.as_deref()
91}
92
93pub fn token_count_from_last_api_response(messages: &[Message]) -> u32 {
94 for message in messages.iter().rev() {
95 if let Some(usage) = get_token_usage(message) {
96 return get_token_count_from_usage(usage);
97 }
98 }
99 0
100}
101
102pub fn final_context_tokens_from_last_response(messages: &[Message]) -> u64 {
107 for message in messages.iter().rev() {
108 if let Some(usage) = get_token_usage(message) {
109 if let Some(ref iterations) = usage.iterations {
110 if !iterations.is_empty() {
111 if let Some(last) = iterations.last() {
112 return last.input_tokens + last.output_tokens;
113 }
114 }
115 }
116 return usage.input_tokens + usage.output_tokens;
118 }
119 }
120 0
121}
122
123pub fn get_current_usage(messages: &[Message]) -> Option<TokenUsage> {
124 for message in messages.iter().rev() {
125 if let Some(usage) = get_token_usage(message) {
126 return Some(TokenUsage {
127 input_tokens: usage.input_tokens,
128 output_tokens: usage.output_tokens,
129 cache_creation_input_tokens: usage.cache_creation_input_tokens,
130 cache_read_input_tokens: usage.cache_read_input_tokens,
131 iterations: usage.iterations.clone(),
132 });
133 }
134 }
135 None
136}
137
138pub fn does_most_recent_assistant_message_exceed_200k(messages: &[Message]) -> bool {
139 const THRESHOLD: u32 = 200_000;
140
141 let last_asst = messages.iter().rev().find(|m| m.msg_type == "assistant");
142 let last_asst = match last_asst {
143 Some(m) => m,
144 None => return false,
145 };
146
147 match get_token_usage(last_asst) {
148 Some(usage) => get_token_count_from_usage(usage) > THRESHOLD,
149 None => false,
150 }
151}
152
153pub fn get_assistant_message_content_length(message: &Message) -> usize {
154 let mut content_length = 0;
155
156 for block in &message.message.content {
157 match block {
158 ContentBlock::Text { text } => content_length += text.len(),
159 ContentBlock::Thinking { thinking } => content_length += thinking.len(),
160 ContentBlock::RedactedThinking { data } => content_length += data.len(),
161 ContentBlock::ToolUse { input, .. } => {
162 content_length += serde_json::to_string(input).map(|s| s.len()).unwrap_or(0);
163 }
164 }
165 }
166
167 content_length
168}
169
170pub fn rough_token_count_estimation_for_messages(messages: &[Message]) -> u32 {
172 messages
173 .iter()
174 .map(|m| {
175 let total_chars: usize = m.message.content.iter().map(|b| match b {
176 ContentBlock::Text { text } => text.len(),
177 ContentBlock::Thinking { thinking } => thinking.len(),
178 ContentBlock::RedactedThinking { data } => data.len(),
179 ContentBlock::ToolUse { input, .. } => {
180 serde_json::to_string(input).map(|s| s.len()).unwrap_or(0)
181 }
182 }).sum();
183 (total_chars as f64 / 4.0) as u32
184 })
185 .sum()
186}
187
188pub fn token_count_with_estimation(messages: &[Message]) -> u32 {
193 let mut i = messages.len();
194 while i > 0 {
195 i -= 1;
196 let message = &messages[i];
197 if let Some(usage) = get_token_usage(message) {
198 if let Some(response_id) = get_assistant_message_id(message) {
202 let mut j = i;
203 while j > 0 {
204 j -= 1;
205 let prior = &messages[j];
206 if let Some(prior_id) = get_assistant_message_id(prior) {
207 if prior_id == response_id {
208 i = j;
209 } else {
210 break;
211 }
212 }
213 }
215 }
216 let trailing = if i + 1 < messages.len() {
217 rough_token_count_estimation_for_messages(&messages[i + 1..])
218 } else {
219 0
220 };
221 return get_token_count_from_usage(usage) + trailing;
222 }
223 }
224 rough_token_count_estimation_for_messages(messages)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_token_count() {
233 let usage = TokenUsage {
234 input_tokens: 100,
235 output_tokens: 50,
236 cache_creation_input_tokens: Some(20),
237 cache_read_input_tokens: Some(30),
238 iterations: None,
239 };
240 assert_eq!(get_token_count_from_usage(&usage), 200);
241 }
242
243 #[test]
244 fn test_final_context_tokens_with_iterations() {
245 let msg = Message {
246 msg_type: "assistant".to_string(),
247 message: InnerMessage {
248 content: vec![],
249 usage: Some(TokenUsage {
250 input_tokens: 1000,
251 output_tokens: 500,
252 cache_creation_input_tokens: Some(200),
253 cache_read_input_tokens: Some(100),
254 iterations: Some(vec![IterationUsage {
255 input_tokens: 800,
256 output_tokens: 400,
257 }]),
258 }),
259 id: Some("msg-1".to_string()),
260 model: None,
261 uuid: None,
262 },
263 };
264 let tokens = final_context_tokens_from_last_response(&[msg]);
265 assert_eq!(tokens, 1200);
267 }
268
269 #[test]
270 fn test_final_context_tokens_without_iterations() {
271 let msg = Message {
272 msg_type: "assistant".to_string(),
273 message: InnerMessage {
274 content: vec![],
275 usage: Some(TokenUsage {
276 input_tokens: 1000,
277 output_tokens: 500,
278 cache_creation_input_tokens: Some(200),
279 cache_read_input_tokens: Some(100),
280 iterations: None,
281 }),
282 id: Some("msg-1".to_string()),
283 model: None,
284 uuid: None,
285 },
286 };
287 let tokens = final_context_tokens_from_last_response(&[msg]);
288 assert_eq!(tokens, 1500);
290 }
291
292 #[test]
293 fn test_token_count_with_estimation_basic() {
294 let usage = TokenUsage {
295 input_tokens: 100,
296 output_tokens: 50,
297 cache_creation_input_tokens: None,
298 cache_read_input_tokens: None,
299 iterations: None,
300 };
301 let msg = Message {
302 msg_type: "assistant".to_string(),
303 message: InnerMessage {
304 content: vec![ContentBlock::Text { text: "hello".to_string() }],
305 usage: Some(usage),
306 id: Some("msg-1".to_string()),
307 model: None,
308 uuid: None,
309 },
310 };
311 let count = token_count_with_estimation(&[msg.clone()]);
312 assert_eq!(count, 150);
313 }
314
315 #[test]
316 fn test_rough_token_estimation_for_messages() {
317 let msg = Message {
318 msg_type: "user".to_string(),
319 message: InnerMessage {
320 content: vec![ContentBlock::Text { text: "Hello world".to_string() }],
321 usage: None,
322 id: None,
323 model: None,
324 uuid: None,
325 },
326 };
327 let est = rough_token_count_estimation_for_messages(&[msg]);
329 assert!(est >= 2 && est <= 3);
330 }
331}