agtrace_engine/
state_updates.rs1use crate::token_usage::ContextWindowUsage;
2use agtrace_types::{AgentEvent, EventPayload};
3use serde_json::Value;
4
5#[derive(Debug, Clone, PartialEq, Default)]
7pub struct StateUpdates {
8 pub model: Option<String>,
9 pub context_window_limit: Option<u64>,
10 pub usage: Option<ContextWindowUsage>,
11 pub reasoning_tokens: Option<i32>,
12 pub is_error: bool,
13 pub is_new_turn: bool,
14}
15
16pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
18 let mut updates = StateUpdates::default();
19
20 match &event.payload {
21 EventPayload::User(_) => {
22 updates.is_new_turn = true;
23 }
24 EventPayload::TokenUsage(usage) => {
25 let cache_creation = usage
26 .details
27 .as_ref()
28 .and_then(|d| d.cache_creation_input_tokens)
29 .unwrap_or(0);
30 let cache_read = usage
31 .details
32 .as_ref()
33 .and_then(|d| d.cache_read_input_tokens)
34 .unwrap_or(0);
35 let reasoning_tokens = usage
36 .details
37 .as_ref()
38 .and_then(|d| d.reasoning_output_tokens)
39 .unwrap_or(0);
40
41 updates.usage = Some(ContextWindowUsage::from_raw(
46 usage.input_tokens,
47 cache_creation,
48 cache_read,
49 usage.output_tokens,
50 ));
51 updates.reasoning_tokens = Some(reasoning_tokens);
52 }
53 EventPayload::ToolResult(result) => {
54 if result.is_error {
55 updates.is_error = true;
56 } else {
57 updates.is_error = false;
59 }
60 }
61 _ => {}
62 }
63
64 if let Some(metadata) = &event.metadata {
65 if updates.model.is_none() {
66 updates.model = extract_model(metadata);
67 }
68
69 if updates.context_window_limit.is_none() {
70 updates.context_window_limit = extract_context_window_limit(metadata);
71 }
72 }
73
74 updates
75}
76
77fn extract_model(metadata: &Value) -> Option<String> {
78 metadata
79 .get("message")
80 .and_then(|m| m.get("model"))
81 .and_then(|v| v.as_str())
82 .map(|s| s.to_string())
83 .or_else(|| {
84 metadata
85 .get("model")
86 .and_then(|v| v.as_str())
87 .map(|s| s.to_string())
88 })
89}
90
91fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
92 metadata
93 .get("info")
94 .and_then(|info| info.get("model_context_window"))
95 .and_then(|v| v.as_u64())
96 .or_else(|| {
97 metadata
98 .get("payload")
99 .and_then(|payload| payload.get("info"))
100 .and_then(|info| info.get("model_context_window"))
101 .and_then(|v| v.as_u64())
102 })
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use agtrace_types::{TokenUsageDetails, TokenUsagePayload, ToolResultPayload, UserPayload};
109 use chrono::Utc;
110 use std::str::FromStr;
111 use uuid::Uuid;
112
113 fn base_event(payload: EventPayload) -> AgentEvent {
114 AgentEvent {
115 id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
116 session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
117 parent_id: None,
118 timestamp: Utc::now(),
119 stream_id: agtrace_types::StreamId::Main,
120 payload,
121 metadata: None,
122 }
123 }
124
125 #[test]
126 fn extracts_user_turn_flag() {
127 let event = base_event(EventPayload::User(UserPayload {
128 text: "hi".to_string(),
129 }));
130
131 let updates = extract_state_updates(&event);
132 assert!(updates.is_new_turn);
133 assert!(!updates.is_error);
134 }
135
136 #[test]
137 fn extracts_token_usage_and_reasoning() {
138 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
139 input_tokens: 100,
140 output_tokens: 50,
141 total_tokens: 150,
142 details: Some(TokenUsageDetails {
143 cache_creation_input_tokens: Some(5),
144 cache_read_input_tokens: Some(20),
145 reasoning_output_tokens: Some(7),
146 }),
147 }));
148
149 let mut meta = serde_json::Map::new();
150 meta.insert(
151 "model".to_string(),
152 serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
153 );
154 meta.insert(
155 "info".to_string(),
156 serde_json::json!({ "model_context_window": 200000 }),
157 );
158 event.metadata = Some(Value::Object(meta));
159
160 let updates = extract_state_updates(&event);
161
162 let usage = updates.usage.expect("usage should be set");
163 assert_eq!(usage.fresh_input.0, 100);
164 assert_eq!(usage.output.0, 50);
165 assert_eq!(usage.cache_creation.0, 5);
166 assert_eq!(usage.cache_read.0, 20);
167 assert_eq!(usage.total_tokens(), crate::TokenCount::new(175));
168
169 assert_eq!(updates.reasoning_tokens, Some(7));
170 assert_eq!(
171 updates.model,
172 Some("claude-3-5-sonnet-20241022".to_string())
173 );
174 assert_eq!(updates.context_window_limit, Some(200_000));
175 }
176
177 #[test]
178 fn extracts_context_window_limit_from_payload_info() {
179 let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
180 input_tokens: 10,
181 output_tokens: 5,
182 total_tokens: 15,
183 details: None,
184 }));
185
186 event.metadata = Some(serde_json::json!({
187 "payload": {
188 "info": { "model_context_window": 123_000 }
189 }
190 }));
191
192 let updates = extract_state_updates(&event);
193 assert_eq!(updates.context_window_limit, Some(123_000));
194 }
195
196 #[test]
197 fn extracts_tool_result_error_flag() {
198 let event = base_event(EventPayload::ToolResult(ToolResultPayload {
199 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
200 output: "err".to_string(),
201 is_error: true,
202 }));
203
204 let updates = extract_state_updates(&event);
205 assert!(updates.is_error);
206 }
207
208 #[test]
209 fn applies_updates_to_session_state_without_io() {
210 #[derive(Default)]
211 struct SessionState {
212 model: Option<String>,
213 context_window_limit: Option<u64>,
214 usage: ContextWindowUsage,
215 reasoning_tokens: i32,
216 turn_count: usize,
217 error_count: u32,
218 }
219
220 impl SessionState {
221 fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
222 if updates.is_new_turn {
223 self.turn_count += 1;
224 self.error_count = 0;
225 }
226 if is_error_event && updates.is_error {
227 self.error_count += 1;
228 }
229 if let Some(m) = updates.model {
230 self.model.get_or_insert(m);
231 }
232 if let Some(limit) = updates.context_window_limit {
233 self.context_window_limit.get_or_insert(limit);
234 }
235 if let Some(u) = updates.usage {
236 self.usage = u;
237 }
238 if let Some(rt) = updates.reasoning_tokens {
239 self.reasoning_tokens = rt;
240 }
241 }
242 }
243
244 let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
245 let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
246 input_tokens: 120,
247 output_tokens: 30,
248 total_tokens: 150,
249 details: Some(TokenUsageDetails {
250 cache_creation_input_tokens: Some(10),
251 cache_read_input_tokens: Some(5),
252 reasoning_output_tokens: Some(3),
253 }),
254 }));
255 let mut meta = serde_json::Map::new();
256 meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
257 meta.insert(
258 "info".into(),
259 serde_json::json!({"model_context_window": 100000}),
260 );
261 usage_event.metadata = Some(Value::Object(meta));
262
263 let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
264 tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
265 output: "boom".into(),
266 is_error: true,
267 }));
268
269 let mut state = SessionState::default();
270
271 state.apply(extract_state_updates(&user), false);
272 state.apply(extract_state_updates(&usage_event), false);
273 state.apply(extract_state_updates(&tool_err), true);
274
275 assert_eq!(state.turn_count, 1);
276 assert_eq!(state.error_count, 1);
277 assert_eq!(state.model.as_deref(), Some("claude-3"));
278 assert_eq!(state.context_window_limit, Some(100_000));
279 assert_eq!(state.usage.fresh_input.0, 120);
280 assert_eq!(state.usage.output.0, 30);
281 assert_eq!(state.usage.cache_creation.0, 10);
282 assert_eq!(state.usage.cache_read.0, 5);
283 assert_eq!(state.reasoning_tokens, 3);
284 }
285}