1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum StepType {
13 AudioCapture,
15 VoiceActivity,
17 SpeechToText,
19 Retrieval,
21 LlmGeneration,
23 ToolExecution,
25 TextToSpeech,
27 AudioPlayback,
29 Error,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TraceStep {
36 pub step_type: StepType,
38
39 pub name: String,
41
42 pub start_time_ms: u64,
44
45 pub duration_ms: u64,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub input: Option<serde_json::Value>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub output: Option<serde_json::Value>,
55
56 #[serde(skip_serializing_if = "HashMap::is_empty", default)]
58 pub metadata: HashMap<String, String>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub error: Option<String>,
63}
64
65impl TraceStep {
66 pub fn new(step_type: StepType, name: impl Into<String>) -> Self {
68 let now = SystemTime::now()
69 .duration_since(UNIX_EPOCH)
70 .unwrap_or(Duration::ZERO);
71
72 Self {
73 step_type,
74 name: name.into(),
75 start_time_ms: now.as_millis() as u64,
76 duration_ms: 0,
77 input: None,
78 output: None,
79 metadata: HashMap::new(),
80 error: None,
81 }
82 }
83
84 pub fn with_input(mut self, input: serde_json::Value) -> Self {
86 self.input = Some(input);
87 self
88 }
89
90 pub fn with_output(mut self, output: serde_json::Value) -> Self {
92 self.output = Some(output);
93 self
94 }
95
96 pub fn with_duration(mut self, duration_ms: u64) -> Self {
98 self.duration_ms = duration_ms;
99 self
100 }
101
102 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
104 self.metadata.insert(key.into(), value.into());
105 self
106 }
107
108 pub fn with_error(mut self, error: impl Into<String>) -> Self {
110 self.step_type = StepType::Error;
111 self.error = Some(error.into());
112 self
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ExecutionTrace {
119 pub trace_id: String,
121
122 pub conversation_id: Option<u64>,
124
125 pub turn_number: Option<u64>,
127
128 pub steps: Vec<TraceStep>,
130
131 pub total_duration_ms: u64,
133
134 pub start_time_ms: u64,
136}
137
138impl ExecutionTrace {
139 pub fn new(trace_id: impl Into<String>) -> Self {
141 let now = SystemTime::now()
142 .duration_since(UNIX_EPOCH)
143 .unwrap_or(Duration::ZERO);
144
145 Self {
146 trace_id: trace_id.into(),
147 conversation_id: None,
148 turn_number: None,
149 steps: Vec::new(),
150 total_duration_ms: 0,
151 start_time_ms: now.as_millis() as u64,
152 }
153 }
154
155 pub fn with_conversation(mut self, conversation_id: u64, turn_number: u64) -> Self {
157 self.conversation_id = Some(conversation_id);
158 self.turn_number = Some(turn_number);
159 self
160 }
161
162 pub fn add_step(&mut self, step: TraceStep) {
164 self.steps.push(step);
165 self.update_total_duration();
166 }
167
168 pub fn finalize(&mut self) {
170 self.update_total_duration();
171 }
172
173 fn update_total_duration(&mut self) {
175 if let (Some(first), Some(last)) = (self.steps.first(), self.steps.last()) {
176 self.total_duration_ms =
177 (last.start_time_ms + last.duration_ms).saturating_sub(first.start_time_ms);
178 }
179 }
180
181 pub fn to_json(&self) -> Result<String, serde_json::Error> {
183 serde_json::to_string_pretty(self)
184 }
185
186 pub fn to_dot(&self) -> String {
188 let mut dot = String::from("digraph ExecutionTrace {\n");
189 dot.push_str(" rankdir=LR;\n");
190 dot.push_str(" node [shape=box];\n\n");
191
192 for (i, step) in self.steps.iter().enumerate() {
193 let label = format!("{}\\n{}ms", step.name, step.duration_ms);
194 let color = match step.step_type {
195 StepType::Error => "red",
196 StepType::AudioCapture | StepType::VoiceActivity => "lightblue",
197 StepType::SpeechToText | StepType::TextToSpeech => "lightgreen",
198 StepType::Retrieval => "lightyellow",
199 StepType::LlmGeneration => "orange",
200 StepType::ToolExecution => "pink",
201 StepType::AudioPlayback => "lightgray",
202 };
203
204 dot.push_str(&format!(
205 " step{} [label=\"{}\", fillcolor={}, style=filled];\n",
206 i, label, color
207 ));
208
209 if i > 0 {
210 dot.push_str(&format!(" step{} -> step{};\n", i - 1, i));
211 }
212 }
213
214 dot.push_str("}\n");
215 dot
216 }
217
218 pub fn summary(&self) -> TraceSummary {
220 let mut summary = TraceSummary {
221 total_steps: self.steps.len(),
222 total_duration_ms: self.total_duration_ms,
223 step_durations: HashMap::new(),
224 errors: Vec::new(),
225 };
226
227 for step in &self.steps {
228 let type_name = format!("{:?}", step.step_type);
229 *summary.step_durations.entry(type_name).or_insert(0) += step.duration_ms;
230
231 if let Some(ref error) = step.error {
232 summary.errors.push(format!("{}: {}", step.name, error));
233 }
234 }
235
236 summary
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct TraceSummary {
243 pub total_steps: usize,
245
246 pub total_duration_ms: u64,
248
249 pub step_durations: HashMap<String, u64>,
251
252 pub errors: Vec<String>,
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_trace_step_creation() {
262 let step = TraceStep::new(StepType::SpeechToText, "Transcribe audio");
263
264 assert_eq!(step.step_type, StepType::SpeechToText);
265 assert_eq!(step.name, "Transcribe audio");
266 assert_eq!(step.duration_ms, 0);
267 assert!(step.input.is_none());
268 assert!(step.output.is_none());
269 }
270
271 #[test]
272 fn test_trace_step_with_data() {
273 let step = TraceStep::new(StepType::LlmGeneration, "Generate response")
274 .with_input(serde_json::json!({"prompt": "Hello"}))
275 .with_output(serde_json::json!({"response": "Hi there!"}))
276 .with_duration(250)
277 .with_metadata("model", "qwen-2.5-3b");
278
279 assert_eq!(step.duration_ms, 250);
280 assert!(step.input.is_some());
281 assert!(step.output.is_some());
282 assert_eq!(step.metadata.get("model").unwrap(), "qwen-2.5-3b");
283 }
284
285 #[test]
286 fn test_trace_step_with_error() {
287 let step =
288 TraceStep::new(StepType::ToolExecution, "Call API").with_error("Network timeout");
289
290 assert_eq!(step.step_type, StepType::Error);
291 assert!(step.error.is_some());
292 assert_eq!(step.error.unwrap(), "Network timeout");
293 }
294
295 #[test]
296 fn test_execution_trace() {
297 let trace = ExecutionTrace::new("trace-001").with_conversation(1, 5);
298
299 assert_eq!(trace.trace_id, "trace-001");
300 assert_eq!(trace.conversation_id, Some(1));
301 assert_eq!(trace.turn_number, Some(5));
302 assert_eq!(trace.steps.len(), 0);
303 }
304
305 #[test]
306 fn test_execution_trace_add_steps() {
307 let mut trace = ExecutionTrace::new("trace-002");
308
309 let step1 = TraceStep::new(StepType::SpeechToText, "STT").with_duration(100);
310 let step2 = TraceStep::new(StepType::LlmGeneration, "LLM").with_duration(300);
311 let step3 = TraceStep::new(StepType::TextToSpeech, "TTS").with_duration(150);
312
313 trace.add_step(step1);
314 trace.add_step(step2);
315 trace.add_step(step3);
316
317 assert_eq!(trace.steps.len(), 3);
318 assert!(trace.total_duration_ms > 0);
319 }
320
321 #[test]
322 fn test_trace_json_serialization() {
323 let mut trace = ExecutionTrace::new("trace-003");
324 trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
325
326 let json = trace.to_json().unwrap();
327 assert!(json.contains("trace-003"));
328 assert!(json.contains("SpeechToText"));
329 }
330
331 #[test]
332 fn test_trace_dot_format() {
333 let mut trace = ExecutionTrace::new("trace-004");
334 trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
335 trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM").with_duration(300));
336
337 let dot = trace.to_dot();
338 assert!(dot.contains("digraph ExecutionTrace"));
339 assert!(dot.contains("step0"));
340 assert!(dot.contains("step1"));
341 assert!(dot.contains("->"));
342 }
343
344 #[test]
345 fn test_trace_summary() {
346 let mut trace = ExecutionTrace::new("trace-005");
347
348 trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
349 trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM 1").with_duration(200));
350 trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM 2").with_duration(150));
351 trace.add_step(TraceStep::new(StepType::TextToSpeech, "TTS").with_duration(120));
352
353 let summary = trace.summary();
354 assert_eq!(summary.total_steps, 4);
355 assert_eq!(*summary.step_durations.get("LlmGeneration").unwrap(), 350);
356 assert_eq!(summary.errors.len(), 0);
357 }
358
359 #[test]
360 fn test_trace_summary_with_errors() {
361 let mut trace = ExecutionTrace::new("trace-006");
362
363 trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
364 trace.add_step(
365 TraceStep::new(StepType::ToolExecution, "API Call").with_error("Connection refused"),
366 );
367
368 let summary = trace.summary();
369 assert_eq!(summary.total_steps, 2);
370 assert_eq!(summary.errors.len(), 1);
371 assert!(summary.errors[0].contains("Connection refused"));
372 }
373}