1use agtrace_types::{AgentEvent, ContextWindowUsage, EventPayload};
2use serde_json::Value;
3
4#[derive(Debug, Clone, PartialEq, Default)]
6pub struct StateUpdates {
7 pub model: Option<String>,
8 pub context_window_limit: Option<u64>,
9 pub usage: Option<ContextWindowUsage>,
10 pub reasoning_tokens: Option<i32>,
11 pub is_error: bool,
12 pub is_new_turn: bool,
13}
14
15pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
17 let mut updates = StateUpdates::default();
18
19 match &event.payload {
20 EventPayload::User(_) => {
21 updates.is_new_turn = true;
22 }
23 EventPayload::TokenUsage(usage) => {
24 updates.usage = Some(ContextWindowUsage::from_raw(
28 usage.input.uncached as i32, 0, usage.input.cached as i32, usage.output.total() as i32, ));
33 updates.reasoning_tokens = Some(usage.output.reasoning as i32);
34 }
35 EventPayload::ToolResult(result) => {
36 if result.is_error {
37 updates.is_error = true;
38 } else {
39 updates.is_error = false;
41 }
42 }
43 _ => {}
44 }
45
46 if let Some(metadata) = &event.metadata {
47 if updates.model.is_none() {
48 updates.model = extract_model(metadata);
49 }
50
51 if updates.context_window_limit.is_none() {
52 updates.context_window_limit = extract_context_window_limit(metadata);
53 }
54 }
55
56 updates
57}
58
59fn extract_model(metadata: &Value) -> Option<String> {
60 metadata
61 .get("message")
62 .and_then(|m| m.get("model"))
63 .and_then(|v| v.as_str())
64 .map(|s| s.to_string())
65 .or_else(|| {
66 metadata
67 .get("model")
68 .and_then(|v| v.as_str())
69 .map(|s| s.to_string())
70 })
71}
72
73fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
74 metadata
75 .get("info")
76 .and_then(|info| info.get("model_context_window"))
77 .and_then(|v| v.as_u64())
78 .or_else(|| {
79 metadata
80 .get("payload")
81 .and_then(|payload| payload.get("info"))
82 .and_then(|info| info.get("model_context_window"))
83 .and_then(|v| v.as_u64())
84 })
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use agtrace_types::{
91 TokenInput, TokenOutput, TokenUsagePayload, ToolResultPayload, UserPayload,
92 };
93 use chrono::Utc;
94 use std::str::FromStr;
95 use uuid::Uuid;
96
97 fn base_event(payload: EventPayload) -> AgentEvent {
98 AgentEvent {
99 id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
100 session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
101 parent_id: None,
102 timestamp: Utc::now(),
103 stream_id: agtrace_types::StreamId::Main,
104 payload,
105 metadata: None,
106 }
107 }
108
109 #[test]
110 fn extracts_user_turn_flag() {
111 let event = base_event(EventPayload::User(UserPayload {
112 text: "hi".to_string(),
113 }));
114
115 let updates = extract_state_updates(&event);
116 assert!(updates.is_new_turn);
117 assert!(!updates.is_error);
118 }
119
120 #[test]
121 fn extracts_token_usage_and_reasoning() {
122 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
123 TokenInput::new(20, 100), TokenOutput::new(43, 7, 0), )));
126
127 let mut meta = serde_json::Map::new();
128 meta.insert(
129 "model".to_string(),
130 serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
131 );
132 meta.insert(
133 "info".to_string(),
134 serde_json::json!({ "model_context_window": 200000 }),
135 );
136 event.metadata = Some(Value::Object(meta));
137
138 let updates = extract_state_updates(&event);
139
140 let usage = updates.usage.expect("usage should be set");
141 assert_eq!(usage.fresh_input.0, 100); assert_eq!(usage.cache_read.0, 20); assert_eq!(usage.output.0, 50); assert_eq!(usage.total_tokens(), crate::TokenCount::new(170)); assert_eq!(updates.reasoning_tokens, Some(7));
147 assert_eq!(
148 updates.model,
149 Some("claude-3-5-sonnet-20241022".to_string())
150 );
151 assert_eq!(updates.context_window_limit, Some(200_000));
152 }
153
154 #[test]
155 fn extracts_context_window_limit_from_payload_info() {
156 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
157 TokenInput::new(0, 10), TokenOutput::new(5, 0, 0), )));
160
161 event.metadata = Some(serde_json::json!({
162 "payload": {
163 "info": { "model_context_window": 123_000 }
164 }
165 }));
166
167 let updates = extract_state_updates(&event);
168 assert_eq!(updates.context_window_limit, Some(123_000));
169 }
170
171 #[test]
172 fn extracts_tool_result_error_flag() {
173 let event = base_event(EventPayload::ToolResult(ToolResultPayload {
174 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
175 output: "err".to_string(),
176 is_error: true,
177 }));
178
179 let updates = extract_state_updates(&event);
180 assert!(updates.is_error);
181 }
182
183 #[test]
184 fn token_usage_conversion_avoids_double_counting_cached_tokens() {
185 let event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
203 TokenInput::new(20, 100), TokenOutput::new(50, 0, 0), )));
206
207 let updates = extract_state_updates(&event);
208 let usage = updates.usage.expect("usage should be set");
209
210 assert_eq!(
212 usage.fresh_input.0, 100,
213 "fresh_input should be uncached tokens only (not total)"
214 );
215 assert_eq!(usage.cache_read.0, 20, "cache_read should be cached tokens");
216 assert_eq!(usage.output.0, 50, "output should match");
217 assert_eq!(
218 usage.total_tokens(),
219 crate::TokenCount::new(170),
220 "total should be 100 (fresh) + 20 (cache) + 50 (output) = 170, not 190"
221 );
222 }
223
224 #[test]
225 fn token_usage_conversion_uses_uncached_for_fresh_input() {
226 let token_payload = TokenUsagePayload::new(
233 TokenInput::new(30, 200), TokenOutput::new(80, 10, 5), );
236
237 let event = base_event(EventPayload::TokenUsage(token_payload));
238 let updates = extract_state_updates(&event);
239 let usage = updates.usage.expect("usage should be set");
240
241 assert_eq!(
243 usage.fresh_input.0, 200,
244 "fresh_input must be uncached tokens only (200), not total (230)"
245 );
246 assert_eq!(usage.cache_read.0, 30);
247 assert_eq!(usage.output.0, 95); assert_eq!(usage.total_tokens(), crate::TokenCount::new(325)); }
250
251 #[test]
252 fn applies_updates_to_session_state_without_io() {
253 #[derive(Default)]
254 struct SessionState {
255 model: Option<String>,
256 context_window_limit: Option<u64>,
257 usage: ContextWindowUsage,
258 reasoning_tokens: i32,
259 turn_count: usize,
260 error_count: u32,
261 }
262
263 impl SessionState {
264 fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
265 if updates.is_new_turn {
266 self.turn_count += 1;
267 self.error_count = 0;
268 }
269 if is_error_event && updates.is_error {
270 self.error_count += 1;
271 }
272 if let Some(m) = updates.model {
273 self.model.get_or_insert(m);
274 }
275 if let Some(limit) = updates.context_window_limit {
276 self.context_window_limit.get_or_insert(limit);
277 }
278 if let Some(u) = updates.usage {
279 self.usage = u;
280 }
281 if let Some(rt) = updates.reasoning_tokens {
282 self.reasoning_tokens = rt;
283 }
284 }
285 }
286
287 let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
288 let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
289 TokenInput::new(5, 120), TokenOutput::new(27, 3, 0), )));
292 let mut meta = serde_json::Map::new();
293 meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
294 meta.insert(
295 "info".into(),
296 serde_json::json!({"model_context_window": 100000}),
297 );
298 usage_event.metadata = Some(Value::Object(meta));
299
300 let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
301 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
302 output: "boom".into(),
303 is_error: true,
304 }));
305
306 let mut state = SessionState::default();
307
308 state.apply(extract_state_updates(&user), false);
309 state.apply(extract_state_updates(&usage_event), false);
310 state.apply(extract_state_updates(&tool_err), true);
311
312 assert_eq!(state.turn_count, 1);
313 assert_eq!(state.error_count, 1);
314 assert_eq!(state.model.as_deref(), Some("claude-3"));
315 assert_eq!(state.context_window_limit, Some(100_000));
316 assert_eq!(state.usage.fresh_input.0, 120); assert_eq!(state.usage.cache_read.0, 5);
318 assert_eq!(state.usage.output.0, 30); assert_eq!(state.reasoning_tokens, 3);
320 }
321}