Skip to main content

agentforge_core/
trace.rs

1use crate::{DimensionScores, FailureCluster};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6/// A complete execution trace for a single scenario run.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Trace {
9    pub id: Uuid,
10    pub run_id: Uuid,
11    pub scenario_id: Uuid,
12    pub status: TraceStatus,
13    /// Ordered list of execution steps.
14    pub steps: Vec<TraceStep>,
15    pub final_output: Option<serde_json::Value>,
16    pub scores: Option<DimensionScores>,
17    pub aggregate_score: Option<f64>,
18    pub failure_cluster: FailureCluster,
19    pub failure_reason: Option<String>,
20    pub review_needed: bool,
21    pub llm_calls: u32,
22    pub tool_invocations: u32,
23    pub input_tokens: u32,
24    pub output_tokens: u32,
25    pub latency_ms: u64,
26    pub retry_count: u32,
27    pub seed: u32,
28    pub created_at: DateTime<Utc>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "snake_case")]
33pub enum TraceStatus {
34    Pass,
35    Fail,
36    Error,
37    ReviewNeeded,
38}
39
40impl std::fmt::Display for TraceStatus {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            TraceStatus::Pass => write!(f, "pass"),
44            TraceStatus::Fail => write!(f, "fail"),
45            TraceStatus::Error => write!(f, "error"),
46            TraceStatus::ReviewNeeded => write!(f, "review_needed"),
47        }
48    }
49}
50
51/// A single step in an execution trace.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "snake_case")]
54pub enum TraceStep {
55    LlmCall(LlmCallStep),
56    ToolCall(ToolCallStep),
57    ToolResult(ToolResultStep),
58    AgentThought(AgentThoughtStep),
59    FinalOutput(FinalOutputStep),
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct LlmCallStep {
64    pub index: u32,
65    pub model: String,
66    pub messages: Vec<serde_json::Value>,
67    pub response: serde_json::Value,
68    pub input_tokens: u32,
69    pub output_tokens: u32,
70    pub latency_ms: u64,
71    pub timestamp: DateTime<Utc>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolCallStep {
76    pub index: u32,
77    pub tool_name: String,
78    pub call_id: String,
79    pub arguments: serde_json::Value,
80    pub timestamp: DateTime<Utc>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ToolResultStep {
85    pub index: u32,
86    pub tool_name: String,
87    pub call_id: String,
88    pub result: serde_json::Value,
89    pub is_error: bool,
90    pub timestamp: DateTime<Utc>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AgentThoughtStep {
95    pub index: u32,
96    pub thought: String,
97    pub timestamp: DateTime<Utc>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct FinalOutputStep {
102    pub index: u32,
103    pub output: serde_json::Value,
104    pub timestamp: DateTime<Utc>,
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn trace_status_display() {
113        assert_eq!(TraceStatus::Pass.to_string(), "pass");
114        assert_eq!(TraceStatus::ReviewNeeded.to_string(), "review_needed");
115    }
116
117    // ── 10 new tests ─────────────────────────────────────────────────────────
118
119    #[test]
120    fn trace_status_display_all_variants() {
121        assert_eq!(TraceStatus::Pass.to_string(), "pass");
122        assert_eq!(TraceStatus::Fail.to_string(), "fail");
123        assert_eq!(TraceStatus::Error.to_string(), "error");
124        assert_eq!(TraceStatus::ReviewNeeded.to_string(), "review_needed");
125    }
126
127    #[test]
128    fn trace_status_serde_roundtrip() {
129        for status in &[
130            TraceStatus::Pass,
131            TraceStatus::Fail,
132            TraceStatus::Error,
133            TraceStatus::ReviewNeeded,
134        ] {
135            let json = serde_json::to_string(status).unwrap();
136            let back: TraceStatus = serde_json::from_str(&json).unwrap();
137            assert_eq!(&back, status);
138        }
139    }
140
141    #[test]
142    fn trace_status_all_variants_distinct() {
143        let all = [
144            TraceStatus::Pass.to_string(),
145            TraceStatus::Fail.to_string(),
146            TraceStatus::Error.to_string(),
147            TraceStatus::ReviewNeeded.to_string(),
148        ];
149        let set: std::collections::HashSet<_> = all.iter().collect();
150        assert_eq!(set.len(), 4);
151    }
152
153    #[test]
154    fn tool_call_step_serde_roundtrip() {
155        let step = ToolCallStep {
156            index: 0,
157            tool_name: "search".to_string(),
158            call_id: "call_abc".to_string(),
159            arguments: serde_json::json!({"query": "rust"}),
160            timestamp: chrono::Utc::now(),
161        };
162        let json = serde_json::to_string(&step).unwrap();
163        let back: ToolCallStep = serde_json::from_str(&json).unwrap();
164        assert_eq!(back.tool_name, "search");
165        assert_eq!(back.call_id, "call_abc");
166    }
167
168    #[test]
169    fn llm_call_step_serde_roundtrip() {
170        let step = LlmCallStep {
171            index: 1,
172            model: "gpt-4o".to_string(),
173            messages: vec![serde_json::json!({"role": "user", "content": "hello"})],
174            response: serde_json::json!({"content": "hi"}),
175            input_tokens: 10,
176            output_tokens: 5,
177            latency_ms: 300,
178            timestamp: chrono::Utc::now(),
179        };
180        let json = serde_json::to_string(&step).unwrap();
181        let back: LlmCallStep = serde_json::from_str(&json).unwrap();
182        assert_eq!(back.model, "gpt-4o");
183        assert_eq!(back.input_tokens, 10);
184    }
185
186    #[test]
187    fn tool_result_step_error_flag() {
188        let step = ToolResultStep {
189            index: 2,
190            tool_name: "search".to_string(),
191            call_id: "call_abc".to_string(),
192            result: serde_json::json!({"error": "timeout"}),
193            is_error: true,
194            timestamp: chrono::Utc::now(),
195        };
196        assert!(step.is_error);
197    }
198
199    #[test]
200    fn agent_thought_step_stores_thought() {
201        let step = AgentThoughtStep {
202            index: 3,
203            thought: "I should use the search tool".to_string(),
204            timestamp: chrono::Utc::now(),
205        };
206        assert_eq!(step.thought, "I should use the search tool");
207    }
208
209    #[test]
210    fn trace_step_tag_type_in_json() {
211        // Serde tag="type" means the JSON must include a "type" key
212        let step = TraceStep::ToolCall(ToolCallStep {
213            index: 0,
214            tool_name: "get_order".to_string(),
215            call_id: "c1".to_string(),
216            arguments: serde_json::json!({}),
217            timestamp: chrono::Utc::now(),
218        });
219        let json = serde_json::to_value(&step).unwrap();
220        assert_eq!(json["type"], "tool_call");
221        assert_eq!(json["tool_name"], "get_order");
222    }
223
224    #[test]
225    fn llm_call_step_tag_in_json() {
226        let step = TraceStep::LlmCall(LlmCallStep {
227            index: 0,
228            model: "gpt-4o".to_string(),
229            messages: vec![],
230            response: serde_json::json!({}),
231            input_tokens: 0,
232            output_tokens: 0,
233            latency_ms: 0,
234            timestamp: chrono::Utc::now(),
235        });
236        let json = serde_json::to_value(&step).unwrap();
237        assert_eq!(json["type"], "llm_call");
238    }
239
240    #[test]
241    fn final_output_step_tag_in_json() {
242        let step = TraceStep::FinalOutput(FinalOutputStep {
243            index: 5,
244            output: serde_json::json!({"result": "done"}),
245            timestamp: chrono::Utc::now(),
246        });
247        let json = serde_json::to_value(&step).unwrap();
248        assert_eq!(json["type"], "final_output");
249        assert_eq!(json["output"]["result"], "done");
250    }
251}