agtrace_engine/
state_updates.rs

1use crate::token_usage::ContextWindowUsage;
2use agtrace_types::{AgentEvent, EventPayload};
3use serde_json::Value;
4
5/// Pure data extracted from an AgentEvent to update runtime session state.
6#[derive(Debug, Clone, PartialEq, Default)]
7pub struct StateUpdates {
8    pub model: Option<String>,
9    pub context_window_limit: Option<u64>,
10    pub usage: Option<ContextWindowUsage>,
11    pub reasoning_tokens: Option<i32>,
12    pub is_error: bool,
13    pub is_new_turn: bool,
14}
15
16/// Extract state updates from a single event without performing I/O or side effects.
17pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
18    let mut updates = StateUpdates::default();
19
20    match &event.payload {
21        EventPayload::User(_) => {
22            updates.is_new_turn = true;
23        }
24        EventPayload::TokenUsage(usage) => {
25            let cache_creation = usage
26                .details
27                .as_ref()
28                .and_then(|d| d.cache_creation_input_tokens)
29                .unwrap_or(0);
30            let cache_read = usage
31                .details
32                .as_ref()
33                .and_then(|d| d.cache_read_input_tokens)
34                .unwrap_or(0);
35            let reasoning_tokens = usage
36                .details
37                .as_ref()
38                .and_then(|d| d.reasoning_output_tokens)
39                .unwrap_or(0);
40
41            updates.usage = Some(ContextWindowUsage::from_raw(
42                usage.input_tokens,
43                cache_creation,
44                cache_read,
45                usage.output_tokens,
46            ));
47            updates.reasoning_tokens = Some(reasoning_tokens);
48        }
49        EventPayload::ToolResult(result) => {
50            if result.is_error {
51                updates.is_error = true;
52            } else {
53                // Explicitly mark success so consumers can reset counters if needed.
54                updates.is_error = false;
55            }
56        }
57        _ => {}
58    }
59
60    if let Some(metadata) = &event.metadata {
61        if updates.model.is_none() {
62            updates.model = extract_model(metadata);
63        }
64
65        if updates.context_window_limit.is_none() {
66            updates.context_window_limit = extract_context_window_limit(metadata);
67        }
68    }
69
70    updates
71}
72
73fn extract_model(metadata: &Value) -> Option<String> {
74    metadata
75        .get("message")
76        .and_then(|m| m.get("model"))
77        .and_then(|v| v.as_str())
78        .map(|s| s.to_string())
79        .or_else(|| {
80            metadata
81                .get("model")
82                .and_then(|v| v.as_str())
83                .map(|s| s.to_string())
84        })
85}
86
87fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
88    metadata
89        .get("info")
90        .and_then(|info| info.get("model_context_window"))
91        .and_then(|v| v.as_u64())
92        .or_else(|| {
93            metadata
94                .get("payload")
95                .and_then(|payload| payload.get("info"))
96                .and_then(|info| info.get("model_context_window"))
97                .and_then(|v| v.as_u64())
98        })
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use agtrace_types::{TokenUsageDetails, TokenUsagePayload, ToolResultPayload, UserPayload};
105    use chrono::Utc;
106    use std::str::FromStr;
107    use uuid::Uuid;
108
109    fn base_event(payload: EventPayload) -> AgentEvent {
110        AgentEvent {
111            id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
112            session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
113            parent_id: None,
114            timestamp: Utc::now(),
115            stream_id: agtrace_types::StreamId::Main,
116            payload,
117            metadata: None,
118        }
119    }
120
121    #[test]
122    fn extracts_user_turn_flag() {
123        let event = base_event(EventPayload::User(UserPayload {
124            text: "hi".to_string(),
125        }));
126
127        let updates = extract_state_updates(&event);
128        assert!(updates.is_new_turn);
129        assert!(!updates.is_error);
130    }
131
132    #[test]
133    fn extracts_token_usage_and_reasoning() {
134        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
135            input_tokens: 100,
136            output_tokens: 50,
137            total_tokens: 150,
138            details: Some(TokenUsageDetails {
139                cache_creation_input_tokens: Some(5),
140                cache_read_input_tokens: Some(20),
141                reasoning_output_tokens: Some(7),
142            }),
143        }));
144
145        let mut meta = serde_json::Map::new();
146        meta.insert(
147            "model".to_string(),
148            serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
149        );
150        meta.insert(
151            "info".to_string(),
152            serde_json::json!({ "model_context_window": 200000 }),
153        );
154        event.metadata = Some(Value::Object(meta));
155
156        let updates = extract_state_updates(&event);
157
158        let usage = updates.usage.expect("usage should be set");
159        assert_eq!(usage.fresh_input.0, 100);
160        assert_eq!(usage.output.0, 50);
161        assert_eq!(usage.cache_creation.0, 5);
162        assert_eq!(usage.cache_read.0, 20);
163        assert_eq!(usage.context_window_tokens(), 175);
164
165        assert_eq!(updates.reasoning_tokens, Some(7));
166        assert_eq!(
167            updates.model,
168            Some("claude-3-5-sonnet-20241022".to_string())
169        );
170        assert_eq!(updates.context_window_limit, Some(200_000));
171    }
172
173    #[test]
174    fn extracts_context_window_limit_from_payload_info() {
175        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
176            input_tokens: 10,
177            output_tokens: 5,
178            total_tokens: 15,
179            details: None,
180        }));
181
182        event.metadata = Some(serde_json::json!({
183            "payload": {
184                "info": { "model_context_window": 123_000 }
185            }
186        }));
187
188        let updates = extract_state_updates(&event);
189        assert_eq!(updates.context_window_limit, Some(123_000));
190    }
191
192    #[test]
193    fn extracts_tool_result_error_flag() {
194        let event = base_event(EventPayload::ToolResult(ToolResultPayload {
195            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
196            output: "err".to_string(),
197            is_error: true,
198        }));
199
200        let updates = extract_state_updates(&event);
201        assert!(updates.is_error);
202    }
203
204    #[test]
205    fn applies_updates_to_session_state_without_io() {
206        #[derive(Default)]
207        struct SessionState {
208            model: Option<String>,
209            context_window_limit: Option<u64>,
210            usage: ContextWindowUsage,
211            reasoning_tokens: i32,
212            turn_count: usize,
213            error_count: u32,
214        }
215
216        impl SessionState {
217            fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
218                if updates.is_new_turn {
219                    self.turn_count += 1;
220                    self.error_count = 0;
221                }
222                if is_error_event && updates.is_error {
223                    self.error_count += 1;
224                }
225                if let Some(m) = updates.model {
226                    self.model.get_or_insert(m);
227                }
228                if let Some(limit) = updates.context_window_limit {
229                    self.context_window_limit.get_or_insert(limit);
230                }
231                if let Some(u) = updates.usage {
232                    self.usage = u;
233                }
234                if let Some(rt) = updates.reasoning_tokens {
235                    self.reasoning_tokens = rt;
236                }
237            }
238        }
239
240        let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
241        let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
242            input_tokens: 120,
243            output_tokens: 30,
244            total_tokens: 150,
245            details: Some(TokenUsageDetails {
246                cache_creation_input_tokens: Some(10),
247                cache_read_input_tokens: Some(5),
248                reasoning_output_tokens: Some(3),
249            }),
250        }));
251        let mut meta = serde_json::Map::new();
252        meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
253        meta.insert(
254            "info".into(),
255            serde_json::json!({"model_context_window": 100000}),
256        );
257        usage_event.metadata = Some(Value::Object(meta));
258
259        let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
260            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
261            output: "boom".into(),
262            is_error: true,
263        }));
264
265        let mut state = SessionState::default();
266
267        state.apply(extract_state_updates(&user), false);
268        state.apply(extract_state_updates(&usage_event), false);
269        state.apply(extract_state_updates(&tool_err), true);
270
271        assert_eq!(state.turn_count, 1);
272        assert_eq!(state.error_count, 1);
273        assert_eq!(state.model.as_deref(), Some("claude-3"));
274        assert_eq!(state.context_window_limit, Some(100_000));
275        assert_eq!(state.usage.fresh_input.0, 120);
276        assert_eq!(state.usage.output.0, 30);
277        assert_eq!(state.usage.cache_creation.0, 10);
278        assert_eq!(state.usage.cache_read.0, 5);
279        assert_eq!(state.reasoning_tokens, 3);
280    }
281}