frame_trace/
trace.rs

1//! Execution trace and CallGraph tracking
2//!
3//! Records the execution flow through SAM's pipeline for debugging,
4//! transparency, and performance analysis.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, SystemTime, UNIX_EPOCH};
9
10/// Pipeline step type
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum StepType {
13    /// Audio capture started
14    AudioCapture,
15    /// Voice activity detected
16    VoiceActivity,
17    /// Speech-to-text transcription
18    SpeechToText,
19    /// Context retrieval from memory
20    Retrieval,
21    /// LLM generation
22    LlmGeneration,
23    /// Tool/skill execution
24    ToolExecution,
25    /// Text-to-speech synthesis
26    TextToSpeech,
27    /// Audio playback
28    AudioPlayback,
29    /// Error occurred
30    Error,
31}
32
33/// A single step in the execution trace
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TraceStep {
36    /// Step type
37    pub step_type: StepType,
38
39    /// Step name/description
40    pub name: String,
41
42    /// Unix timestamp when step started (milliseconds)
43    pub start_time_ms: u64,
44
45    /// Duration in milliseconds
46    pub duration_ms: u64,
47
48    /// Input data (JSON-serializable)
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub input: Option<serde_json::Value>,
51
52    /// Output data (JSON-serializable)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub output: Option<serde_json::Value>,
55
56    /// Additional metadata
57    #[serde(skip_serializing_if = "HashMap::is_empty", default)]
58    pub metadata: HashMap<String, String>,
59
60    /// Error message if step failed
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub error: Option<String>,
63}
64
65impl TraceStep {
66    /// Create a new trace step
67    pub fn new(step_type: StepType, name: impl Into<String>) -> Self {
68        let now = SystemTime::now()
69            .duration_since(UNIX_EPOCH)
70            .unwrap_or(Duration::ZERO);
71
72        Self {
73            step_type,
74            name: name.into(),
75            start_time_ms: now.as_millis() as u64,
76            duration_ms: 0,
77            input: None,
78            output: None,
79            metadata: HashMap::new(),
80            error: None,
81        }
82    }
83
84    /// Set input data
85    pub fn with_input(mut self, input: serde_json::Value) -> Self {
86        self.input = Some(input);
87        self
88    }
89
90    /// Set output data
91    pub fn with_output(mut self, output: serde_json::Value) -> Self {
92        self.output = Some(output);
93        self
94    }
95
96    /// Set duration
97    pub fn with_duration(mut self, duration_ms: u64) -> Self {
98        self.duration_ms = duration_ms;
99        self
100    }
101
102    /// Add metadata
103    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
104        self.metadata.insert(key.into(), value.into());
105        self
106    }
107
108    /// Set error
109    pub fn with_error(mut self, error: impl Into<String>) -> Self {
110        self.step_type = StepType::Error;
111        self.error = Some(error.into());
112        self
113    }
114}
115
116/// Complete execution trace for a conversation turn
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ExecutionTrace {
119    /// Unique trace ID
120    pub trace_id: String,
121
122    /// Conversation ID this trace belongs to
123    pub conversation_id: Option<u64>,
124
125    /// Turn number in conversation
126    pub turn_number: Option<u64>,
127
128    /// All steps in execution order
129    pub steps: Vec<TraceStep>,
130
131    /// Total execution time (milliseconds)
132    pub total_duration_ms: u64,
133
134    /// Trace start timestamp
135    pub start_time_ms: u64,
136}
137
138impl ExecutionTrace {
139    /// Create a new execution trace
140    pub fn new(trace_id: impl Into<String>) -> Self {
141        let now = SystemTime::now()
142            .duration_since(UNIX_EPOCH)
143            .unwrap_or(Duration::ZERO);
144
145        Self {
146            trace_id: trace_id.into(),
147            conversation_id: None,
148            turn_number: None,
149            steps: Vec::new(),
150            total_duration_ms: 0,
151            start_time_ms: now.as_millis() as u64,
152        }
153    }
154
155    /// Set conversation context
156    pub fn with_conversation(mut self, conversation_id: u64, turn_number: u64) -> Self {
157        self.conversation_id = Some(conversation_id);
158        self.turn_number = Some(turn_number);
159        self
160    }
161
162    /// Add a step to the trace
163    pub fn add_step(&mut self, step: TraceStep) {
164        self.steps.push(step);
165        self.update_total_duration();
166    }
167
168    /// Finalize the trace
169    pub fn finalize(&mut self) {
170        self.update_total_duration();
171    }
172
173    /// Update total duration based on steps
174    fn update_total_duration(&mut self) {
175        if let (Some(first), Some(last)) = (self.steps.first(), self.steps.last()) {
176            self.total_duration_ms =
177                (last.start_time_ms + last.duration_ms).saturating_sub(first.start_time_ms);
178        }
179    }
180
181    /// Export trace as JSON
182    pub fn to_json(&self) -> Result<String, serde_json::Error> {
183        serde_json::to_string_pretty(self)
184    }
185
186    /// Export trace as DOT graph format
187    pub fn to_dot(&self) -> String {
188        let mut dot = String::from("digraph ExecutionTrace {\n");
189        dot.push_str("  rankdir=LR;\n");
190        dot.push_str("  node [shape=box];\n\n");
191
192        for (i, step) in self.steps.iter().enumerate() {
193            let label = format!("{}\\n{}ms", step.name, step.duration_ms);
194            let color = match step.step_type {
195                StepType::Error => "red",
196                StepType::AudioCapture | StepType::VoiceActivity => "lightblue",
197                StepType::SpeechToText | StepType::TextToSpeech => "lightgreen",
198                StepType::Retrieval => "lightyellow",
199                StepType::LlmGeneration => "orange",
200                StepType::ToolExecution => "pink",
201                StepType::AudioPlayback => "lightgray",
202            };
203
204            dot.push_str(&format!(
205                "  step{} [label=\"{}\", fillcolor={}, style=filled];\n",
206                i, label, color
207            ));
208
209            if i > 0 {
210                dot.push_str(&format!("  step{} -> step{};\n", i - 1, i));
211            }
212        }
213
214        dot.push_str("}\n");
215        dot
216    }
217
218    /// Get performance summary
219    pub fn summary(&self) -> TraceSummary {
220        let mut summary = TraceSummary {
221            total_steps: self.steps.len(),
222            total_duration_ms: self.total_duration_ms,
223            step_durations: HashMap::new(),
224            errors: Vec::new(),
225        };
226
227        for step in &self.steps {
228            let type_name = format!("{:?}", step.step_type);
229            *summary.step_durations.entry(type_name).or_insert(0) += step.duration_ms;
230
231            if let Some(ref error) = step.error {
232                summary.errors.push(format!("{}: {}", step.name, error));
233            }
234        }
235
236        summary
237    }
238}
239
240/// Performance summary of an execution trace
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct TraceSummary {
243    /// Total number of steps
244    pub total_steps: usize,
245
246    /// Total duration (milliseconds)
247    pub total_duration_ms: u64,
248
249    /// Duration by step type
250    pub step_durations: HashMap<String, u64>,
251
252    /// List of errors that occurred
253    pub errors: Vec<String>,
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_trace_step_creation() {
262        let step = TraceStep::new(StepType::SpeechToText, "Transcribe audio");
263
264        assert_eq!(step.step_type, StepType::SpeechToText);
265        assert_eq!(step.name, "Transcribe audio");
266        assert_eq!(step.duration_ms, 0);
267        assert!(step.input.is_none());
268        assert!(step.output.is_none());
269    }
270
271    #[test]
272    fn test_trace_step_with_data() {
273        let step = TraceStep::new(StepType::LlmGeneration, "Generate response")
274            .with_input(serde_json::json!({"prompt": "Hello"}))
275            .with_output(serde_json::json!({"response": "Hi there!"}))
276            .with_duration(250)
277            .with_metadata("model", "qwen-2.5-3b");
278
279        assert_eq!(step.duration_ms, 250);
280        assert!(step.input.is_some());
281        assert!(step.output.is_some());
282        assert_eq!(step.metadata.get("model").unwrap(), "qwen-2.5-3b");
283    }
284
285    #[test]
286    fn test_trace_step_with_error() {
287        let step =
288            TraceStep::new(StepType::ToolExecution, "Call API").with_error("Network timeout");
289
290        assert_eq!(step.step_type, StepType::Error);
291        assert!(step.error.is_some());
292        assert_eq!(step.error.unwrap(), "Network timeout");
293    }
294
295    #[test]
296    fn test_execution_trace() {
297        let trace = ExecutionTrace::new("trace-001").with_conversation(1, 5);
298
299        assert_eq!(trace.trace_id, "trace-001");
300        assert_eq!(trace.conversation_id, Some(1));
301        assert_eq!(trace.turn_number, Some(5));
302        assert_eq!(trace.steps.len(), 0);
303    }
304
305    #[test]
306    fn test_execution_trace_add_steps() {
307        let mut trace = ExecutionTrace::new("trace-002");
308
309        let step1 = TraceStep::new(StepType::SpeechToText, "STT").with_duration(100);
310        let step2 = TraceStep::new(StepType::LlmGeneration, "LLM").with_duration(300);
311        let step3 = TraceStep::new(StepType::TextToSpeech, "TTS").with_duration(150);
312
313        trace.add_step(step1);
314        trace.add_step(step2);
315        trace.add_step(step3);
316
317        assert_eq!(trace.steps.len(), 3);
318        assert!(trace.total_duration_ms > 0);
319    }
320
321    #[test]
322    fn test_trace_json_serialization() {
323        let mut trace = ExecutionTrace::new("trace-003");
324        trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
325
326        let json = trace.to_json().unwrap();
327        assert!(json.contains("trace-003"));
328        assert!(json.contains("SpeechToText"));
329    }
330
331    #[test]
332    fn test_trace_dot_format() {
333        let mut trace = ExecutionTrace::new("trace-004");
334        trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
335        trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM").with_duration(300));
336
337        let dot = trace.to_dot();
338        assert!(dot.contains("digraph ExecutionTrace"));
339        assert!(dot.contains("step0"));
340        assert!(dot.contains("step1"));
341        assert!(dot.contains("->"));
342    }
343
344    #[test]
345    fn test_trace_summary() {
346        let mut trace = ExecutionTrace::new("trace-005");
347
348        trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
349        trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM 1").with_duration(200));
350        trace.add_step(TraceStep::new(StepType::LlmGeneration, "LLM 2").with_duration(150));
351        trace.add_step(TraceStep::new(StepType::TextToSpeech, "TTS").with_duration(120));
352
353        let summary = trace.summary();
354        assert_eq!(summary.total_steps, 4);
355        assert_eq!(*summary.step_durations.get("LlmGeneration").unwrap(), 350);
356        assert_eq!(summary.errors.len(), 0);
357    }
358
359    #[test]
360    fn test_trace_summary_with_errors() {
361        let mut trace = ExecutionTrace::new("trace-006");
362
363        trace.add_step(TraceStep::new(StepType::SpeechToText, "STT").with_duration(100));
364        trace.add_step(
365            TraceStep::new(StepType::ToolExecution, "API Call").with_error("Connection refused"),
366        );
367
368        let summary = trace.summary();
369        assert_eq!(summary.total_steps, 2);
370        assert_eq!(summary.errors.len(), 1);
371        assert!(summary.errors[0].contains("Connection refused"));
372    }
373}