atomr_agents_eval/
suite.rs1use std::sync::Arc;
2
3use atomr_agents_callable::Callable;
4use atomr_agents_core::{CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8use crate::scorer::{Scorer, ScorerOutcome};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct EvalCase {
12 pub id: String,
13 pub input: Value,
14 pub expected: Value,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EvalResult {
19 pub case_id: String,
20 pub outcome: ScorerOutcome,
21 pub elapsed_ms: u64,
22}
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct EvalRun {
26 pub passed: u32,
27 pub failed: u32,
28 pub avg_score: f32,
29 pub results: Vec<EvalResult>,
30}
31
32impl EvalRun {
33 pub fn pass_rate(&self) -> f32 {
34 let total = self.passed + self.failed;
35 if total == 0 {
36 return 0.0;
37 }
38 self.passed as f32 / total as f32
39 }
40}
41
42pub struct EvalSuite {
43 pub id: String,
44 pub cases: Vec<EvalCase>,
45 pub scorer: Arc<dyn Scorer>,
46}
47
48impl EvalSuite {
49 pub async fn run(&self, callable: &dyn Callable) -> Result<EvalRun> {
50 let mut run = EvalRun::default();
51 let mut total_score = 0.0f32;
52 for case in &self.cases {
53 let t0 = std::time::Instant::now();
54 let actual = callable.call(case.input.clone(), default_ctx()).await?;
55 let outcome = self.scorer.score(&case.expected, &actual);
56 if outcome.passed {
57 run.passed += 1;
58 } else {
59 run.failed += 1;
60 }
61 total_score += outcome.score;
62 run.results.push(EvalResult {
63 case_id: case.id.clone(),
64 outcome,
65 elapsed_ms: t0.elapsed().as_millis() as u64,
66 });
67 }
68 let total = (run.passed + run.failed) as f32;
69 run.avg_score = if total == 0.0 { 0.0 } else { total_score / total };
70 Ok(run)
71 }
72}
73
74fn default_ctx() -> CallCtx {
75 CallCtx {
76 agent_id: None,
77 tokens: TokenBudget::new(8192),
78 time: TimeBudget::new(Duration::from_secs(30)),
79 money: MoneyBudget::from_usd(1.0),
80 iterations: IterationBudget::new(8),
81 trace: vec![],
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use crate::scorer::ContainsScorer;
89 use atomr_agents_callable::FnCallable;
90
91 #[tokio::test]
92 async fn suite_scores_cases() {
93 let suite = EvalSuite {
94 id: "demo".into(),
95 cases: vec![
96 EvalCase {
97 id: "c1".into(),
98 input: serde_json::json!("hi"),
99 expected: serde_json::json!({"must_contain": "hi"}),
100 },
101 EvalCase {
102 id: "c2".into(),
103 input: serde_json::json!("bye"),
104 expected: serde_json::json!({"must_contain": "hello"}),
105 },
106 ],
107 scorer: Arc::new(ContainsScorer),
108 };
109 let echo = FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) });
110 let r = suite.run(&echo).await.unwrap();
111 assert_eq!(r.passed, 1);
112 assert_eq!(r.failed, 1);
113 assert!((r.pass_rate() - 0.5).abs() < 1e-6);
114 }
115}