agtrace_engine/
state_updates.rs

1use agtrace_types::{AgentEvent, ContextWindowUsage, EventPayload};
2use serde_json::Value;
3
4/// Pure data extracted from an AgentEvent to update runtime session state.
5#[derive(Debug, Clone, PartialEq, Default)]
6pub struct StateUpdates {
7    pub model: Option<String>,
8    pub context_window_limit: Option<u64>,
9    pub usage: Option<ContextWindowUsage>,
10    pub reasoning_tokens: Option<i32>,
11    pub is_error: bool,
12    pub is_new_turn: bool,
13}
14
15/// Extract state updates from a single event without performing I/O or side effects.
16pub fn extract_state_updates(event: &AgentEvent) -> StateUpdates {
17    let mut updates = StateUpdates::default();
18
19    match &event.payload {
20        EventPayload::User(_) | EventPayload::SlashCommand(_) => {
21            updates.is_new_turn = true;
22        }
23        EventPayload::TokenUsage(usage) => {
24            // Convert normalized TokenUsagePayload to ContextWindowUsage
25            // The new TokenUsagePayload separates input into cached/uncached.
26            // To avoid double-counting, fresh_input = uncached only (not total).
27            updates.usage = Some(ContextWindowUsage::from_raw(
28                usage.input.uncached as i32, // fresh input tokens (not from cache)
29                0,                           // cache_creation - not separately tracked
30                usage.input.cached as i32,   // cache_read tokens (still consume context)
31                usage.output.total() as i32, // total output tokens
32            ));
33            updates.reasoning_tokens = Some(usage.output.reasoning as i32);
34        }
35        EventPayload::ToolResult(result) => {
36            if result.is_error {
37                updates.is_error = true;
38            } else {
39                // Explicitly mark success so consumers can reset counters if needed.
40                updates.is_error = false;
41            }
42        }
43        _ => {}
44    }
45
46    if let Some(metadata) = &event.metadata {
47        if updates.model.is_none() {
48            updates.model = extract_model(metadata);
49        }
50
51        if updates.context_window_limit.is_none() {
52            updates.context_window_limit = extract_context_window_limit(metadata);
53        }
54    }
55
56    updates
57}
58
59fn extract_model(metadata: &Value) -> Option<String> {
60    metadata
61        .get("message")
62        .and_then(|m| m.get("model"))
63        .and_then(|v| v.as_str())
64        .map(|s| s.to_string())
65        .or_else(|| {
66            metadata
67                .get("model")
68                .and_then(|v| v.as_str())
69                .map(|s| s.to_string())
70        })
71}
72
73fn extract_context_window_limit(metadata: &Value) -> Option<u64> {
74    metadata
75        .get("info")
76        .and_then(|info| info.get("model_context_window"))
77        .and_then(|v| v.as_u64())
78        .or_else(|| {
79            metadata
80                .get("payload")
81                .and_then(|payload| payload.get("info"))
82                .and_then(|info| info.get("model_context_window"))
83                .and_then(|v| v.as_u64())
84        })
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use agtrace_types::{
91        TokenInput, TokenOutput, TokenUsagePayload, ToolResultPayload, UserPayload,
92    };
93    use chrono::Utc;
94    use std::str::FromStr;
95    use uuid::Uuid;
96
97    fn base_event(payload: EventPayload) -> AgentEvent {
98        AgentEvent {
99            id: Uuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(),
100            session_id: Uuid::from_str("00000000-0000-0000-0000-000000000002").unwrap(),
101            parent_id: None,
102            timestamp: Utc::now(),
103            stream_id: agtrace_types::StreamId::Main,
104            payload,
105            metadata: None,
106        }
107    }
108
109    #[test]
110    fn extracts_user_turn_flag() {
111        let event = base_event(EventPayload::User(UserPayload {
112            text: "hi".to_string(),
113        }));
114
115        let updates = extract_state_updates(&event);
116        assert!(updates.is_new_turn);
117        assert!(!updates.is_error);
118    }
119
120    #[test]
121    fn extracts_token_usage_and_reasoning() {
122        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
123            TokenInput::new(20, 100),   // cached=20, uncached=100
124            TokenOutput::new(43, 7, 0), // generated=43, reasoning=7, tool=0
125        )));
126
127        let mut meta = serde_json::Map::new();
128        meta.insert(
129            "model".to_string(),
130            serde_json::Value::String("claude-3-5-sonnet-20241022".to_string()),
131        );
132        meta.insert(
133            "info".to_string(),
134            serde_json::json!({ "model_context_window": 200000 }),
135        );
136        event.metadata = Some(Value::Object(meta));
137
138        let updates = extract_state_updates(&event);
139
140        let usage = updates.usage.expect("usage should be set");
141        assert_eq!(usage.fresh_input.0, 100); // uncached only (not total)
142        assert_eq!(usage.cache_read.0, 20); // cached input
143        assert_eq!(usage.output.0, 50); // generated + reasoning + tool = 43 + 7 + 0
144        assert_eq!(usage.total_tokens(), crate::TokenCount::new(170)); // 100 + 20 + 50
145
146        assert_eq!(updates.reasoning_tokens, Some(7));
147        assert_eq!(
148            updates.model,
149            Some("claude-3-5-sonnet-20241022".to_string())
150        );
151        assert_eq!(updates.context_window_limit, Some(200_000));
152    }
153
154    #[test]
155    fn extracts_context_window_limit_from_payload_info() {
156        let mut event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
157            TokenInput::new(0, 10),    // cached=0, uncached=10
158            TokenOutput::new(5, 0, 0), // generated=5, reasoning=0, tool=0
159        )));
160
161        event.metadata = Some(serde_json::json!({
162            "payload": {
163                "info": { "model_context_window": 123_000 }
164            }
165        }));
166
167        let updates = extract_state_updates(&event);
168        assert_eq!(updates.context_window_limit, Some(123_000));
169    }
170
171    #[test]
172    fn extracts_tool_result_error_flag() {
173        let event = base_event(EventPayload::ToolResult(ToolResultPayload {
174            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000003").unwrap(),
175            output: "err".to_string(),
176            is_error: true,
177            agent_id: None,
178        }));
179
180        let updates = extract_state_updates(&event);
181        assert!(updates.is_error);
182    }
183
184    #[test]
185    fn token_usage_conversion_avoids_double_counting_cached_tokens() {
186        // Bug reproduction test: cached tokens should NOT be counted twice
187        //
188        // Given a TokenUsagePayload with:
189        //   input:  cached=20, uncached=100 (total input = 120)
190        //   output: generated=50 (total output = 50)
191        //
192        // Expected ContextWindowUsage:
193        //   fresh_input:    100 (uncached only)
194        //   cache_read:      20 (cached tokens)
195        //   output:          50
196        //   total_tokens:   170 (100 + 20 + 50)
197        //
198        // Bug produces:
199        //   fresh_input:    120 (input.total() = cached + uncached)
200        //   cache_read:      20 (same)
201        //   total_tokens:   190 (120 + 20 + 50) ❌ cached counted twice!
202
203        let event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
204            TokenInput::new(20, 100),   // cached=20, uncached=100
205            TokenOutput::new(50, 0, 0), // generated=50, reasoning=0, tool=0
206        )));
207
208        let updates = extract_state_updates(&event);
209        let usage = updates.usage.expect("usage should be set");
210
211        // CORRECT expectations (this test will FAIL until bug is fixed):
212        assert_eq!(
213            usage.fresh_input.0, 100,
214            "fresh_input should be uncached tokens only (not total)"
215        );
216        assert_eq!(usage.cache_read.0, 20, "cache_read should be cached tokens");
217        assert_eq!(usage.output.0, 50, "output should match");
218        assert_eq!(
219            usage.total_tokens(),
220            crate::TokenCount::new(170),
221            "total should be 100 (fresh) + 20 (cache) + 50 (output) = 170, not 190"
222        );
223    }
224
225    #[test]
226    fn token_usage_conversion_uses_uncached_for_fresh_input() {
227        // Consistency test: The conversion logic should match merge_usage semantics
228        // which correctly uses input.uncached for fresh_input (not input.total())
229        //
230        // This ensures extract_state_updates produces the same result as the
231        // conversion done in session assembly (stats::merge_usage)
232
233        let token_payload = TokenUsagePayload::new(
234            TokenInput::new(30, 200),    // cached=30, uncached=200
235            TokenOutput::new(80, 10, 5), // generated=80, reasoning=10, tool=5
236        );
237
238        let event = base_event(EventPayload::TokenUsage(token_payload));
239        let updates = extract_state_updates(&event);
240        let usage = updates.usage.expect("usage should be set");
241
242        // Should use uncached only for fresh_input (matching merge_usage logic)
243        assert_eq!(
244            usage.fresh_input.0, 200,
245            "fresh_input must be uncached tokens only (200), not total (230)"
246        );
247        assert_eq!(usage.cache_read.0, 30);
248        assert_eq!(usage.output.0, 95); // 80 + 10 + 5
249        assert_eq!(usage.total_tokens(), crate::TokenCount::new(325)); // 200 + 30 + 95
250    }
251
252    #[test]
253    fn applies_updates_to_session_state_without_io() {
254        #[derive(Default)]
255        struct SessionState {
256            model: Option<String>,
257            context_window_limit: Option<u64>,
258            usage: ContextWindowUsage,
259            reasoning_tokens: i32,
260            turn_count: usize,
261            error_count: u32,
262        }
263
264        impl SessionState {
265            fn apply(&mut self, updates: StateUpdates, is_error_event: bool) {
266                if updates.is_new_turn {
267                    self.turn_count += 1;
268                    self.error_count = 0;
269                }
270                if is_error_event && updates.is_error {
271                    self.error_count += 1;
272                }
273                if let Some(m) = updates.model {
274                    self.model.get_or_insert(m);
275                }
276                if let Some(limit) = updates.context_window_limit {
277                    self.context_window_limit.get_or_insert(limit);
278                }
279                if let Some(u) = updates.usage {
280                    self.usage = u;
281                }
282                if let Some(rt) = updates.reasoning_tokens {
283                    self.reasoning_tokens = rt;
284                }
285            }
286        }
287
288        let user = base_event(EventPayload::User(UserPayload { text: "hi".into() }));
289        let mut usage_event = base_event(EventPayload::TokenUsage(TokenUsagePayload::new(
290            TokenInput::new(5, 120),    // cached=5, uncached=120
291            TokenOutput::new(27, 3, 0), // generated=27, reasoning=3, tool=0
292        )));
293        let mut meta = serde_json::Map::new();
294        meta.insert("model".into(), serde_json::Value::String("claude-3".into()));
295        meta.insert(
296            "info".into(),
297            serde_json::json!({"model_context_window": 100000}),
298        );
299        usage_event.metadata = Some(Value::Object(meta));
300
301        let tool_err = base_event(EventPayload::ToolResult(ToolResultPayload {
302            tool_call_id: Uuid::from_str("00000000-0000-0000-0000-000000000009").unwrap(),
303            output: "boom".into(),
304            is_error: true,
305            agent_id: None,
306        }));
307
308        let mut state = SessionState::default();
309
310        state.apply(extract_state_updates(&user), false);
311        state.apply(extract_state_updates(&usage_event), false);
312        state.apply(extract_state_updates(&tool_err), true);
313
314        assert_eq!(state.turn_count, 1);
315        assert_eq!(state.error_count, 1);
316        assert_eq!(state.model.as_deref(), Some("claude-3"));
317        assert_eq!(state.context_window_limit, Some(100_000));
318        assert_eq!(state.usage.fresh_input.0, 120); // uncached only (not total)
319        assert_eq!(state.usage.cache_read.0, 5);
320        assert_eq!(state.usage.output.0, 30); // generated + reasoning + tool = 27 + 3 + 0
321        assert_eq!(state.reasoning_tokens, 3);
322    }
323}