Skip to main content

cortexai_agents/
trace.rs

1//! Execution trace collection for agent runs
2//!
3//! Captures tool calls, LLM calls, step traces, and memory context
4//! during agent execution for observability and eval frameworks.
5
6use std::sync::Mutex;
7use std::time::Instant;
8
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12/// Trace of a single tool call execution.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCallTrace {
15    pub tool_name: String,
16    pub input: serde_json::Value,
17    pub output: serde_json::Value,
18    pub duration_ms: u64,
19    pub error: Option<String>,
20}
21
22/// Trace of a single LLM inference call.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LlmCallTrace {
25    pub model: String,
26    pub prompt_tokens: u32,
27    pub completion_tokens: u32,
28    pub duration_ms: u64,
29}
30
31/// Trace of a single ReACT step.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct StepTrace {
34    pub step_type: String,
35    pub content: String,
36    pub duration_ms: u64,
37}
38
39/// Mutable trace collector used during agent execution.
40///
41/// Uses interior mutability (`Mutex`) so it can be shared across
42/// async tasks via `Arc<ExecutionTrace>`.
43pub struct ExecutionTrace {
44    trace_id: String,
45    start_time: Instant,
46    inner: Mutex<TraceInner>,
47}
48
49/// Interior mutable state of an `ExecutionTrace`.
50struct TraceInner {
51    tool_calls: Vec<ToolCallTrace>,
52    llm_calls: Vec<LlmCallTrace>,
53    steps: Vec<StepTrace>,
54    memory_context: String,
55}
56
57/// Immutable, serializable snapshot of a completed execution trace.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct FinalizedTrace {
60    pub trace_id: String,
61    pub tool_calls: Vec<ToolCallTrace>,
62    pub llm_calls: Vec<LlmCallTrace>,
63    pub steps: Vec<StepTrace>,
64    pub memory_context: String,
65    pub response: String,
66    pub total_duration_ms: u64,
67}
68
69impl ExecutionTrace {
70    /// Create a new trace collector.
71    pub fn new() -> Self {
72        Self {
73            trace_id: uuid::Uuid::new_v4().to_string(),
74            start_time: Instant::now(),
75            inner: Mutex::new(TraceInner {
76                tool_calls: Vec::new(),
77                llm_calls: Vec::new(),
78                steps: Vec::new(),
79                memory_context: String::new(),
80            }),
81        }
82    }
83
84    /// The unique identifier for this trace.
85    pub fn trace_id(&self) -> &str {
86        &self.trace_id
87    }
88
89    /// Record a tool call execution.
90    pub fn record_tool_call(&self, trace: ToolCallTrace) {
91        let mut inner = self.inner.lock().expect("trace lock poisoned");
92        inner.tool_calls.push(trace);
93    }
94
95    /// Record an LLM inference call.
96    pub fn record_llm_call(&self, trace: LlmCallTrace) {
97        let mut inner = self.inner.lock().expect("trace lock poisoned");
98        inner.llm_calls.push(trace);
99    }
100
101    /// Record a ReACT step.
102    pub fn record_step(&self, trace: StepTrace) {
103        let mut inner = self.inner.lock().expect("trace lock poisoned");
104        inner.steps.push(trace);
105    }
106
107    /// Set the memory context that was loaded for this execution.
108    pub fn set_memory_context(&self, ctx: String) {
109        let mut inner = self.inner.lock().expect("trace lock poisoned");
110        inner.memory_context = ctx;
111    }
112
113    /// Consume the trace and produce an immutable finalized snapshot.
114    pub fn finalize(&self, response: String) -> FinalizedTrace {
115        let total_duration_ms = self.start_time.elapsed().as_millis() as u64;
116        let inner = self.inner.lock().expect("trace lock poisoned");
117        FinalizedTrace {
118            trace_id: self.trace_id.clone(),
119            tool_calls: inner.tool_calls.clone(),
120            llm_calls: inner.llm_calls.clone(),
121            steps: inner.steps.clone(),
122            memory_context: inner.memory_context.clone(),
123            response,
124            total_duration_ms,
125        }
126    }
127}
128
129impl Default for ExecutionTrace {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135impl FinalizedTrace {
136    /// Serialize to a JSON value.
137    pub fn to_json(&self) -> serde_json::Value {
138        json!({
139            "trace_id": self.trace_id,
140            "tool_calls": self.tool_calls,
141            "llm_calls": self.llm_calls,
142            "steps": self.steps,
143            "memory_context": self.memory_context,
144            "response": self.response,
145            "total_duration_ms": self.total_duration_ms,
146        })
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use serde_json::json;
154
155    #[test]
156    fn test_create_trace_record_tool_calls_and_finalize() {
157        let trace = ExecutionTrace::new();
158        let trace_id = trace.trace_id().to_string();
159        assert!(!trace_id.is_empty());
160
161        trace.record_tool_call(ToolCallTrace {
162            tool_name: "search".to_string(),
163            input: json!({"query": "rust"}),
164            output: json!({"results": ["a", "b"]}),
165            duration_ms: 150,
166            error: None,
167        });
168        trace.record_tool_call(ToolCallTrace {
169            tool_name: "fetch".to_string(),
170            input: json!({"url": "https://example.com"}),
171            output: json!({"status": 200}),
172            duration_ms: 300,
173            error: None,
174        });
175        trace.record_tool_call(ToolCallTrace {
176            tool_name: "write".to_string(),
177            input: json!({"path": "/tmp/out.txt"}),
178            output: json!({}),
179            duration_ms: 50,
180            error: Some("permission denied".to_string()),
181        });
182
183        let finalized = trace.finalize("Final answer".to_string());
184
185        assert_eq!(finalized.trace_id, trace_id);
186        assert_eq!(finalized.tool_calls.len(), 3);
187        assert_eq!(finalized.tool_calls[0].tool_name, "search");
188        assert_eq!(finalized.tool_calls[1].tool_name, "fetch");
189        assert_eq!(finalized.tool_calls[2].tool_name, "write");
190        assert_eq!(finalized.tool_calls[2].error, Some("permission denied".to_string()));
191        assert_eq!(finalized.response, "Final answer");
192        assert!(finalized.total_duration_ms > 0 || finalized.total_duration_ms == 0);
193
194        // Verify JSON structure
195        let json_val = finalized.to_json();
196        assert_eq!(json_val["trace_id"], trace_id);
197        assert_eq!(json_val["tool_calls"].as_array().unwrap().len(), 3);
198        assert_eq!(json_val["tool_calls"][0]["tool_name"], "search");
199        assert!(json_val["total_duration_ms"].is_u64());
200        assert_eq!(json_val["response"], "Final answer");
201    }
202
203    #[test]
204    fn test_verify_duration_calculation() {
205        let trace = ExecutionTrace::new();
206
207        trace.record_llm_call(LlmCallTrace {
208            model: "gpt-4".to_string(),
209            prompt_tokens: 100,
210            completion_tokens: 50,
211            duration_ms: 500,
212        });
213
214        trace.record_step(StepTrace {
215            step_type: "think".to_string(),
216            content: "Reasoning about the problem".to_string(),
217            duration_ms: 200,
218        });
219
220        // Sleep briefly to ensure total_duration_ms > 0
221        std::thread::sleep(std::time::Duration::from_millis(5));
222
223        let finalized = trace.finalize("done".to_string());
224        // total_duration_ms is wall clock from new() to finalize()
225        assert!(finalized.total_duration_ms >= 5);
226        assert_eq!(finalized.llm_calls.len(), 1);
227        assert_eq!(finalized.llm_calls[0].model, "gpt-4");
228        assert_eq!(finalized.llm_calls[0].prompt_tokens, 100);
229        assert_eq!(finalized.steps.len(), 1);
230        assert_eq!(finalized.steps[0].step_type, "think");
231    }
232
233    #[test]
234    fn test_empty_trace_serializes_correctly() {
235        let trace = ExecutionTrace::new();
236        let finalized = trace.finalize("".to_string());
237
238        let json_val = finalized.to_json();
239        assert!(json_val["trace_id"].is_string());
240        assert_eq!(json_val["tool_calls"].as_array().unwrap().len(), 0);
241        assert_eq!(json_val["llm_calls"].as_array().unwrap().len(), 0);
242        assert_eq!(json_val["steps"].as_array().unwrap().len(), 0);
243        assert_eq!(json_val["memory_context"], "");
244        assert_eq!(json_val["response"], "");
245        assert!(json_val["total_duration_ms"].is_u64());
246    }
247
248    #[test]
249    fn test_set_memory_context() {
250        let trace = ExecutionTrace::new();
251        trace.set_memory_context("Previous conversation about Rust".to_string());
252
253        let finalized = trace.finalize("response".to_string());
254        assert_eq!(finalized.memory_context, "Previous conversation about Rust");
255
256        let json_val = finalized.to_json();
257        assert_eq!(json_val["memory_context"], "Previous conversation about Rust");
258    }
259
260    #[test]
261    fn test_finalized_trace_is_clone_and_serializable() {
262        let trace = ExecutionTrace::new();
263        trace.record_tool_call(ToolCallTrace {
264            tool_name: "test".to_string(),
265            input: json!({}),
266            output: json!({}),
267            duration_ms: 10,
268            error: None,
269        });
270
271        let finalized = trace.finalize("ok".to_string());
272        let cloned = finalized.clone();
273        assert_eq!(finalized.trace_id, cloned.trace_id);
274
275        // Verify serde round-trip
276        let serialized = serde_json::to_string(&finalized).unwrap();
277        let deserialized: FinalizedTrace = serde_json::from_str(&serialized).unwrap();
278        assert_eq!(deserialized.trace_id, finalized.trace_id);
279        assert_eq!(deserialized.tool_calls.len(), 1);
280    }
281}