1use agtrace_types::{AgentEvent, ContextWindowUsage, EventPayload};
2use serde_json::Value;
3
4#[derive(Debug, Clone, PartialEq, Default)]
6pub struct StateUpdates {
7 pub model: Option<String>,
8 pub context_window_limit: Option<u64>,
9 pub usage: Option<ContextWindowUsage>,
10 pub reasoning_tokens: Option<i32>,
11 pub is_error: bool,
12 pub is_new_turn: bool,
13}
14
15pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
17 let mut updates = StateUpdates::default();
18
19 match &event.payload {
20 EventPayload::User(_) | EventPayload::SlashCommand(_) => {
21 updates.is_new_turn = true;
22 }
23 EventPayload::TokenUsage(usage) => {
24 updates.usage = Some(ContextWindowUsage::from_raw(
28 usage.input.uncached as i32, 0, usage.input.cached as i32, usage.output.total() as i32, ));
33 updates.reasoning_tokens = Some(usage.output.reasoning as i32);
34 }
35 EventPayload::ToolResult(result) => {
36 if result.is_error {
37 updates.is_error = true;
38 } else {
39 updates.is_error = false;
41 }
42 }
43 _ => {}
44 }
45
46 if let Some(metadata) = &event.metadata {
47 if updates.model.is_none() {
48 updates.model = extract_model(metadata);
49 }
50
51 if updates.context_window_limit.is_none() {
52 updates.context_window_limit = extract_context_window_limit(metadata);
53 }
54 }
55
56 updates
57}
58
59fn extract_model(metadata: &Value) -> Option<String> {
60 metadata
61 .get("message")
62 .and_then(|m| m.get("model"))
63 .and_then(|v| v.as_str())
64 .map(|s| s.to_string())
65 .or_else(|| {
66 metadata
67 .get("model")
68 .and_then(|v| v.as_str())
69 .map(|s| s.to_string())
70 })
71}
72
73fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
74 metadata
75 .get("info")
76 .and_then(|info| info.get("model_context_window"))
77 .and_then(|v| v.as_u64())
78 .or_else(|| {
79 metadata
80 .get("payload")
81 .and_then(|payload| payload.get("info"))
82 .and_then(|info| info.get("model_context_window"))
83 .and_then(|v| v.as_u64())
84 })
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use agtrace_types::{
91 TokenInput, TokenOutput, TokenUsagePayload, ToolResultPayload, UserPayload,
92 };
93 use chrono::Utc;
94 use std::str::FromStr;
95 use uuid::Uuid;
96
97 fn base_event(payload: EventPayload) -> AgentEvent {
98 AgentEvent {
99 id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
100 session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
101 parent_id: None,
102 timestamp: Utc::now(),
103 stream_id: agtrace_types::StreamId::Main,
104 payload,
105 metadata: None,
106 }
107 }
108
109 #[test]
110 fn extracts_user_turn_flag() {
111 let event = base_event(EventPayload::User(UserPayload {
112 text: "hi".to_string(),
113 }));
114
115 let updates = extract_state_updates(&event);
116 assert!(updates.is_new_turn);
117 assert!(!updates.is_error);
118 }
119
120 #[test]
121 fn extracts_token_usage_and_reasoning() {
122 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
123 TokenInput::new(20, 100), TokenOutput::new(43, 7, 0), )));
126
127 let mut meta = serde_json::Map::new();
128 meta.insert(
129 "model".to_string(),
130 serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
131 );
132 meta.insert(
133 "info".to_string(),
134 serde_json::json!({ "model_context_window": 200000 }),
135 );
136 event.metadata = Some(Value::Object(meta));
137
138 let updates = extract_state_updates(&event);
139
140 let usage = updates.usage.expect("usage should be set");
141 assert_eq!(usage.fresh_input.0, 100); assert_eq!(usage.cache_read.0, 20); assert_eq!(usage.output.0, 50); assert_eq!(usage.total_tokens(), crate::TokenCount::new(170)); assert_eq!(updates.reasoning_tokens, Some(7));
147 assert_eq!(
148 updates.model,
149 Some("claude-3-5-sonnet-20241022".to_string())
150 );
151 assert_eq!(updates.context_window_limit, Some(200_000));
152 }
153
154 #[test]
155 fn extracts_context_window_limit_from_payload_info() {
156 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
157 TokenInput::new(0, 10), TokenOutput::new(5, 0, 0), )));
160
161 event.metadata = Some(serde_json::json!({
162 "payload": {
163 "info": { "model_context_window": 123_000 }
164 }
165 }));
166
167 let updates = extract_state_updates(&event);
168 assert_eq!(updates.context_window_limit, Some(123_000));
169 }
170
171 #[test]
172 fn extracts_tool_result_error_flag() {
173 let event = base_event(EventPayload::ToolResult(ToolResultPayload {
174 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
175 output: "err".to_string(),
176 is_error: true,
177 agent_id: None,
178 }));
179
180 let updates = extract_state_updates(&event);
181 assert!(updates.is_error);
182 }
183
184 #[test]
185 fn token_usage_conversion_avoids_double_counting_cached_tokens() {
186 let event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
204 TokenInput::new(20, 100), TokenOutput::new(50, 0, 0), )));
207
208 let updates = extract_state_updates(&event);
209 let usage = updates.usage.expect("usage should be set");
210
211 assert_eq!(
213 usage.fresh_input.0, 100,
214 "fresh_input should be uncached tokens only (not total)"
215 );
216 assert_eq!(usage.cache_read.0, 20, "cache_read should be cached tokens");
217 assert_eq!(usage.output.0, 50, "output should match");
218 assert_eq!(
219 usage.total_tokens(),
220 crate::TokenCount::new(170),
221 "total should be 100 (fresh) + 20 (cache) + 50 (output) = 170, not 190"
222 );
223 }
224
225 #[test]
226 fn token_usage_conversion_uses_uncached_for_fresh_input() {
227 let token_payload = TokenUsagePayload::new(
234 TokenInput::new(30, 200), TokenOutput::new(80, 10, 5), );
237
238 let event = base_event(EventPayload::TokenUsage(token_payload));
239 let updates = extract_state_updates(&event);
240 let usage = updates.usage.expect("usage should be set");
241
242 assert_eq!(
244 usage.fresh_input.0, 200,
245 "fresh_input must be uncached tokens only (200), not total (230)"
246 );
247 assert_eq!(usage.cache_read.0, 30);
248 assert_eq!(usage.output.0, 95); assert_eq!(usage.total_tokens(), crate::TokenCount::new(325)); }
251
252 #[test]
253 fn applies_updates_to_session_state_without_io() {
254 #[derive(Default)]
255 struct SessionState {
256 model: Option<String>,
257 context_window_limit: Option<u64>,
258 usage: ContextWindowUsage,
259 reasoning_tokens: i32,
260 turn_count: usize,
261 error_count: u32,
262 }
263
264 impl SessionState {
265 fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
266 if updates.is_new_turn {
267 self.turn_count += 1;
268 self.error_count = 0;
269 }
270 if is_error_event && updates.is_error {
271 self.error_count += 1;
272 }
273 if let Some(m) = updates.model {
274 self.model.get_or_insert(m);
275 }
276 if let Some(limit) = updates.context_window_limit {
277 self.context_window_limit.get_or_insert(limit);
278 }
279 if let Some(u) = updates.usage {
280 self.usage = u;
281 }
282 if let Some(rt) = updates.reasoning_tokens {
283 self.reasoning_tokens = rt;
284 }
285 }
286 }
287
288 let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
289 let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
290 TokenInput::new(5, 120), TokenOutput::new(27, 3, 0), )));
293 let mut meta = serde_json::Map::new();
294 meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
295 meta.insert(
296 "info".into(),
297 serde_json::json!({"model_context_window": 100000}),
298 );
299 usage_event.metadata = Some(Value::Object(meta));
300
301 let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
302 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
303 output: "boom".into(),
304 is_error: true,
305 agent_id: None,
306 }));
307
308 let mut state = SessionState::default();
309
310 state.apply(extract_state_updates(&user), false);
311 state.apply(extract_state_updates(&usage_event), false);
312 state.apply(extract_state_updates(&tool_err), true);
313
314 assert_eq!(state.turn_count, 1);
315 assert_eq!(state.error_count, 1);
316 assert_eq!(state.model.as_deref(), Some("claude-3"));
317 assert_eq!(state.context_window_limit, Some(100_000));
318 assert_eq!(state.usage.fresh_input.0, 120); assert_eq!(state.usage.cache_read.0, 5);
320 assert_eq!(state.usage.output.0, 30); assert_eq!(state.reasoning_tokens, 3);
322 }
323}