Skip to main content

atomr_agents_eval/
judge.rs

1//! LLM-judge scorer + rubric-based scorer.
2//!
3//! `JudgeModel` is the trait callers plug a model in through; the
4//! scorer simply prompts it and parses the response. The judges
5//! implement [`AsyncScorer`] directly so they can `await` without a
6//! blocking bridge.
7//!
8//! Note: these scorers do **not** implement the sync [`Scorer`] trait.
9//! The blanket `impl<S: Scorer> AsyncScorer for S` in `crate::scorer`
10//! would otherwise conflict with the explicit `AsyncScorer` impls
11//! here, and the whole point of the explicit impls is to drop the
12//! `tokio::task::block_in_place` workaround that a sync impl would
13//! force on us. Callers stuck on a sync surface can wrap the model
14//! manually or move to `Arc<dyn AsyncScorer>`.
15
16use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_agents_core::Result;
20use serde::{Deserialize, Serialize};
21
22use crate::scorer::{AsyncScorer, ScorerOutcome};
23
24#[async_trait]
25pub trait JudgeModel: Send + Sync + 'static {
26    async fn judge(&self, prompt: &str) -> Result<String>;
27}
28
29/// Single-criterion graded scorer — "did the actual output answer the
30/// expected question correctly?". The judge replies `pass` / `fail`
31/// followed by a short justification.
32pub struct LlmJudgeScorer {
33    pub model: Arc<dyn JudgeModel>,
34    pub prompt_template: String,
35}
36
37impl LlmJudgeScorer {
38    pub fn new(model: Arc<dyn JudgeModel>) -> Self {
39        Self {
40            model,
41            prompt_template: include_str_template_default(),
42        }
43    }
44
45    fn build_prompt(
46        &self,
47        expected: &atomr_agents_core::Value,
48        actual: &atomr_agents_core::Value,
49    ) -> String {
50        self.prompt_template
51            .replace("{expected}", &expected.to_string())
52            .replace("{actual}", &actual.to_string())
53    }
54}
55
56fn include_str_template_default() -> String {
57    "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
58}
59
60fn parse_judge_reply(reply: &str) -> ScorerOutcome {
61    let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
62    let passed = first == "pass";
63    ScorerOutcome {
64        passed,
65        score: if passed { 1.0 } else { 0.0 },
66        note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
67    }
68}
69
70#[async_trait]
71impl AsyncScorer for LlmJudgeScorer {
72    async fn score(
73        &self,
74        expected: &atomr_agents_core::Value,
75        actual: &atomr_agents_core::Value,
76    ) -> ScorerOutcome {
77        let prompt = self.build_prompt(expected, actual);
78        match self.model.judge(&prompt).await {
79            Ok(reply) => parse_judge_reply(&reply),
80            Err(e) => ScorerOutcome {
81                passed: false,
82                score: 0.0,
83                note: format!("judge error: {e}"),
84            },
85        }
86    }
87}
88
89// --------------------------------------------------------------------
90// RubricScorer — multi-criterion grading. Each criterion is judged
91// individually; the final score is the average.
92// --------------------------------------------------------------------
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RubricCriterion {
96    pub name: String,
97    pub description: String,
98    pub weight: f32,
99}
100
101pub struct RubricScorer {
102    pub model: Arc<dyn JudgeModel>,
103    pub criteria: Vec<RubricCriterion>,
104    /// Pass threshold on the (weighted) average score.
105    pub pass_at: f32,
106}
107
108impl RubricScorer {
109    fn build_criterion_prompt(
110        c: &RubricCriterion,
111        expected: &atomr_agents_core::Value,
112        actual: &atomr_agents_core::Value,
113    ) -> String {
114        format!(
115            "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
116            c.name, c.description, expected, actual
117        )
118    }
119
120    fn aggregate(
121        results: &[(&RubricCriterion, f32)],
122        pass_at: f32,
123    ) -> ScorerOutcome {
124        let mut total = 0.0;
125        let mut total_w = 0.0;
126        let mut notes = Vec::with_capacity(results.len());
127        for (c, score) in results {
128            total += score * c.weight;
129            total_w += c.weight;
130            notes.push(format!("{}={}", c.name, score));
131        }
132        let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
133        let normalized = (avg / 10.0).clamp(0.0, 1.0);
134        ScorerOutcome {
135            passed: normalized >= pass_at,
136            score: normalized,
137            note: notes.join(", "),
138        }
139    }
140}
141
142fn parse_rubric_score(reply: &str) -> f32 {
143    reply
144        .lines()
145        .next()
146        .and_then(|s| s.trim().parse().ok())
147        .unwrap_or(0.0)
148}
149
150#[async_trait]
151impl AsyncScorer for RubricScorer {
152    async fn score(
153        &self,
154        expected: &atomr_agents_core::Value,
155        actual: &atomr_agents_core::Value,
156    ) -> ScorerOutcome {
157        let mut scored: Vec<(&RubricCriterion, f32)> = Vec::with_capacity(self.criteria.len());
158        for c in &self.criteria {
159            let prompt = Self::build_criterion_prompt(c, expected, actual);
160            let reply = match self.model.judge(&prompt).await {
161                Ok(r) => r,
162                Err(e) => format!("0\njudge error: {e}"),
163            };
164            scored.push((c, parse_rubric_score(&reply)));
165        }
166        Self::aggregate(&scored, self.pass_at)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use atomr_agents_core::Value;
174    use parking_lot::Mutex;
175
176    struct ScriptedJudge {
177        replies: Mutex<Vec<String>>,
178    }
179    #[async_trait]
180    impl JudgeModel for ScriptedJudge {
181        async fn judge(&self, _prompt: &str) -> Result<String> {
182            let mut g = self.replies.lock();
183            if g.is_empty() {
184                return Ok("fail\nout of replies".into());
185            }
186            Ok(g.remove(0))
187        }
188    }
189
190    #[tokio::test]
191    async fn async_judge_pass_passes() {
192        let m = Arc::new(ScriptedJudge {
193            replies: Mutex::new(vec!["pass\nlooks good".into()]),
194        });
195        let s = LlmJudgeScorer::new(m);
196        let r = AsyncScorer::score(
197            &s,
198            &Value::String("yes".into()),
199            &Value::String("yes!".into()),
200        )
201        .await;
202        assert!(r.passed);
203        assert!(r.note.contains("looks good"));
204    }
205
206    #[tokio::test]
207    async fn async_judge_propagates_error_into_outcome() {
208        struct FailingJudge;
209        #[async_trait]
210        impl JudgeModel for FailingJudge {
211            async fn judge(&self, _prompt: &str) -> Result<String> {
212                Err(atomr_agents_core::AgentError::Tool("boom".into()))
213            }
214        }
215        let s = LlmJudgeScorer::new(Arc::new(FailingJudge));
216        let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
217        assert!(!r.passed);
218        assert!(r.note.contains("judge error"));
219    }
220
221    #[tokio::test]
222    async fn async_rubric_averages_weighted_scores() {
223        let m = Arc::new(ScriptedJudge {
224            replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
225        });
226        let s = RubricScorer {
227            model: m,
228            criteria: vec![
229                RubricCriterion {
230                    name: "correctness".into(),
231                    description: "is the answer correct".into(),
232                    weight: 1.0,
233                },
234                RubricCriterion {
235                    name: "concision".into(),
236                    description: "is it terse".into(),
237                    weight: 1.0,
238                },
239            ],
240            pass_at: 0.6,
241        };
242        let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
243        assert!((r.score - 0.75).abs() < 1e-5);
244        assert!(r.passed);
245    }
246}