adk_eval/
report.rs

1//! Evaluation result reporting
2//!
3//! Structures for representing and formatting evaluation results.
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::time::Duration;
9
10/// Complete evaluation report for a test file or eval set
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EvaluationReport {
13    /// Unique identifier for this evaluation run
14    pub run_id: String,
15    /// When the evaluation started
16    pub started_at: chrono::DateTime<chrono::Utc>,
17    /// When the evaluation completed
18    pub completed_at: chrono::DateTime<chrono::Utc>,
19    /// Total duration
20    pub duration: Duration,
21    /// Results for each test case
22    pub results: Vec<EvaluationResult>,
23    /// Summary statistics
24    pub summary: EvaluationSummary,
25}
26
27impl EvaluationReport {
28    /// Create a new report
29    pub fn new(
30        run_id: &str,
31        results: Vec<EvaluationResult>,
32        started_at: chrono::DateTime<chrono::Utc>,
33    ) -> Self {
34        let completed_at = chrono::Utc::now();
35        let duration = (completed_at - started_at).to_std().unwrap_or_default();
36        let summary = EvaluationSummary::from_results(&results);
37
38        Self { run_id: run_id.to_string(), started_at, completed_at, duration, results, summary }
39    }
40
41    /// Check if all tests passed
42    pub fn all_passed(&self) -> bool {
43        self.summary.failed == 0
44    }
45
46    /// Get failed results only
47    pub fn failures(&self) -> Vec<&EvaluationResult> {
48        self.results.iter().filter(|r| !r.passed).collect()
49    }
50
51    /// Format as a human-readable string
52    pub fn format_summary(&self) -> String {
53        let mut output = String::new();
54        output.push_str(&format!("Evaluation Report: {}\n", self.run_id));
55        output.push_str(&format!("Duration: {:?}\n", self.duration));
56        output.push_str("\nSummary:\n");
57        output.push_str(&format!("  Total: {}\n", self.summary.total));
58        output.push_str(&format!("  Passed: {}\n", self.summary.passed));
59        output.push_str(&format!("  Failed: {}\n", self.summary.failed));
60        output.push_str(&format!("  Pass Rate: {:.1}%\n", self.summary.pass_rate * 100.0));
61
62        if !self.summary.avg_scores.is_empty() {
63            output.push_str("\nAverage Scores:\n");
64            for (criterion, score) in &self.summary.avg_scores {
65                output.push_str(&format!("  {}: {:.3}\n", criterion, score));
66            }
67        }
68
69        if self.summary.failed > 0 {
70            output.push_str("\nFailed Tests:\n");
71            for result in self.failures() {
72                output.push_str(&format!(
73                    "  - {} ({})\n",
74                    result.eval_id,
75                    result
76                        .failures
77                        .iter()
78                        .map(|f| f.criterion.as_str())
79                        .collect::<Vec<_>>()
80                        .join(", ")
81                ));
82            }
83        }
84
85        output
86    }
87
88    /// Export to JSON
89    pub fn to_json(&self) -> Result<String, serde_json::Error> {
90        serde_json::to_string_pretty(self)
91    }
92}
93
94/// Summary statistics for an evaluation run
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
96pub struct EvaluationSummary {
97    /// Total number of test cases
98    pub total: usize,
99    /// Number of passed test cases
100    pub passed: usize,
101    /// Number of failed test cases
102    pub failed: usize,
103    /// Pass rate (0.0 - 1.0)
104    pub pass_rate: f64,
105    /// Average scores by criterion
106    pub avg_scores: HashMap<String, f64>,
107}
108
109impl EvaluationSummary {
110    /// Calculate summary from results
111    pub fn from_results(results: &[EvaluationResult]) -> Self {
112        let total = results.len();
113        let passed = results.iter().filter(|r| r.passed).count();
114        let failed = total - passed;
115        let pass_rate = if total > 0 { passed as f64 / total as f64 } else { 0.0 };
116
117        // Calculate average scores
118        let mut score_sums: HashMap<String, (f64, usize)> = HashMap::new();
119        for result in results {
120            for (criterion, score) in &result.scores {
121                let entry = score_sums.entry(criterion.clone()).or_insert((0.0, 0));
122                entry.0 += score;
123                entry.1 += 1;
124            }
125        }
126
127        let avg_scores =
128            score_sums.into_iter().map(|(k, (sum, count))| (k, sum / count as f64)).collect();
129
130        Self { total, passed, failed, pass_rate, avg_scores }
131    }
132}
133
134/// Result for a single test case
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct EvaluationResult {
137    /// Test case identifier
138    pub eval_id: String,
139    /// Whether the test passed all criteria
140    pub passed: bool,
141    /// Scores for each criterion
142    pub scores: HashMap<String, f64>,
143    /// Failures (criteria that didn't meet threshold)
144    pub failures: Vec<Failure>,
145    /// Execution duration
146    pub duration: Duration,
147    /// Detailed turn results
148    #[serde(default)]
149    pub turn_results: Vec<TurnResult>,
150}
151
152impl EvaluationResult {
153    /// Create a passed result
154    pub fn passed(eval_id: &str, scores: HashMap<String, f64>, duration: Duration) -> Self {
155        Self {
156            eval_id: eval_id.to_string(),
157            passed: true,
158            scores,
159            failures: vec![],
160            duration,
161            turn_results: vec![],
162        }
163    }
164
165    /// Create a failed result
166    pub fn failed(
167        eval_id: &str,
168        scores: HashMap<String, f64>,
169        failures: Vec<Failure>,
170        duration: Duration,
171    ) -> Self {
172        Self {
173            eval_id: eval_id.to_string(),
174            passed: false,
175            scores,
176            failures,
177            duration,
178            turn_results: vec![],
179        }
180    }
181
182    /// Add turn results
183    pub fn with_turn_results(mut self, turn_results: Vec<TurnResult>) -> Self {
184        self.turn_results = turn_results;
185        self
186    }
187}
188
189/// A single failure in evaluation
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct Failure {
192    /// Criterion that failed
193    pub criterion: String,
194    /// Expected value
195    pub expected: Value,
196    /// Actual value
197    pub actual: Value,
198    /// Score achieved
199    pub score: f64,
200    /// Threshold required
201    pub threshold: f64,
202    /// Additional details
203    #[serde(default)]
204    pub details: Option<String>,
205}
206
207impl Failure {
208    /// Create a new failure
209    pub fn new(
210        criterion: &str,
211        expected: Value,
212        actual: Value,
213        score: f64,
214        threshold: f64,
215    ) -> Self {
216        Self { criterion: criterion.to_string(), expected, actual, score, threshold, details: None }
217    }
218
219    /// Add details
220    pub fn with_details(mut self, details: &str) -> Self {
221        self.details = Some(details.to_string());
222        self
223    }
224
225    /// Format as human-readable string
226    pub fn format(&self) -> String {
227        let mut s = format!(
228            "{}: score {:.3} < threshold {:.3}",
229            self.criterion, self.score, self.threshold
230        );
231        if let Some(details) = &self.details {
232            s.push_str(&format!("\n  Details: {}", details));
233        }
234        s
235    }
236}
237
238/// Result for a single conversation turn
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct TurnResult {
241    /// Turn/invocation identifier
242    pub invocation_id: String,
243    /// Actual response from the agent
244    pub actual_response: Option<String>,
245    /// Expected response
246    pub expected_response: Option<String>,
247    /// Actual tool calls made
248    pub actual_tool_calls: Vec<crate::schema::ToolUse>,
249    /// Expected tool calls
250    pub expected_tool_calls: Vec<crate::schema::ToolUse>,
251    /// Scores for this turn
252    pub scores: HashMap<String, f64>,
253}
254
255/// Result for a single test case (alias for backward compatibility)
256pub type TestCaseResult = EvaluationResult;
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_evaluation_summary() {
264        let results = vec![
265            EvaluationResult::passed(
266                "test_1",
267                HashMap::from([("tool_trajectory".to_string(), 1.0)]),
268                Duration::from_millis(100),
269            ),
270            EvaluationResult::passed(
271                "test_2",
272                HashMap::from([("tool_trajectory".to_string(), 0.8)]),
273                Duration::from_millis(150),
274            ),
275            EvaluationResult::failed(
276                "test_3",
277                HashMap::from([("tool_trajectory".to_string(), 0.5)]),
278                vec![Failure::new("tool_trajectory", Value::Null, Value::Null, 0.5, 0.8)],
279                Duration::from_millis(200),
280            ),
281        ];
282
283        let summary = EvaluationSummary::from_results(&results);
284        assert_eq!(summary.total, 3);
285        assert_eq!(summary.passed, 2);
286        assert_eq!(summary.failed, 1);
287        assert!((summary.pass_rate - 0.666).abs() < 0.01);
288    }
289
290    #[test]
291    fn test_failure_format() {
292        let failure = Failure::new(
293            "response_similarity",
294            Value::String("expected".to_string()),
295            Value::String("actual".to_string()),
296            0.6,
297            0.8,
298        )
299        .with_details("Responses differ significantly");
300
301        let formatted = failure.format();
302        assert!(formatted.contains("response_similarity"));
303        assert!(formatted.contains("0.600"));
304        assert!(formatted.contains("0.800"));
305    }
306}