agtrace_engine/
state_updates.rs1use crate::token_usage::ContextWindowUsage;
2use agtrace_types::{AgentEvent, EventPayload};
3use serde_json::Value;
4
5#[derive(Debug, Clone, PartialEq, Default)]
7pub struct StateUpdates {
8 pub model: Option<String>,
9 pub context_window_limit: Option<u64>,
10 pub usage: Option<ContextWindowUsage>,
11 pub reasoning_tokens: Option<i32>,
12 pub is_error: bool,
13 pub is_new_turn: bool,
14}
15
16pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
18 let mut updates = StateUpdates::default();
19
20 match &event.payload {
21 EventPayload::User(_) => {
22 updates.is_new_turn = true;
23 }
24 EventPayload::TokenUsage(usage) => {
25 let cache_creation = usage
26 .details
27 .as_ref()
28 .and_then(|d| d.cache_creation_input_tokens)
29 .unwrap_or(0);
30 let cache_read = usage
31 .details
32 .as_ref()
33 .and_then(|d| d.cache_read_input_tokens)
34 .unwrap_or(0);
35 let reasoning_tokens = usage
36 .details
37 .as_ref()
38 .and_then(|d| d.reasoning_output_tokens)
39 .unwrap_or(0);
40
41 updates.usage = Some(ContextWindowUsage::from_raw(
42 usage.input_tokens,
43 cache_creation,
44 cache_read,
45 usage.output_tokens,
46 ));
47 updates.reasoning_tokens = Some(reasoning_tokens);
48 }
49 EventPayload::ToolResult(result) => {
50 if result.is_error {
51 updates.is_error = true;
52 } else {
53 updates.is_error = false;
55 }
56 }
57 _ => {}
58 }
59
60 if let Some(metadata) = &event.metadata {
61 if updates.model.is_none() {
62 updates.model = extract_model(metadata);
63 }
64
65 if updates.context_window_limit.is_none() {
66 updates.context_window_limit = extract_context_window_limit(metadata);
67 }
68 }
69
70 updates
71}
72
73fn extract_model(metadata: &Value) -> Option<String> {
74 metadata
75 .get("message")
76 .and_then(|m| m.get("model"))
77 .and_then(|v| v.as_str())
78 .map(|s| s.to_string())
79 .or_else(|| {
80 metadata
81 .get("model")
82 .and_then(|v| v.as_str())
83 .map(|s| s.to_string())
84 })
85}
86
87fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
88 metadata
89 .get("info")
90 .and_then(|info| info.get("model_context_window"))
91 .and_then(|v| v.as_u64())
92 .or_else(|| {
93 metadata
94 .get("payload")
95 .and_then(|payload| payload.get("info"))
96 .and_then(|info| info.get("model_context_window"))
97 .and_then(|v| v.as_u64())
98 })
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use agtrace_types::{TokenUsageDetails, TokenUsagePayload, ToolResultPayload, UserPayload};
105 use chrono::Utc;
106 use std::str::FromStr;
107 use uuid::Uuid;
108
109 fn base_event(payload: EventPayload) -> AgentEvent {
110 AgentEvent {
111 id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
112 session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
113 parent_id: None,
114 timestamp: Utc::now(),
115 stream_id: agtrace_types::StreamId::Main,
116 payload,
117 metadata: None,
118 }
119 }
120
121 #[test]
122 fn extracts_user_turn_flag() {
123 let event = base_event(EventPayload::User(UserPayload {
124 text: "hi".to_string(),
125 }));
126
127 let updates = extract_state_updates(&event);
128 assert!(updates.is_new_turn);
129 assert!(!updates.is_error);
130 }
131
132 #[test]
133 fn extracts_token_usage_and_reasoning() {
134 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
135 input_tokens: 100,
136 output_tokens: 50,
137 total_tokens: 150,
138 details: Some(TokenUsageDetails {
139 cache_creation_input_tokens: Some(5),
140 cache_read_input_tokens: Some(20),
141 reasoning_output_tokens: Some(7),
142 }),
143 }));
144
145 let mut meta = serde_json::Map::new();
146 meta.insert(
147 "model".to_string(),
148 serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
149 );
150 meta.insert(
151 "info".to_string(),
152 serde_json::json!({ "model_context_window": 200000 }),
153 );
154 event.metadata = Some(Value::Object(meta));
155
156 let updates = extract_state_updates(&event);
157
158 let usage = updates.usage.expect("usage should be set");
159 assert_eq!(usage.fresh_input.0, 100);
160 assert_eq!(usage.output.0, 50);
161 assert_eq!(usage.cache_creation.0, 5);
162 assert_eq!(usage.cache_read.0, 20);
163 assert_eq!(usage.context_window_tokens(), 175);
164
165 assert_eq!(updates.reasoning_tokens, Some(7));
166 assert_eq!(
167 updates.model,
168 Some("claude-3-5-sonnet-20241022".to_string())
169 );
170 assert_eq!(updates.context_window_limit, Some(200_000));
171 }
172
173 #[test]
174 fn extracts_context_window_limit_from_payload_info() {
175 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
176 input_tokens: 10,
177 output_tokens: 5,
178 total_tokens: 15,
179 details: None,
180 }));
181
182 event.metadata = Some(serde_json::json!({
183 "payload": {
184 "info": { "model_context_window": 123_000 }
185 }
186 }));
187
188 let updates = extract_state_updates(&event);
189 assert_eq!(updates.context_window_limit, Some(123_000));
190 }
191
192 #[test]
193 fn extracts_tool_result_error_flag() {
194 let event = base_event(EventPayload::ToolResult(ToolResultPayload {
195 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
196 output: "err".to_string(),
197 is_error: true,
198 }));
199
200 let updates = extract_state_updates(&event);
201 assert!(updates.is_error);
202 }
203
204 #[test]
205 fn applies_updates_to_session_state_without_io() {
206 #[derive(Default)]
207 struct SessionState {
208 model: Option<String>,
209 context_window_limit: Option<u64>,
210 usage: ContextWindowUsage,
211 reasoning_tokens: i32,
212 turn_count: usize,
213 error_count: u32,
214 }
215
216 impl SessionState {
217 fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
218 if updates.is_new_turn {
219 self.turn_count += 1;
220 self.error_count = 0;
221 }
222 if is_error_event && updates.is_error {
223 self.error_count += 1;
224 }
225 if let Some(m) = updates.model {
226 self.model.get_or_insert(m);
227 }
228 if let Some(limit) = updates.context_window_limit {
229 self.context_window_limit.get_or_insert(limit);
230 }
231 if let Some(u) = updates.usage {
232 self.usage = u;
233 }
234 if let Some(rt) = updates.reasoning_tokens {
235 self.reasoning_tokens = rt;
236 }
237 }
238 }
239
240 let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
241 let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
242 input_tokens: 120,
243 output_tokens: 30,
244 total_tokens: 150,
245 details: Some(TokenUsageDetails {
246 cache_creation_input_tokens: Some(10),
247 cache_read_input_tokens: Some(5),
248 reasoning_output_tokens: Some(3),
249 }),
250 }));
251 let mut meta = serde_json::Map::new();
252 meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
253 meta.insert(
254 "info".into(),
255 serde_json::json!({"model_context_window": 100000}),
256 );
257 usage_event.metadata = Some(Value::Object(meta));
258
259 let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
260 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
261 output: "boom".into(),
262 is_error: true,
263 }));
264
265 let mut state = SessionState::default();
266
267 state.apply(extract_state_updates(&user), false);
268 state.apply(extract_state_updates(&usage_event), false);
269 state.apply(extract_state_updates(&tool_err), true);
270
271 assert_eq!(state.turn_count, 1);
272 assert_eq!(state.error_count, 1);
273 assert_eq!(state.model.as_deref(), Some("claude-3"));
274 assert_eq!(state.context_window_limit, Some(100_000));
275 assert_eq!(state.usage.fresh_input.0, 120);
276 assert_eq!(state.usage.output.0, 30);
277 assert_eq!(state.usage.cache_creation.0, 10);
278 assert_eq!(state.usage.cache_read.0, 5);
279 assert_eq!(state.reasoning_tokens, 3);
280 }
281}