1use std::sync::Mutex;
7use std::time::Instant;
8
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCallTrace {
15 pub tool_name: String,
16 pub input: serde_json::Value,
17 pub output: serde_json::Value,
18 pub duration_ms: u64,
19 pub error: Option<String>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct LlmCallTrace {
25 pub model: String,
26 pub prompt_tokens: u32,
27 pub completion_tokens: u32,
28 pub duration_ms: u64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct StepTrace {
34 pub step_type: String,
35 pub content: String,
36 pub duration_ms: u64,
37}
38
39pub struct ExecutionTrace {
44 trace_id: String,
45 start_time: Instant,
46 inner: Mutex<TraceInner>,
47}
48
49struct TraceInner {
51 tool_calls: Vec<ToolCallTrace>,
52 llm_calls: Vec<LlmCallTrace>,
53 steps: Vec<StepTrace>,
54 memory_context: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct FinalizedTrace {
60 pub trace_id: String,
61 pub tool_calls: Vec<ToolCallTrace>,
62 pub llm_calls: Vec<LlmCallTrace>,
63 pub steps: Vec<StepTrace>,
64 pub memory_context: String,
65 pub response: String,
66 pub total_duration_ms: u64,
67}
68
69impl ExecutionTrace {
70 pub fn new() -> Self {
72 Self {
73 trace_id: uuid::Uuid::new_v4().to_string(),
74 start_time: Instant::now(),
75 inner: Mutex::new(TraceInner {
76 tool_calls: Vec::new(),
77 llm_calls: Vec::new(),
78 steps: Vec::new(),
79 memory_context: String::new(),
80 }),
81 }
82 }
83
84 pub fn trace_id(&self) -> &str {
86 &self.trace_id
87 }
88
89 pub fn record_tool_call(&self, trace: ToolCallTrace) {
91 let mut inner = self.inner.lock().expect("trace lock poisoned");
92 inner.tool_calls.push(trace);
93 }
94
95 pub fn record_llm_call(&self, trace: LlmCallTrace) {
97 let mut inner = self.inner.lock().expect("trace lock poisoned");
98 inner.llm_calls.push(trace);
99 }
100
101 pub fn record_step(&self, trace: StepTrace) {
103 let mut inner = self.inner.lock().expect("trace lock poisoned");
104 inner.steps.push(trace);
105 }
106
107 pub fn set_memory_context(&self, ctx: String) {
109 let mut inner = self.inner.lock().expect("trace lock poisoned");
110 inner.memory_context = ctx;
111 }
112
113 pub fn finalize(&self, response: String) -> FinalizedTrace {
115 let total_duration_ms = self.start_time.elapsed().as_millis() as u64;
116 let inner = self.inner.lock().expect("trace lock poisoned");
117 FinalizedTrace {
118 trace_id: self.trace_id.clone(),
119 tool_calls: inner.tool_calls.clone(),
120 llm_calls: inner.llm_calls.clone(),
121 steps: inner.steps.clone(),
122 memory_context: inner.memory_context.clone(),
123 response,
124 total_duration_ms,
125 }
126 }
127}
128
129impl Default for ExecutionTrace {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl FinalizedTrace {
136 pub fn to_json(&self) -> serde_json::Value {
138 json!({
139 "trace_id": self.trace_id,
140 "tool_calls": self.tool_calls,
141 "llm_calls": self.llm_calls,
142 "steps": self.steps,
143 "memory_context": self.memory_context,
144 "response": self.response,
145 "total_duration_ms": self.total_duration_ms,
146 })
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use serde_json::json;
154
155 #[test]
156 fn test_create_trace_record_tool_calls_and_finalize() {
157 let trace = ExecutionTrace::new();
158 let trace_id = trace.trace_id().to_string();
159 assert!(!trace_id.is_empty());
160
161 trace.record_tool_call(ToolCallTrace {
162 tool_name: "search".to_string(),
163 input: json!({"query": "rust"}),
164 output: json!({"results": ["a", "b"]}),
165 duration_ms: 150,
166 error: None,
167 });
168 trace.record_tool_call(ToolCallTrace {
169 tool_name: "fetch".to_string(),
170 input: json!({"url": "https://example.com"}),
171 output: json!({"status": 200}),
172 duration_ms: 300,
173 error: None,
174 });
175 trace.record_tool_call(ToolCallTrace {
176 tool_name: "write".to_string(),
177 input: json!({"path": "/tmp/out.txt"}),
178 output: json!({}),
179 duration_ms: 50,
180 error: Some("permission denied".to_string()),
181 });
182
183 let finalized = trace.finalize("Final answer".to_string());
184
185 assert_eq!(finalized.trace_id, trace_id);
186 assert_eq!(finalized.tool_calls.len(), 3);
187 assert_eq!(finalized.tool_calls[0].tool_name, "search");
188 assert_eq!(finalized.tool_calls[1].tool_name, "fetch");
189 assert_eq!(finalized.tool_calls[2].tool_name, "write");
190 assert_eq!(finalized.tool_calls[2].error, Some("permission denied".to_string()));
191 assert_eq!(finalized.response, "Final answer");
192 assert!(finalized.total_duration_ms > 0 || finalized.total_duration_ms == 0);
193
194 let json_val = finalized.to_json();
196 assert_eq!(json_val["trace_id"], trace_id);
197 assert_eq!(json_val["tool_calls"].as_array().unwrap().len(), 3);
198 assert_eq!(json_val["tool_calls"][0]["tool_name"], "search");
199 assert!(json_val["total_duration_ms"].is_u64());
200 assert_eq!(json_val["response"], "Final answer");
201 }
202
203 #[test]
204 fn test_verify_duration_calculation() {
205 let trace = ExecutionTrace::new();
206
207 trace.record_llm_call(LlmCallTrace {
208 model: "gpt-4".to_string(),
209 prompt_tokens: 100,
210 completion_tokens: 50,
211 duration_ms: 500,
212 });
213
214 trace.record_step(StepTrace {
215 step_type: "think".to_string(),
216 content: "Reasoning about the problem".to_string(),
217 duration_ms: 200,
218 });
219
220 std::thread::sleep(std::time::Duration::from_millis(5));
222
223 let finalized = trace.finalize("done".to_string());
224 assert!(finalized.total_duration_ms >= 5);
226 assert_eq!(finalized.llm_calls.len(), 1);
227 assert_eq!(finalized.llm_calls[0].model, "gpt-4");
228 assert_eq!(finalized.llm_calls[0].prompt_tokens, 100);
229 assert_eq!(finalized.steps.len(), 1);
230 assert_eq!(finalized.steps[0].step_type, "think");
231 }
232
233 #[test]
234 fn test_empty_trace_serializes_correctly() {
235 let trace = ExecutionTrace::new();
236 let finalized = trace.finalize("".to_string());
237
238 let json_val = finalized.to_json();
239 assert!(json_val["trace_id"].is_string());
240 assert_eq!(json_val["tool_calls"].as_array().unwrap().len(), 0);
241 assert_eq!(json_val["llm_calls"].as_array().unwrap().len(), 0);
242 assert_eq!(json_val["steps"].as_array().unwrap().len(), 0);
243 assert_eq!(json_val["memory_context"], "");
244 assert_eq!(json_val["response"], "");
245 assert!(json_val["total_duration_ms"].is_u64());
246 }
247
248 #[test]
249 fn test_set_memory_context() {
250 let trace = ExecutionTrace::new();
251 trace.set_memory_context("Previous conversation about Rust".to_string());
252
253 let finalized = trace.finalize("response".to_string());
254 assert_eq!(finalized.memory_context, "Previous conversation about Rust");
255
256 let json_val = finalized.to_json();
257 assert_eq!(json_val["memory_context"], "Previous conversation about Rust");
258 }
259
260 #[test]
261 fn test_finalized_trace_is_clone_and_serializable() {
262 let trace = ExecutionTrace::new();
263 trace.record_tool_call(ToolCallTrace {
264 tool_name: "test".to_string(),
265 input: json!({}),
266 output: json!({}),
267 duration_ms: 10,
268 error: None,
269 });
270
271 let finalized = trace.finalize("ok".to_string());
272 let cloned = finalized.clone();
273 assert_eq!(finalized.trace_id, cloned.trace_id);
274
275 let serialized = serde_json::to_string(&finalized).unwrap();
277 let deserialized: FinalizedTrace = serde_json::from_str(&serialized).unwrap();
278 assert_eq!(deserialized.trace_id, finalized.trace_id);
279 assert_eq!(deserialized.tool_calls.len(), 1);
280 }
281}