atomr_agents_eval/
judge.rs1use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_agents_core::Result;
20use serde::{Deserialize, Serialize};
21
22use crate::scorer::{AsyncScorer, ScorerOutcome};
23
24#[async_trait]
25pub trait JudgeModel: Send + Sync + 'static {
26 async fn judge(&self, prompt: &str) -> Result<String>;
27}
28
29pub struct LlmJudgeScorer {
33 pub model: Arc<dyn JudgeModel>,
34 pub prompt_template: String,
35}
36
37impl LlmJudgeScorer {
38 pub fn new(model: Arc<dyn JudgeModel>) -> Self {
39 Self {
40 model,
41 prompt_template: include_str_template_default(),
42 }
43 }
44
45 fn build_prompt(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> String {
46 self.prompt_template
47 .replace("{expected}", &expected.to_string())
48 .replace("{actual}", &actual.to_string())
49 }
50}
51
52fn include_str_template_default() -> String {
53 "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
54}
55
56fn parse_judge_reply(reply: &str) -> ScorerOutcome {
57 let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
58 let passed = first == "pass";
59 ScorerOutcome {
60 passed,
61 score: if passed { 1.0 } else { 0.0 },
62 note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
63 }
64}
65
66#[async_trait]
67impl AsyncScorer for LlmJudgeScorer {
68 async fn score(
69 &self,
70 expected: &atomr_agents_core::Value,
71 actual: &atomr_agents_core::Value,
72 ) -> ScorerOutcome {
73 let prompt = self.build_prompt(expected, actual);
74 match self.model.judge(&prompt).await {
75 Ok(reply) => parse_judge_reply(&reply),
76 Err(e) => ScorerOutcome {
77 passed: false,
78 score: 0.0,
79 note: format!("judge error: {e}"),
80 },
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct RubricCriterion {
92 pub name: String,
93 pub description: String,
94 pub weight: f32,
95}
96
97pub struct RubricScorer {
98 pub model: Arc<dyn JudgeModel>,
99 pub criteria: Vec<RubricCriterion>,
100 pub pass_at: f32,
102}
103
104impl RubricScorer {
105 fn build_criterion_prompt(
106 c: &RubricCriterion,
107 expected: &atomr_agents_core::Value,
108 actual: &atomr_agents_core::Value,
109 ) -> String {
110 format!(
111 "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
112 c.name, c.description, expected, actual
113 )
114 }
115
116 fn aggregate(results: &[(&RubricCriterion, f32)], pass_at: f32) -> ScorerOutcome {
117 let mut total = 0.0;
118 let mut total_w = 0.0;
119 let mut notes = Vec::with_capacity(results.len());
120 for (c, score) in results {
121 total += score * c.weight;
122 total_w += c.weight;
123 notes.push(format!("{}={}", c.name, score));
124 }
125 let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
126 let normalized = (avg / 10.0).clamp(0.0, 1.0);
127 ScorerOutcome {
128 passed: normalized >= pass_at,
129 score: normalized,
130 note: notes.join(", "),
131 }
132 }
133}
134
135fn parse_rubric_score(reply: &str) -> f32 {
136 reply
137 .lines()
138 .next()
139 .and_then(|s| s.trim().parse().ok())
140 .unwrap_or(0.0)
141}
142
143#[async_trait]
144impl AsyncScorer for RubricScorer {
145 async fn score(
146 &self,
147 expected: &atomr_agents_core::Value,
148 actual: &atomr_agents_core::Value,
149 ) -> ScorerOutcome {
150 let mut scored: Vec<(&RubricCriterion, f32)> = Vec::with_capacity(self.criteria.len());
151 for c in &self.criteria {
152 let prompt = Self::build_criterion_prompt(c, expected, actual);
153 let reply = match self.model.judge(&prompt).await {
154 Ok(r) => r,
155 Err(e) => format!("0\njudge error: {e}"),
156 };
157 scored.push((c, parse_rubric_score(&reply)));
158 }
159 Self::aggregate(&scored, self.pass_at)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use atomr_agents_core::Value;
167 use parking_lot::Mutex;
168
169 struct ScriptedJudge {
170 replies: Mutex<Vec<String>>,
171 }
172 #[async_trait]
173 impl JudgeModel for ScriptedJudge {
174 async fn judge(&self, _prompt: &str) -> Result<String> {
175 let mut g = self.replies.lock();
176 if g.is_empty() {
177 return Ok("fail\nout of replies".into());
178 }
179 Ok(g.remove(0))
180 }
181 }
182
183 #[tokio::test]
184 async fn async_judge_pass_passes() {
185 let m = Arc::new(ScriptedJudge {
186 replies: Mutex::new(vec!["pass\nlooks good".into()]),
187 });
188 let s = LlmJudgeScorer::new(m);
189 let r = AsyncScorer::score(&s, &Value::String("yes".into()), &Value::String("yes!".into())).await;
190 assert!(r.passed);
191 assert!(r.note.contains("looks good"));
192 }
193
194 #[tokio::test]
195 async fn async_judge_propagates_error_into_outcome() {
196 struct FailingJudge;
197 #[async_trait]
198 impl JudgeModel for FailingJudge {
199 async fn judge(&self, _prompt: &str) -> Result<String> {
200 Err(atomr_agents_core::AgentError::Tool("boom".into()))
201 }
202 }
203 let s = LlmJudgeScorer::new(Arc::new(FailingJudge));
204 let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
205 assert!(!r.passed);
206 assert!(r.note.contains("judge error"));
207 }
208
209 #[tokio::test]
210 async fn async_rubric_averages_weighted_scores() {
211 let m = Arc::new(ScriptedJudge {
212 replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
213 });
214 let s = RubricScorer {
215 model: m,
216 criteria: vec![
217 RubricCriterion {
218 name: "correctness".into(),
219 description: "is the answer correct".into(),
220 weight: 1.0,
221 },
222 RubricCriterion {
223 name: "concision".into(),
224 description: "is it terse".into(),
225 weight: 1.0,
226 },
227 ],
228 pass_at: 0.6,
229 };
230 let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
231 assert!((r.score - 0.75).abs() < 1e-5);
232 assert!(r.passed);
233 }
234}