atomr_agents_eval/
judge.rs1use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_agents_core::Result;
15use serde::{Deserialize, Serialize};
16
17use crate::scorer::{Scorer, ScorerOutcome};
18
19#[async_trait]
20pub trait JudgeModel: Send + Sync + 'static {
21 async fn judge(&self, prompt: &str) -> Result<String>;
22}
23
24pub struct LlmJudgeScorer {
28 pub model: Arc<dyn JudgeModel>,
29 pub prompt_template: String,
30}
31
32impl LlmJudgeScorer {
33 pub fn new(model: Arc<dyn JudgeModel>) -> Self {
34 Self {
35 model,
36 prompt_template: include_str_template_default(),
37 }
38 }
39}
40
41fn include_str_template_default() -> String {
42 "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
43}
44
45impl Scorer for LlmJudgeScorer {
46 fn score(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> ScorerOutcome {
47 let prompt = self
48 .prompt_template
49 .replace("{expected}", &expected.to_string())
50 .replace("{actual}", &actual.to_string());
51 let model = self.model.clone();
54 let reply = tokio::task::block_in_place(|| {
55 tokio::runtime::Handle::try_current()
56 .map(|h| h.block_on(model.judge(&prompt)))
57 .unwrap_or_else(|_| {
58 let rt = tokio::runtime::Builder::new_current_thread()
59 .enable_all()
60 .build()
61 .unwrap();
62 rt.block_on(model.judge(&prompt))
63 })
64 });
65 let reply = reply.unwrap_or_else(|e| format!("fail\n{e}"));
66 let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
67 let passed = first == "pass";
68 ScorerOutcome {
69 passed,
70 score: if passed { 1.0 } else { 0.0 },
71 note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct RubricCriterion {
83 pub name: String,
84 pub description: String,
85 pub weight: f32,
86}
87
88pub struct RubricScorer {
89 pub model: Arc<dyn JudgeModel>,
90 pub criteria: Vec<RubricCriterion>,
91 pub pass_at: f32,
93}
94
95impl Scorer for RubricScorer {
96 fn score(&self, expected: &atomr_agents_core::Value, actual: &atomr_agents_core::Value) -> ScorerOutcome {
97 let mut total = 0.0;
98 let mut total_w = 0.0;
99 let mut notes = Vec::new();
100 for c in &self.criteria {
101 let prompt = format!(
102 "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
103 c.name, c.description, expected, actual
104 );
105 let model = self.model.clone();
106 let reply = tokio::task::block_in_place(|| {
107 tokio::runtime::Handle::try_current()
108 .map(|h| h.block_on(model.judge(&prompt)))
109 .unwrap_or_else(|_| {
110 let rt = tokio::runtime::Builder::new_current_thread()
111 .enable_all()
112 .build()
113 .unwrap();
114 rt.block_on(model.judge(&prompt))
115 })
116 });
117 let reply = reply.unwrap_or_else(|e| format!("0\n{e}"));
118 let score: f32 = reply
119 .lines()
120 .next()
121 .and_then(|s| s.trim().parse().ok())
122 .unwrap_or(0.0);
123 total += score * c.weight;
124 total_w += c.weight;
125 notes.push(format!("{}={}", c.name, score));
126 }
127 let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
128 let normalized = (avg / 10.0).clamp(0.0, 1.0);
129 ScorerOutcome {
130 passed: normalized >= self.pass_at,
131 score: normalized,
132 note: notes.join(", "),
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use atomr_agents_core::Value;
141 use parking_lot::Mutex;
142
143 struct ScriptedJudge {
144 replies: Mutex<Vec<String>>,
145 }
146 #[async_trait]
147 impl JudgeModel for ScriptedJudge {
148 async fn judge(&self, _prompt: &str) -> Result<String> {
149 let mut g = self.replies.lock();
150 if g.is_empty() {
151 return Ok("fail\nout of replies".into());
152 }
153 Ok(g.remove(0))
154 }
155 }
156
157 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
158 async fn judge_pass_passes() {
159 let m = Arc::new(ScriptedJudge {
160 replies: Mutex::new(vec!["pass\nlooks good".into()]),
161 });
162 let s = LlmJudgeScorer::new(m);
163 let r = s.score(&Value::String("yes".into()), &Value::String("yes!".into()));
164 assert!(r.passed);
165 assert!(r.note.contains("looks good"));
166 }
167
168 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
169 async fn rubric_averages_weighted_scores() {
170 let m = Arc::new(ScriptedJudge {
171 replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
172 });
173 let s = RubricScorer {
174 model: m,
175 criteria: vec![
176 RubricCriterion {
177 name: "correctness".into(),
178 description: "is the answer correct".into(),
179 weight: 1.0,
180 },
181 RubricCriterion {
182 name: "concision".into(),
183 description: "is it terse".into(),
184 weight: 1.0,
185 },
186 ],
187 pass_at: 0.6,
188 };
189 let r = s.score(&Value::Null, &Value::Null);
190 assert!((r.score - 0.75).abs() < 1e-5);
192 assert!(r.passed);
193 }
194}