Skip to main content

atomr_agents_eval/
judge.rs

1//! LLM-judge scorer + rubric-based scorer.
2//!
3//! `JudgeModel` is the trait callers plug a model in through; the
4//! scorer simply prompts it and parses the response. The judges
5//! implement [`AsyncScorer`] directly so they can `await` without a
6//! blocking bridge.
7//!
8//! Note: these scorers do **not** implement the sync [`Scorer`] trait.
9//! The blanket `impl<S: Scorer> AsyncScorer for S` in `crate::scorer`
10//! would otherwise conflict with the explicit `AsyncScorer` impls
11//! here, and the whole point of the explicit impls is to drop the
12//! `tokio::task::block_in_place` workaround that a sync impl would
13//! force on us. Callers stuck on a sync surface can wrap the model
14//! manually or move to `Arc<dyn AsyncScorer>`.
15
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_agents_core::Result;
20use serde::{Deserialize, Serialize};
21
22use crate::scorer::{AsyncScorer, ScorerOutcome};
23
24#[async_trait]
25pub trait JudgeModel: Send + Sync + 'static {
26    async fn judge(&self, prompt: &str) -> Result<String>;
27}
28
29/// Single-criterion graded scorer — "did the actual output answer the
30/// expected question correctly?". The judge replies `pass` / `fail`
31/// followed by a short justification.
32pub struct LlmJudgeScorer {
33    pub model: Arc<dyn JudgeModel>,
34    pub prompt_template: String,
35}
36
37impl LlmJudgeScorer {
38    pub fn new(model: Arc<dyn JudgeModel>) -> Self {
39        Self {
40            model,
41            prompt_template: include_str_template_default(),
42        }
43    }
44
45    fn build_prompt(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> String {
46        self.prompt_template
47            .replace("{expected}", &expected.to_string())
48            .replace("{actual}", &actual.to_string())
49    }
50}
51
52fn include_str_template_default() -> String {
53    "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
54}
55
56fn parse_judge_reply(reply: &str) -> ScorerOutcome {
57    let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
58    let passed = first == "pass";
59    ScorerOutcome {
60        passed,
61        score: if passed { 1.0 } else { 0.0 },
62        note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
63    }
64}
65
66#[async_trait]
67impl AsyncScorer for LlmJudgeScorer {
68    async fn score(
69        &self,
70        expected: &atomr_agents_core::Value,
71        actual: &atomr_agents_core::Value,
72    ) -> ScorerOutcome {
73        let prompt = self.build_prompt(expected, actual);
74        match self.model.judge(&prompt).await {
75            Ok(reply) => parse_judge_reply(&reply),
76            Err(e) => ScorerOutcome {
77                passed: false,
78                score: 0.0,
79                note: format!("judge error: {e}"),
80            },
81        }
82    }
83}
84
85// --------------------------------------------------------------------
86// RubricScorer — multi-criterion grading. Each criterion is judged
87// individually; the final score is the average.
88// --------------------------------------------------------------------
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct RubricCriterion {
92    pub name: String,
93    pub description: String,
94    pub weight: f32,
95}
96
97pub struct RubricScorer {
98    pub model: Arc<dyn JudgeModel>,
99    pub criteria: Vec<RubricCriterion>,
100    /// Pass threshold on the (weighted) average score.
101    pub pass_at: f32,
102}
103
104impl RubricScorer {
105    fn build_criterion_prompt(
106        c: &RubricCriterion,
107        expected: &atomr_agents_core::Value,
108        actual: &atomr_agents_core::Value,
109    ) -> String {
110        format!(
111            "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
112            c.name, c.description, expected, actual
113        )
114    }
115
116    fn aggregate(results: &[(&RubricCriterion, f32)], pass_at: f32) -> ScorerOutcome {
117        let mut total = 0.0;
118        let mut total_w = 0.0;
119        let mut notes = Vec::with_capacity(results.len());
120        for (c, score) in results {
121            total += score * c.weight;
122            total_w += c.weight;
123            notes.push(format!("{}={}", c.name, score));
124        }
125        let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
126        let normalized = (avg / 10.0).clamp(0.0, 1.0);
127        ScorerOutcome {
128            passed: normalized >= pass_at,
129            score: normalized,
130            note: notes.join(", "),
131        }
132    }
133}
134
135fn parse_rubric_score(reply: &str) -> f32 {
136    reply
137        .lines()
138        .next()
139        .and_then(|s| s.trim().parse().ok())
140        .unwrap_or(0.0)
141}
142
143#[async_trait]
144impl AsyncScorer for RubricScorer {
145    async fn score(
146        &self,
147        expected: &atomr_agents_core::Value,
148        actual: &atomr_agents_core::Value,
149    ) -> ScorerOutcome {
150        let mut scored: Vec<(&RubricCriterion, f32)> = Vec::with_capacity(self.criteria.len());
151        for c in &self.criteria {
152            let prompt = Self::build_criterion_prompt(c, expected, actual);
153            let reply = match self.model.judge(&prompt).await {
154                Ok(r) => r,
155                Err(e) => format!("0\njudge error: {e}"),
156            };
157            scored.push((c, parse_rubric_score(&reply)));
158        }
159        Self::aggregate(&scored, self.pass_at)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use atomr_agents_core::Value;
167    use parking_lot::Mutex;
168
169    struct ScriptedJudge {
170        replies: Mutex<Vec<String>>,
171    }
172    #[async_trait]
173    impl JudgeModel for ScriptedJudge {
174        async fn judge(&self, _prompt: &str) -> Result<String> {
175            let mut g = self.replies.lock();
176            if g.is_empty() {
177                return Ok("fail\nout of replies".into());
178            }
179            Ok(g.remove(0))
180        }
181    }
182
183    #[tokio::test]
184    async fn async_judge_pass_passes() {
185        let m = Arc::new(ScriptedJudge {
186            replies: Mutex::new(vec!["pass\nlooks good".into()]),
187        });
188        let s = LlmJudgeScorer::new(m);
189        let r = AsyncScorer::score(&s, &Value::String("yes".into()), &Value::String("yes!".into())).await;
190        assert!(r.passed);
191        assert!(r.note.contains("looks good"));
192    }
193
194    #[tokio::test]
195    async fn async_judge_propagates_error_into_outcome() {
196        struct FailingJudge;
197        #[async_trait]
198        impl JudgeModel for FailingJudge {
199            async fn judge(&self, _prompt: &str) -> Result<String> {
200                Err(atomr_agents_core::AgentError::Tool("boom".into()))
201            }
202        }
203        let s = LlmJudgeScorer::new(Arc::new(FailingJudge));
204        let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
205        assert!(!r.passed);
206        assert!(r.note.contains("judge error"));
207    }
208
209    #[tokio::test]
210    async fn async_rubric_averages_weighted_scores() {
211        let m = Arc::new(ScriptedJudge {
212            replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
213        });
214        let s = RubricScorer {
215            model: m,
216            criteria: vec![
217                RubricCriterion {
218                    name: "correctness".into(),
219                    description: "is the answer correct".into(),
220                    weight: 1.0,
221                },
222                RubricCriterion {
223                    name: "concision".into(),
224                    description: "is it terse".into(),
225                    weight: 1.0,
226                },
227            ],
228            pass_at: 0.6,
229        };
230        let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
231        assert!((r.score - 0.75).abs() < 1e-5);
232        assert!(r.passed);
233    }
234}