Skip to main content

atomr_agents_eval/
suite.rs

1use std::sync::Arc;
2
3use atomr_agents_callable::Callable;
4use atomr_agents_core::{CallCtx, IterationBudget, MoneyBudget, Result, TimeBudget, TokenBudget, Value};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8use crate::scorer::{AsyncScorer, ScorerOutcome};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct EvalCase {
12    pub id: String,
13    pub input: Value,
14    pub expected: Value,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EvalResult {
19    pub case_id: String,
20    pub outcome: ScorerOutcome,
21    pub elapsed_ms: u64,
22}
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize)]
25pub struct EvalRun {
26    pub passed: u32,
27    pub failed: u32,
28    pub avg_score: f32,
29    pub results: Vec<EvalResult>,
30}
31
32impl EvalRun {
33    pub fn pass_rate(&self) -> f32 {
34        let total = self.passed + self.failed;
35        if total == 0 {
36            return 0.0;
37        }
38        self.passed as f32 / total as f32
39    }
40}
41
42pub struct EvalSuite {
43    pub id: String,
44    pub cases: Vec<EvalCase>,
45    /// Scorer used to grade each case. The field is typed as
46    /// `Arc<dyn AsyncScorer>` so judges that genuinely await (LLM
47    /// judges, retrieval-grounded checks, network probes) compose
48    /// directly without a `block_on` bridge. Sync `Scorer` impls are
49    /// promoted into `AsyncScorer` via the blanket impl in
50    /// `crate::scorer`, so callers can construct with
51    /// `Arc::new(MySyncScorer) as Arc<dyn AsyncScorer>` (or rely on
52    /// type inference at the field-init site).
53    pub scorer: Arc<dyn AsyncScorer>,
54}
55
56impl EvalSuite {
57    pub async fn run(&self, callable: &dyn Callable) -> Result<EvalRun> {
58        let mut run = EvalRun::default();
59        let mut total_score = 0.0f32;
60        for case in &self.cases {
61            let t0 = std::time::Instant::now();
62            let actual = callable.call(case.input.clone(), default_ctx()).await?;
63            let outcome = self.scorer.score(&case.expected, &actual).await;
64            if outcome.passed {
65                run.passed += 1;
66            } else {
67                run.failed += 1;
68            }
69            total_score += outcome.score;
70            run.results.push(EvalResult {
71                case_id: case.id.clone(),
72                outcome,
73                elapsed_ms: t0.elapsed().as_millis() as u64,
74            });
75        }
76        let total = (run.passed + run.failed) as f32;
77        run.avg_score = if total == 0.0 { 0.0 } else { total_score / total };
78        Ok(run)
79    }
80}
81
82fn default_ctx() -> CallCtx {
83    CallCtx {
84        agent_id: None,
85        tokens: TokenBudget::new(8192),
86        time: TimeBudget::new(Duration::from_secs(30)),
87        money: MoneyBudget::from_usd(1.0),
88        iterations: IterationBudget::new(8),
89        trace: vec![],
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::scorer::ContainsScorer;
97    use atomr_agents_callable::FnCallable;
98
99    #[tokio::test]
100    async fn suite_scores_cases() {
101        let suite = EvalSuite {
102            id: "demo".into(),
103            cases: vec![
104                EvalCase {
105                    id: "c1".into(),
106                    input: serde_json::json!("hi"),
107                    expected: serde_json::json!({"must_contain": "hi"}),
108                },
109                EvalCase {
110                    id: "c2".into(),
111                    input: serde_json::json!("bye"),
112                    expected: serde_json::json!({"must_contain": "hello"}),
113                },
114            ],
115            scorer: Arc::new(ContainsScorer),
116        };
117        let echo = FnCallable::labeled("echo", |v: Value, _ctx| async move { Ok(v) });
118        let r = suite.run(&echo).await.unwrap();
119        assert_eq!(r.passed, 1);
120        assert_eq!(r.failed, 1);
121        assert!((r.pass_rate() - 0.5).abs() < 1e-6);
122    }
123}