agtrace_engine/
state_updates.rs

1use crate::token_usage::ContextWindowUsage;
2use agtrace_types::{AgentEvent, EventPayload};
3use serde_json::Value;
4
5/// Pure data extracted from an AgentEvent to update runtime session state.
6#[derive(Debug, Clone, PartialEq, Default)]
7pub struct StateUpdates {
8    pub model: Option<String>,
9    pub context_window_limit: Option<u64>,
10    pub usage: Option<ContextWindowUsage>,
11    pub reasoning_tokens: Option<i32>,
12    pub is_error: bool,
13    pub is_new_turn: bool,
14}
15
16/// Extract state updates from a single event without performing I/O or side effects.
17pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
18    let mut updates = StateUpdates::default();
19
20    match &event.payload {
21        EventPayload::User(_) => {
22            updates.is_new_turn = true;
23        }
24        EventPayload::TokenUsage(usage) => {
25            let cache_creation = usage
26                .details
27                .as_ref()
28                .and_then(|d| d.cache_creation_input_tokens)
29                .unwrap_or(0);
30            let cache_read = usage
31                .details
32                .as_ref()
33                .and_then(|d| d.cache_read_input_tokens)
34                .unwrap_or(0);
35            let reasoning_tokens = usage
36                .details
37                .as_ref()
38                .and_then(|d| d.reasoning_output_tokens)
39                .unwrap_or(0);
40
41            // TODO: Verify if input_tokens includes cache tokens (potential double counting)
42            // Claude API's input_tokens may already include cache_creation + cache_read.
43            // If so, fresh_input should be: usage.input_tokens - cache_creation - cache_read
44            // Need to verify with actual logs using: agtrace lab grep "TokenUsage" --json
45            updates.usage = Some(ContextWindowUsage::from_raw(
46                usage.input_tokens,
47                cache_creation,
48                cache_read,
49                usage.output_tokens,
50            ));
51            updates.reasoning_tokens = Some(reasoning_tokens);
52        }
53        EventPayload::ToolResult(result) => {
54            if result.is_error {
55                updates.is_error = true;
56            } else {
57                // Explicitly mark success so consumers can reset counters if needed.
58                updates.is_error = false;
59            }
60        }
61        _ => {}
62    }
63
64    if let Some(metadata) = &event.metadata {
65        if updates.model.is_none() {
66            updates.model = extract_model(metadata);
67        }
68
69        if updates.context_window_limit.is_none() {
70            updates.context_window_limit = extract_context_window_limit(metadata);
71        }
72    }
73
74    updates
75}
76
77fn extract_model(metadata: &Value) -> Option<String> {
78    metadata
79        .get("message")
80        .and_then(|m| m.get("model"))
81        .and_then(|v| v.as_str())
82        .map(|s| s.to_string())
83        .or_else(|| {
84            metadata
85                .get("model")
86                .and_then(|v| v.as_str())
87                .map(|s| s.to_string())
88        })
89}
90
91fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
92    metadata
93        .get("info")
94        .and_then(|info| info.get("model_context_window"))
95        .and_then(|v| v.as_u64())
96        .or_else(|| {
97            metadata
98                .get("payload")
99                .and_then(|payload| payload.get("info"))
100                .and_then(|info| info.get("model_context_window"))
101                .and_then(|v| v.as_u64())
102        })
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use agtrace_types::{TokenUsageDetails, TokenUsagePayload, ToolResultPayload, UserPayload};
109    use chrono::Utc;
110    use std::str::FromStr;
111    use uuid::Uuid;
112
113    fn base_event(payload: EventPayload) -> AgentEvent {
114        AgentEvent {
115            id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
116            session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
117            parent_id: None,
118            timestamp: Utc::now(),
119            stream_id: agtrace_types::StreamId::Main,
120            payload,
121            metadata: None,
122        }
123    }
124
125    #[test]
126    fn extracts_user_turn_flag() {
127        let event = base_event(EventPayload::User(UserPayload {
128            text: "hi".to_string(),
129        }));
130
131        let updates = extract_state_updates(&event);
132        assert!(updates.is_new_turn);
133        assert!(!updates.is_error);
134    }
135
136    #[test]
137    fn extracts_token_usage_and_reasoning() {
138        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
139            input_tokens: 100,
140            output_tokens: 50,
141            total_tokens: 150,
142            details: Some(TokenUsageDetails {
143                cache_creation_input_tokens: Some(5),
144                cache_read_input_tokens: Some(20),
145                reasoning_output_tokens: Some(7),
146            }),
147        }));
148
149        let mut meta = serde_json::Map::new();
150        meta.insert(
151            "model".to_string(),
152            serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
153        );
154        meta.insert(
155            "info".to_string(),
156            serde_json::json!({ "model_context_window": 200000 }),
157        );
158        event.metadata = Some(Value::Object(meta));
159
160        let updates = extract_state_updates(&event);
161
162        let usage = updates.usage.expect("usage should be set");
163        assert_eq!(usage.fresh_input.0, 100);
164        assert_eq!(usage.output.0, 50);
165        assert_eq!(usage.cache_creation.0, 5);
166        assert_eq!(usage.cache_read.0, 20);
167        assert_eq!(usage.total_tokens(), crate::TokenCount::new(175));
168
169        assert_eq!(updates.reasoning_tokens, Some(7));
170        assert_eq!(
171            updates.model,
172            Some("claude-3-5-sonnet-20241022".to_string())
173        );
174        assert_eq!(updates.context_window_limit, Some(200_000));
175    }
176
177    #[test]
178    fn extracts_context_window_limit_from_payload_info() {
179        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
180            input_tokens: 10,
181            output_tokens: 5,
182            total_tokens: 15,
183            details: None,
184        }));
185
186        event.metadata = Some(serde_json::json!({
187            "payload": {
188                "info": { "model_context_window": 123_000 }
189            }
190        }));
191
192        let updates = extract_state_updates(&event);
193        assert_eq!(updates.context_window_limit, Some(123_000));
194    }
195
196    #[test]
197    fn extracts_tool_result_error_flag() {
198        let event = base_event(EventPayload::ToolResult(ToolResultPayload {
199            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
200            output: "err".to_string(),
201            is_error: true,
202        }));
203
204        let updates = extract_state_updates(&event);
205        assert!(updates.is_error);
206    }
207
208    #[test]
209    fn applies_updates_to_session_state_without_io() {
210        #[derive(Default)]
211        struct SessionState {
212            model: Option<String>,
213            context_window_limit: Option<u64>,
214            usage: ContextWindowUsage,
215            reasoning_tokens: i32,
216            turn_count: usize,
217            error_count: u32,
218        }
219
220        impl SessionState {
221            fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
222                if updates.is_new_turn {
223                    self.turn_count += 1;
224                    self.error_count = 0;
225                }
226                if is_error_event && updates.is_error {
227                    self.error_count += 1;
228                }
229                if let Some(m) = updates.model {
230                    self.model.get_or_insert(m);
231                }
232                if let Some(limit) = updates.context_window_limit {
233                    self.context_window_limit.get_or_insert(limit);
234                }
235                if let Some(u) = updates.usage {
236                    self.usage = u;
237                }
238                if let Some(rt) = updates.reasoning_tokens {
239                    self.reasoning_tokens = rt;
240                }
241            }
242        }
243
244        let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
245        let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload {
246            input_tokens: 120,
247            output_tokens: 30,
248            total_tokens: 150,
249            details: Some(TokenUsageDetails {
250                cache_creation_input_tokens: Some(10),
251                cache_read_input_tokens: Some(5),
252                reasoning_output_tokens: Some(3),
253            }),
254        }));
255        let mut meta = serde_json::Map::new();
256        meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
257        meta.insert(
258            "info".into(),
259            serde_json::json!({"model_context_window": 100000}),
260        );
261        usage_event.metadata = Some(Value::Object(meta));
262
263        let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
264            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
265            output: "boom".into(),
266            is_error: true,
267        }));
268
269        let mut state = SessionState::default();
270
271        state.apply(extract_state_updates(&user), false);
272        state.apply(extract_state_updates(&usage_event), false);
273        state.apply(extract_state_updates(&tool_err), true);
274
275        assert_eq!(state.turn_count, 1);
276        assert_eq!(state.error_count, 1);
277        assert_eq!(state.model.as_deref(), Some("claude-3"));
278        assert_eq!(state.context_window_limit, Some(100_000));
279        assert_eq!(state.usage.fresh_input.0, 120);
280        assert_eq!(state.usage.output.0, 30);
281        assert_eq!(state.usage.cache_creation.0, 10);
282        assert_eq!(state.usage.cache_read.0, 5);
283        assert_eq!(state.reasoning_tokens, 3);
284    }
285}