agtrace_engine/
state_updates.rs

1use agtrace_types::{AgentEvent, ContextWindowUsage, EventPayload};
2use serde_json::Value;
3
4/// Pure data extracted from an AgentEvent to update runtime session state.
5#[derive(Debug, Clone, PartialEq, Default)]
6pub struct StateUpdates {
7    pub model: Option<String>,
8    pub context_window_limit: Option<u64>,
9    pub usage: Option<ContextWindowUsage>,
10    pub reasoning_tokens: Option<i32>,
11    pub is_error: bool,
12    pub is_new_turn: bool,
13}
14
15/// Extract state updates from a single event without performing I/O or side effects.
16pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
17    let mut updates = StateUpdates::default();
18
19    match &event.payload {
20        EventPayload::User(_) => {
21            updates.is_new_turn = true;
22        }
23        EventPayload::TokenUsage(usage) => {
24            // Convert normalized TokenUsagePayload to ContextWindowUsage
25            // The new TokenUsagePayload separates input into cached/uncached.
26            // To avoid double-counting, fresh_input = uncached only (not total).
27            updates.usage = Some(ContextWindowUsage::from_raw(
28                usage.input.uncached as i32, // fresh input tokens (not from cache)
29                0,                           // cache_creation - not separately tracked
30                usage.input.cached as i32,   // cache_read tokens (still consume context)
31                usage.output.total() as i32, // total output tokens
32            ));
33            updates.reasoning_tokens = Some(usage.output.reasoning as i32);
34        }
35        EventPayload::ToolResult(result) => {
36            if result.is_error {
37                updates.is_error = true;
38            } else {
39                // Explicitly mark success so consumers can reset counters if needed.
40                updates.is_error = false;
41            }
42        }
43        _ => {}
44    }
45
46    if let Some(metadata) = &event.metadata {
47        if updates.model.is_none() {
48            updates.model = extract_model(metadata);
49        }
50
51        if updates.context_window_limit.is_none() {
52            updates.context_window_limit = extract_context_window_limit(metadata);
53        }
54    }
55
56    updates
57}
58
59fn extract_model(metadata: &Value) -> Option<String> {
60    metadata
61        .get("message")
62        .and_then(|m| m.get("model"))
63        .and_then(|v| v.as_str())
64        .map(|s| s.to_string())
65        .or_else(|| {
66            metadata
67                .get("model")
68                .and_then(|v| v.as_str())
69                .map(|s| s.to_string())
70        })
71}
72
73fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
74    metadata
75        .get("info")
76        .and_then(|info| info.get("model_context_window"))
77        .and_then(|v| v.as_u64())
78        .or_else(|| {
79            metadata
80                .get("payload")
81                .and_then(|payload| payload.get("info"))
82                .and_then(|info| info.get("model_context_window"))
83                .and_then(|v| v.as_u64())
84        })
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use agtrace_types::{
91        TokenInput, TokenOutput, TokenUsagePayload, ToolResultPayload, UserPayload,
92    };
93    use chrono::Utc;
94    use std::str::FromStr;
95    use uuid::Uuid;
96
97    fn base_event(payload: EventPayload) -> AgentEvent {
98        AgentEvent {
99            id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
100            session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
101            parent_id: None,
102            timestamp: Utc::now(),
103            stream_id: agtrace_types::StreamId::Main,
104            payload,
105            metadata: None,
106        }
107    }
108
109    #[test]
110    fn extracts_user_turn_flag() {
111        let event = base_event(EventPayload::User(UserPayload {
112            text: "hi".to_string(),
113        }));
114
115        let updates = extract_state_updates(&event);
116        assert!(updates.is_new_turn);
117        assert!(!updates.is_error);
118    }
119
120    #[test]
121    fn extracts_token_usage_and_reasoning() {
122        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
123            TokenInput::new(20, 100),   // cached=20, uncached=100
124            TokenOutput::new(43, 7, 0), // generated=43, reasoning=7, tool=0
125        )));
126
127        let mut meta = serde_json::Map::new();
128        meta.insert(
129            "model".to_string(),
130            serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
131        );
132        meta.insert(
133            "info".to_string(),
134            serde_json::json!({ "model_context_window": 200000 }),
135        );
136        event.metadata = Some(Value::Object(meta));
137
138        let updates = extract_state_updates(&event);
139
140        let usage = updates.usage.expect("usage should be set");
141        assert_eq!(usage.fresh_input.0, 100); // uncached only (not total)
142        assert_eq!(usage.cache_read.0, 20); // cached input
143        assert_eq!(usage.output.0, 50); // generated + reasoning + tool = 43 + 7 + 0
144        assert_eq!(usage.total_tokens(), crate::TokenCount::new(170)); // 100 + 20 + 50
145
146        assert_eq!(updates.reasoning_tokens, Some(7));
147        assert_eq!(
148            updates.model,
149            Some("claude-3-5-sonnet-20241022".to_string())
150        );
151        assert_eq!(updates.context_window_limit, Some(200_000));
152    }
153
154    #[test]
155    fn extracts_context_window_limit_from_payload_info() {
156        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
157            TokenInput::new(0, 10),    // cached=0, uncached=10
158            TokenOutput::new(5, 0, 0), // generated=5, reasoning=0, tool=0
159        )));
160
161        event.metadata = Some(serde_json::json!({
162            "payload": {
163                "info": { "model_context_window": 123_000 }
164            }
165        }));
166
167        let updates = extract_state_updates(&event);
168        assert_eq!(updates.context_window_limit, Some(123_000));
169    }
170
171    #[test]
172    fn extracts_tool_result_error_flag() {
173        let event = base_event(EventPayload::ToolResult(ToolResultPayload {
174            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
175            output: "err".to_string(),
176            is_error: true,
177        }));
178
179        let updates = extract_state_updates(&event);
180        assert!(updates.is_error);
181    }
182
183    #[test]
184    fn token_usage_conversion_avoids_double_counting_cached_tokens() {
185        // Bug reproduction test: cached tokens should NOT be counted twice
186        //
187        // Given a TokenUsagePayload with:
188        //   input:  cached=20, uncached=100 (total input = 120)
189        //   output: generated=50 (total output = 50)
190        //
191        // Expected ContextWindowUsage:
192        //   fresh_input:    100 (uncached only)
193        //   cache_read:      20 (cached tokens)
194        //   output:          50
195        //   total_tokens:   170 (100 + 20 + 50)
196        //
197        // Bug produces:
198        //   fresh_input:    120 (input.total() = cached + uncached)
199        //   cache_read:      20 (same)
200        //   total_tokens:   190 (120 + 20 + 50) ❌ cached counted twice!
201
202        let event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
203            TokenInput::new(20, 100),   // cached=20, uncached=100
204            TokenOutput::new(50, 0, 0), // generated=50, reasoning=0, tool=0
205        )));
206
207        let updates = extract_state_updates(&event);
208        let usage = updates.usage.expect("usage should be set");
209
210        // CORRECT expectations (this test will FAIL until bug is fixed):
211        assert_eq!(
212            usage.fresh_input.0, 100,
213            "fresh_input should be uncached tokens only (not total)"
214        );
215        assert_eq!(usage.cache_read.0, 20, "cache_read should be cached tokens");
216        assert_eq!(usage.output.0, 50, "output should match");
217        assert_eq!(
218            usage.total_tokens(),
219            crate::TokenCount::new(170),
220            "total should be 100 (fresh) + 20 (cache) + 50 (output) = 170, not 190"
221        );
222    }
223
224    #[test]
225    fn token_usage_conversion_uses_uncached_for_fresh_input() {
226        // Consistency test: The conversion logic should match merge_usage semantics
227        // which correctly uses input.uncached for fresh_input (not input.total())
228        //
229        // This ensures extract_state_updates produces the same result as the
230        // conversion done in session assembly (stats::merge_usage)
231
232        let token_payload = TokenUsagePayload::new(
233            TokenInput::new(30, 200),    // cached=30, uncached=200
234            TokenOutput::new(80, 10, 5), // generated=80, reasoning=10, tool=5
235        );
236
237        let event = base_event(EventPayload::TokenUsage(token_payload));
238        let updates = extract_state_updates(&event);
239        let usage = updates.usage.expect("usage should be set");
240
241        // Should use uncached only for fresh_input (matching merge_usage logic)
242        assert_eq!(
243            usage.fresh_input.0, 200,
244            "fresh_input must be uncached tokens only (200), not total (230)"
245        );
246        assert_eq!(usage.cache_read.0, 30);
247        assert_eq!(usage.output.0, 95); // 80 + 10 + 5
248        assert_eq!(usage.total_tokens(), crate::TokenCount::new(325)); // 200 + 30 + 95
249    }
250
251    #[test]
252    fn applies_updates_to_session_state_without_io() {
253        #[derive(Default)]
254        struct SessionState {
255            model: Option<String>,
256            context_window_limit: Option<u64>,
257            usage: ContextWindowUsage,
258            reasoning_tokens: i32,
259            turn_count: usize,
260            error_count: u32,
261        }
262
263        impl SessionState {
264            fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
265                if updates.is_new_turn {
266                    self.turn_count += 1;
267                    self.error_count = 0;
268                }
269                if is_error_event && updates.is_error {
270                    self.error_count += 1;
271                }
272                if let Some(m) = updates.model {
273                    self.model.get_or_insert(m);
274                }
275                if let Some(limit) = updates.context_window_limit {
276                    self.context_window_limit.get_or_insert(limit);
277                }
278                if let Some(u) = updates.usage {
279                    self.usage = u;
280                }
281                if let Some(rt) = updates.reasoning_tokens {
282                    self.reasoning_tokens = rt;
283                }
284            }
285        }
286
287        let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
288        let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
289            TokenInput::new(5, 120),    // cached=5, uncached=120
290            TokenOutput::new(27, 3, 0), // generated=27, reasoning=3, tool=0
291        )));
292        let mut meta = serde_json::Map::new();
293        meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
294        meta.insert(
295            "info".into(),
296            serde_json::json!({"model_context_window": 100000}),
297        );
298        usage_event.metadata = Some(Value::Object(meta));
299
300        let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
301            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
302            output: "boom".into(),
303            is_error: true,
304        }));
305
306        let mut state = SessionState::default();
307
308        state.apply(extract_state_updates(&user), false);
309        state.apply(extract_state_updates(&usage_event), false);
310        state.apply(extract_state_updates(&tool_err), true);
311
312        assert_eq!(state.turn_count, 1);
313        assert_eq!(state.error_count, 1);
314        assert_eq!(state.model.as_deref(), Some("claude-3"));
315        assert_eq!(state.context_window_limit, Some(100_000));
316        assert_eq!(state.usage.fresh_input.0, 120); // uncached only (not total)
317        assert_eq!(state.usage.cache_read.0, 5);
318        assert_eq!(state.usage.output.0, 30); // generated + reasoning + tool = 27 + 3 + 0
319        assert_eq!(state.reasoning_tokens, 3);
320    }
321}