Skip to main content

atomr_agents_eval/
judge.rs

1//! LLM-judge scorer + rubric-based scorer.
2//!
3//! `JudgeModel` is the trait callers plug a model in through; the
4//! scorer simply prompts it and parses the response. The
5//! `Scorer` trait is sync; we use a blocking call from `tokio` for
6//! the async judge (or, in unit tests, a stub that returns a fixed
7//! response). For production async use, wrap `LlmJudgeScorer` in an
8//! `OnlineEvaluator` that owns its own runtime — see `online_eval`
9//! below.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_agents_core::Result;
15use serde::{Deserialize, Serialize};
16
17use crate::scorer::{Scorer, ScorerOutcome};
18
19#[async_trait]
20pub trait JudgeModel: Send + Sync + 'static {
21    async fn judge(&self, prompt: &str) -> Result<String>;
22}
23
24/// Single-criterion graded scorer — "did the actual output answer the
25/// expected question correctly?". The judge replies `pass` / `fail`
26/// followed by a short justification.
27pub struct LlmJudgeScorer {
28    pub model: Arc<dyn JudgeModel>,
29    pub prompt_template: String,
30}
31
32impl LlmJudgeScorer {
33    pub fn new(model: Arc<dyn JudgeModel>) -> Self {
34        Self {
35            model,
36            prompt_template: include_str_template_default(),
37        }
38    }
39}
40
41fn include_str_template_default() -> String {
42    "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
43}
44
45impl Scorer for LlmJudgeScorer {
46    fn score(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> ScorerOutcome {
47        let prompt = self
48            .prompt_template
49            .replace("{expected}", &expected.to_string())
50            .replace("{actual}", &actual.to_string());
51        // Run the async judge synchronously. Callers running inside
52        // a tokio runtime can use `OnlineEvaluator` instead.
53        let model = self.model.clone();
54        let reply = tokio::task::block_in_place(|| {
55            tokio::runtime::Handle::try_current()
56                .map(|h| h.block_on(model.judge(&prompt)))
57                .unwrap_or_else(|_| {
58                    let rt = tokio::runtime::Builder::new_current_thread()
59                        .enable_all()
60                        .build()
61                        .unwrap();
62                    rt.block_on(model.judge(&prompt))
63                })
64        });
65        let reply = reply.unwrap_or_else(|e| format!("fail\n{e}"));
66        let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
67        let passed = first == "pass";
68        ScorerOutcome {
69            passed,
70            score: if passed { 1.0 } else { 0.0 },
71            note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
72        }
73    }
74}
75
76// --------------------------------------------------------------------
77// RubricScorer — multi-criterion grading. Each criterion is judged
78// individually; the final score is the average.
79// --------------------------------------------------------------------
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct RubricCriterion {
83    pub name: String,
84    pub description: String,
85    pub weight: f32,
86}
87
88pub struct RubricScorer {
89    pub model: Arc<dyn JudgeModel>,
90    pub criteria: Vec<RubricCriterion>,
91    /// Pass threshold on the (weighted) average score.
92    pub pass_at: f32,
93}
94
95impl Scorer for RubricScorer {
96    fn score(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> ScorerOutcome {
97        let mut total = 0.0;
98        let mut total_w = 0.0;
99        let mut notes = Vec::new();
100        for c in &self.criteria {
101            let prompt = format!(
102                "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
103                c.name, c.description, expected, actual
104            );
105            let model = self.model.clone();
106            let reply = tokio::task::block_in_place(|| {
107                tokio::runtime::Handle::try_current()
108                    .map(|h| h.block_on(model.judge(&prompt)))
109                    .unwrap_or_else(|_| {
110                        let rt = tokio::runtime::Builder::new_current_thread()
111                            .enable_all()
112                            .build()
113                            .unwrap();
114                        rt.block_on(model.judge(&prompt))
115                    })
116            });
117            let reply = reply.unwrap_or_else(|e| format!("0\n{e}"));
118            let score: f32 = reply
119                .lines()
120                .next()
121                .and_then(|s| s.trim().parse().ok())
122                .unwrap_or(0.0);
123            total += score * c.weight;
124            total_w += c.weight;
125            notes.push(format!("{}={}", c.name, score));
126        }
127        let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
128        let normalized = (avg / 10.0).clamp(0.0, 1.0);
129        ScorerOutcome {
130            passed: normalized >= self.pass_at,
131            score: normalized,
132            note: notes.join(", "),
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use atomr_agents_core::Value;
141    use parking_lot::Mutex;
142
143    struct ScriptedJudge {
144        replies: Mutex<Vec<String>>,
145    }
146    #[async_trait]
147    impl JudgeModel for ScriptedJudge {
148        async fn judge(&self, _prompt: &str) -> Result<String> {
149            let mut g = self.replies.lock();
150            if g.is_empty() {
151                return Ok("fail\nout of replies".into());
152            }
153            Ok(g.remove(0))
154        }
155    }
156
157    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
158    async fn judge_pass_passes() {
159        let m = Arc::new(ScriptedJudge {
160            replies: Mutex::new(vec!["pass\nlooks good".into()]),
161        });
162        let s = LlmJudgeScorer::new(m);
163        let r = s.score(&Value::String("yes".into()), &Value::String("yes!".into()));
164        assert!(r.passed);
165        assert!(r.note.contains("looks good"));
166    }
167
168    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
169    async fn rubric_averages_weighted_scores() {
170        let m = Arc::new(ScriptedJudge {
171            replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
172        });
173        let s = RubricScorer {
174            model: m,
175            criteria: vec![
176                RubricCriterion {
177                    name: "correctness".into(),
178                    description: "is the answer correct".into(),
179                    weight: 1.0,
180                },
181                RubricCriterion {
182                    name: "concision".into(),
183                    description: "is it terse".into(),
184                    weight: 1.0,
185                },
186            ],
187            pass_at: 0.6,
188        };
189        let r = s.score(&Value::Null, &Value::Null);
190        // (10*1 + 5*1) / (1+1) / 10 = 0.75
191        assert!((r.score - 0.75).abs() < 1e-5);
192        assert!(r.passed);
193    }
194}