Skip to main content

agentforge_core/
eval.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use uuid::Uuid;
4
5/// An evaluation run of an agent version against a scenario set.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct EvalRun {
8    pub id: Uuid,
9    pub agent_id: Uuid,
10    pub scenario_set_id: Option<Uuid>,
11    pub status: EvalRunStatus,
12    pub scenario_count: u32,
13    pub completed_count: u32,
14    pub error_count: u32,
15    pub aggregate_score: Option<f64>,
16    pub pass_rate: Option<f64>,
17    pub scores: Option<DimensionScores>,
18    pub failure_clusters: Option<Vec<FailureClusterSummary>>,
19    pub seed: u32,
20    pub concurrency: u32,
21    pub error_message: Option<String>,
22    pub started_at: Option<DateTime<Utc>>,
23    pub completed_at: Option<DateTime<Utc>>,
24    pub created_at: DateTime<Utc>,
25    pub updated_at: DateTime<Utc>,
26    // ── Self-improvement loop tracking ────────────────────────────────────────
27    /// Optimization loop state: `running` | `converged` | `no_improvement` |
28    /// `max_iterations` | `failed`. `None` means optimization was not requested.
29    pub opt_status: Option<String>,
30    /// Number of completed optimization rounds (0 = not started).
31    pub opt_rounds: i32,
32    /// Best aggregate score achieved across all optimization rounds.
33    pub opt_best_score: Option<f64>,
34    /// UUID of the best agent version saved during the optimization loop.
35    pub opt_best_agent_id: Option<Uuid>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum EvalRunStatus {
41    Pending,
42    Running,
43    Complete,
44    Error,
45    Cancelled,
46}
47
48impl std::fmt::Display for EvalRunStatus {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        match self {
51            EvalRunStatus::Pending => write!(f, "pending"),
52            EvalRunStatus::Running => write!(f, "running"),
53            EvalRunStatus::Complete => write!(f, "complete"),
54            EvalRunStatus::Error => write!(f, "error"),
55            EvalRunStatus::Cancelled => write!(f, "cancelled"),
56        }
57    }
58}
59
60/// Per-dimension score breakdown.
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
62pub struct DimensionScores {
63    /// Did the agent achieve the goal? Weight: 35%
64    pub task_completion: f64,
65    /// Were the right tools called? Weight: 20%
66    pub tool_selection: f64,
67    /// Were tool arguments valid and correct? Weight: 20%
68    pub argument_correctness: f64,
69    /// Was the output schema compliant? Weight: 15%
70    pub schema_compliance: f64,
71    /// Did the agent follow constraints? Weight: 7%
72    pub instruction_adherence: f64,
73    /// Was the shortest valid path taken? Weight: 3%
74    pub path_efficiency: f64,
75}
76
77impl DimensionScores {
78    /// Scoring weights as defined in the PRD (configurable per run via EvalWeights).
79    pub fn weighted_aggregate(&self, weights: &EvalWeights) -> f64 {
80        self.task_completion * weights.task_completion
81            + self.tool_selection * weights.tool_selection
82            + self.argument_correctness * weights.argument_correctness
83            + self.schema_compliance * weights.schema_compliance
84            + self.instruction_adherence * weights.instruction_adherence
85            + self.path_efficiency * weights.path_efficiency
86    }
87}
88
89/// Configurable weights for the 6 evaluation dimensions.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct EvalWeights {
92    pub task_completion: f64,
93    pub tool_selection: f64,
94    pub argument_correctness: f64,
95    pub schema_compliance: f64,
96    pub instruction_adherence: f64,
97    pub path_efficiency: f64,
98}
99
100impl Default for EvalWeights {
101    fn default() -> Self {
102        Self {
103            task_completion: 0.35,
104            tool_selection: 0.20,
105            argument_correctness: 0.20,
106            schema_compliance: 0.15,
107            instruction_adherence: 0.07,
108            path_efficiency: 0.03,
109        }
110    }
111}
112
113impl EvalWeights {
114    /// Validate that all weights sum to approximately 1.0.
115    pub fn validate(&self) -> bool {
116        let total = self.task_completion
117            + self.tool_selection
118            + self.argument_correctness
119            + self.schema_compliance
120            + self.instruction_adherence
121            + self.path_efficiency;
122        (total - 1.0).abs() < 0.001
123    }
124}
125
126/// Summary of failure clusters across a run.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct FailureClusterSummary {
129    pub cluster: FailureCluster,
130    pub count: u32,
131    pub percentage: f64,
132    pub sample_scenarios: Vec<Uuid>,
133}
134
135/// Root cause categories for trace failures.
136#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
137#[serde(rename_all = "snake_case")]
138pub enum FailureCluster {
139    WrongTool,
140    HallucinatedArgument,
141    Looping,
142    PrematureStop,
143    SchemaViolation,
144    ConstraintBreach,
145    NoFailure,
146    /// Trace failed due to an LLM/API infrastructure error (rate limit, timeout,
147    /// 5xx), not due to agent behaviour. These do not reflect agent quality.
148    ApiError,
149    Unknown,
150}
151
152impl std::fmt::Display for FailureCluster {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            FailureCluster::WrongTool => write!(f, "wrong_tool"),
156            FailureCluster::HallucinatedArgument => write!(f, "hallucinated_argument"),
157            FailureCluster::Looping => write!(f, "looping"),
158            FailureCluster::PrematureStop => write!(f, "premature_stop"),
159            FailureCluster::SchemaViolation => write!(f, "schema_violation"),
160            FailureCluster::ConstraintBreach => write!(f, "constraint_breach"),
161            FailureCluster::NoFailure => write!(f, "no_failure"),
162            FailureCluster::ApiError => write!(f, "api_error"),
163            FailureCluster::Unknown => write!(f, "unknown"),
164        }
165    }
166}
167
168impl EvalRun {
169    /// Convert a completed EvalRun to a Scorecard for display/gatekeeper use.
170    /// Returns None if the run has no scores yet.
171    pub fn to_scorecard(&self) -> Option<crate::Scorecard> {
172        let scores = self.scores.clone()?;
173        let aggregate_score = self.aggregate_score?;
174        let pass_rate = self.pass_rate?;
175        Some(crate::Scorecard {
176            run_id: self.id,
177            agent_id: self.agent_id,
178            agent_name: String::new(), // filled by caller if needed
179            agent_version: String::new(),
180            aggregate_score,
181            pass_rate,
182            total_scenarios: self.scenario_count,
183            passed: (pass_rate * self.scenario_count as f64) as u32,
184            failed: self.error_count,
185            errors: 0,
186            review_needed: 0,
187            dimension_scores: scores,
188            failure_clusters: self.failure_clusters.clone().unwrap_or_default(),
189            duration_seconds: self
190                .completed_at
191                .zip(self.started_at)
192                .map(|(c, s)| (c - s).num_seconds().max(0) as u64)
193                .unwrap_or(0),
194            total_input_tokens: 0,
195            total_output_tokens: 0,
196        })
197    }
198}
199
200/// Request to start a new eval run.
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct EvalRunRequest {
203    pub agent_id: Uuid,
204    pub scenario_count: Option<u32>,
205    pub concurrency: Option<u32>,
206    pub seed: Option<u32>,
207    pub weights: Option<EvalWeights>,
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn default_weights_sum_to_one() {
216        let weights = EvalWeights::default();
217        assert!(weights.validate(), "Default weights must sum to 1.0");
218    }
219
220    #[test]
221    fn weighted_aggregate_perfect_score() {
222        let scores = DimensionScores {
223            task_completion: 1.0,
224            tool_selection: 1.0,
225            argument_correctness: 1.0,
226            schema_compliance: 1.0,
227            instruction_adherence: 1.0,
228            path_efficiency: 1.0,
229        };
230        let weights = EvalWeights::default();
231        let agg = scores.weighted_aggregate(&weights);
232        assert!((agg - 1.0).abs() < 1e-9);
233    }
234
235    #[test]
236    fn weighted_aggregate_zero_score() {
237        let scores = DimensionScores::default();
238        let weights = EvalWeights::default();
239        assert_eq!(scores.weighted_aggregate(&weights), 0.0);
240    }
241
242    #[test]
243    fn failure_cluster_display() {
244        assert_eq!(FailureCluster::WrongTool.to_string(), "wrong_tool");
245        assert_eq!(
246            FailureCluster::HallucinatedArgument.to_string(),
247            "hallucinated_argument"
248        );
249    }
250
251    // ── 20 new tests below ───────────────────────────────────────────────────
252
253    #[test]
254    fn weights_that_do_not_sum_to_one_fail_validate() {
255        let bad = EvalWeights {
256            task_completion: 0.5,
257            tool_selection: 0.4,
258            argument_correctness: 0.0,
259            schema_compliance: 0.0,
260            instruction_adherence: 0.0,
261            path_efficiency: 0.0,
262        };
263        assert!(!bad.validate());
264    }
265
266    #[test]
267    fn weights_summing_to_exactly_one_are_valid() {
268        let w = EvalWeights {
269            task_completion: 0.35,
270            tool_selection: 0.20,
271            argument_correctness: 0.20,
272            schema_compliance: 0.15,
273            instruction_adherence: 0.07,
274            path_efficiency: 0.03,
275        };
276        assert!(w.validate());
277    }
278
279    #[test]
280    fn weighted_aggregate_only_task_completion() {
281        let scores = DimensionScores {
282            task_completion: 1.0,
283            ..DimensionScores::default()
284        };
285        let weights = EvalWeights::default();
286        // Only task_completion=1.0 contributes → 0.35
287        let agg = scores.weighted_aggregate(&weights);
288        assert!((agg - 0.35).abs() < 1e-9);
289    }
290
291    #[test]
292    fn weighted_aggregate_only_tool_selection() {
293        let scores = DimensionScores {
294            tool_selection: 1.0,
295            ..DimensionScores::default()
296        };
297        let weights = EvalWeights::default();
298        let agg = scores.weighted_aggregate(&weights);
299        assert!((agg - 0.20).abs() < 1e-9);
300    }
301
302    #[test]
303    fn weighted_aggregate_only_schema_compliance() {
304        let scores = DimensionScores {
305            schema_compliance: 1.0,
306            ..DimensionScores::default()
307        };
308        let weights = EvalWeights::default();
309        let agg = scores.weighted_aggregate(&weights);
310        assert!((agg - 0.15).abs() < 1e-9);
311    }
312
313    #[test]
314    fn weighted_aggregate_respects_custom_weights() {
315        let scores = DimensionScores {
316            task_completion: 1.0,
317            ..DimensionScores::default()
318        };
319        let weights = EvalWeights {
320            task_completion: 1.0,
321            tool_selection: 0.0,
322            argument_correctness: 0.0,
323            schema_compliance: 0.0,
324            instruction_adherence: 0.0,
325            path_efficiency: 0.0,
326        };
327        let agg = scores.weighted_aggregate(&weights);
328        assert!((agg - 1.0).abs() < 1e-9);
329    }
330
331    #[test]
332    fn eval_run_status_display_all_variants() {
333        assert_eq!(EvalRunStatus::Pending.to_string(), "pending");
334        assert_eq!(EvalRunStatus::Running.to_string(), "running");
335        assert_eq!(EvalRunStatus::Complete.to_string(), "complete");
336        assert_eq!(EvalRunStatus::Error.to_string(), "error");
337        assert_eq!(EvalRunStatus::Cancelled.to_string(), "cancelled");
338    }
339
340    #[test]
341    fn eval_run_status_all_variants_distinct() {
342        let all = [
343            EvalRunStatus::Pending.to_string(),
344            EvalRunStatus::Running.to_string(),
345            EvalRunStatus::Complete.to_string(),
346            EvalRunStatus::Error.to_string(),
347            EvalRunStatus::Cancelled.to_string(),
348        ];
349        let set: std::collections::HashSet<_> = all.iter().collect();
350        assert_eq!(set.len(), 5, "All status strings must be distinct");
351    }
352
353    #[test]
354    fn failure_cluster_display_all_variants() {
355        assert_eq!(FailureCluster::WrongTool.to_string(), "wrong_tool");
356        assert_eq!(
357            FailureCluster::HallucinatedArgument.to_string(),
358            "hallucinated_argument"
359        );
360        assert_eq!(FailureCluster::Looping.to_string(), "looping");
361        assert_eq!(FailureCluster::PrematureStop.to_string(), "premature_stop");
362        assert_eq!(
363            FailureCluster::SchemaViolation.to_string(),
364            "schema_violation"
365        );
366        assert_eq!(
367            FailureCluster::ConstraintBreach.to_string(),
368            "constraint_breach"
369        );
370        assert_eq!(FailureCluster::NoFailure.to_string(), "no_failure");
371        assert_eq!(FailureCluster::ApiError.to_string(), "api_error");
372        assert_eq!(FailureCluster::Unknown.to_string(), "unknown");
373    }
374
375    #[test]
376    fn failure_cluster_serde_roundtrip() {
377        let original = FailureCluster::HallucinatedArgument;
378        let json = serde_json::to_string(&original).unwrap();
379        let back: FailureCluster = serde_json::from_str(&json).unwrap();
380        assert_eq!(back, FailureCluster::HallucinatedArgument);
381    }
382
383    #[test]
384    fn failure_cluster_api_error_serde() {
385        let cluster = FailureCluster::ApiError;
386        let json = serde_json::to_string(&cluster).unwrap();
387        assert_eq!(json, r#""api_error""#);
388        let back: FailureCluster = serde_json::from_str(&json).unwrap();
389        assert_eq!(back, FailureCluster::ApiError);
390    }
391
392    #[test]
393    fn dimension_scores_default_are_zero() {
394        let s = DimensionScores::default();
395        assert_eq!(s.task_completion, 0.0);
396        assert_eq!(s.tool_selection, 0.0);
397        assert_eq!(s.argument_correctness, 0.0);
398        assert_eq!(s.schema_compliance, 0.0);
399        assert_eq!(s.instruction_adherence, 0.0);
400        assert_eq!(s.path_efficiency, 0.0);
401    }
402
403    #[test]
404    fn dimension_scores_serde_roundtrip() {
405        let s = DimensionScores {
406            task_completion: 0.9,
407            tool_selection: 0.8,
408            argument_correctness: 0.7,
409            schema_compliance: 0.6,
410            instruction_adherence: 0.5,
411            path_efficiency: 0.4,
412        };
413        let json = serde_json::to_string(&s).unwrap();
414        let back: DimensionScores = serde_json::from_str(&json).unwrap();
415        assert!((back.task_completion - 0.9).abs() < 1e-9);
416        assert!((back.path_efficiency - 0.4).abs() < 1e-9);
417    }
418
419    #[test]
420    fn eval_run_status_serde_snake_case() {
421        let json = serde_json::to_string(&EvalRunStatus::Complete).unwrap();
422        assert_eq!(json, r#""complete""#);
423        let back: EvalRunStatus = serde_json::from_str(&json).unwrap();
424        assert_eq!(back, EvalRunStatus::Complete);
425    }
426
427    #[test]
428    fn eval_run_status_pending_serde() {
429        let json = serde_json::to_string(&EvalRunStatus::Pending).unwrap();
430        assert_eq!(json, r#""pending""#);
431    }
432
433    #[test]
434    fn eval_run_status_cancelled_serde() {
435        let json = serde_json::to_string(&EvalRunStatus::Cancelled).unwrap();
436        assert_eq!(json, r#""cancelled""#);
437    }
438
439    #[test]
440    fn weights_tolerance_accepts_floating_point_imprecision() {
441        // When weights are computed programmatically, floating point may cause
442        // tiny deviations from 1.0 — the 0.001 tolerance covers this.
443        let w = EvalWeights {
444            task_completion: 0.35 + 0.000001,
445            tool_selection: 0.20,
446            argument_correctness: 0.20,
447            schema_compliance: 0.15,
448            instruction_adherence: 0.07,
449            path_efficiency: 0.03,
450        };
451        // Sum is 1.000001 — within tolerance
452        assert!(w.validate());
453    }
454
455    #[test]
456    fn failure_cluster_summary_stores_fields() {
457        let id = Uuid::new_v4();
458        let s = FailureClusterSummary {
459            cluster: FailureCluster::Looping,
460            count: 3,
461            percentage: 30.0,
462            sample_scenarios: vec![id],
463        };
464        assert_eq!(s.count, 3);
465        assert!((s.percentage - 30.0).abs() < 1e-9);
466        assert_eq!(s.sample_scenarios[0], id);
467    }
468
469    #[test]
470    fn failure_cluster_all_variants_are_hash_compatible() {
471        let mut map = std::collections::HashMap::new();
472        map.insert(FailureCluster::WrongTool, 1u32);
473        map.insert(FailureCluster::Looping, 2u32);
474        map.insert(FailureCluster::ApiError, 3u32);
475        assert_eq!(map.get(&FailureCluster::WrongTool), Some(&1));
476        assert_eq!(map.get(&FailureCluster::ApiError), Some(&3));
477    }
478
479    #[test]
480    fn eval_run_request_serde_roundtrip() {
481        let req = EvalRunRequest {
482            agent_id: Uuid::new_v4(),
483            scenario_count: Some(50),
484            concurrency: Some(5),
485            seed: Some(99),
486            weights: None,
487        };
488        let json = serde_json::to_string(&req).unwrap();
489        let back: EvalRunRequest = serde_json::from_str(&json).unwrap();
490        assert_eq!(back.scenario_count, Some(50));
491        assert_eq!(back.seed, Some(99));
492    }
493}