1use std::sync::Arc;
17
18use async_trait::async_trait;
19use atomr_agents_core::Result;
20use serde::{Deserialize, Serialize};
21
22use crate::scorer::{AsyncScorer, ScorerOutcome};
23
24#[async_trait]
25pub trait JudgeModel: Send + Sync + 'static {
26 async fn judge(&self, prompt: &str) -> Result<String>;
27}
28
29pub struct LlmJudgeScorer {
33 pub model: Arc<dyn JudgeModel>,
34 pub prompt_template: String,
35}
36
37impl LlmJudgeScorer {
38 pub fn new(model: Arc<dyn JudgeModel>) -> Self {
39 Self {
40 model,
41 prompt_template: include_str_template_default(),
42 }
43 }
44
45 fn build_prompt(
46 &self,
47 expected: &atomr_agents_core::Value,
48 actual: &atomr_agents_core::Value,
49 ) -> String {
50 self.prompt_template
51 .replace("{expected}", &expected.to_string())
52 .replace("{actual}", &actual.to_string())
53 }
54}
55
56fn include_str_template_default() -> String {
57 "You are an evaluator. Given the expected outcome and the actual output, reply on the first line with exactly 'pass' or 'fail' and on the next line a one-sentence justification.\n\nExpected:\n{expected}\n\nActual:\n{actual}".into()
58}
59
60fn parse_judge_reply(reply: &str) -> ScorerOutcome {
61 let first = reply.lines().next().unwrap_or("").trim().to_lowercase();
62 let passed = first == "pass";
63 ScorerOutcome {
64 passed,
65 score: if passed { 1.0 } else { 0.0 },
66 note: reply.lines().nth(1).unwrap_or("").trim().to_string(),
67 }
68}
69
70#[async_trait]
71impl AsyncScorer for LlmJudgeScorer {
72 async fn score(
73 &self,
74 expected: &atomr_agents_core::Value,
75 actual: &atomr_agents_core::Value,
76 ) -> ScorerOutcome {
77 let prompt = self.build_prompt(expected, actual);
78 match self.model.judge(&prompt).await {
79 Ok(reply) => parse_judge_reply(&reply),
80 Err(e) => ScorerOutcome {
81 passed: false,
82 score: 0.0,
83 note: format!("judge error: {e}"),
84 },
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RubricCriterion {
96 pub name: String,
97 pub description: String,
98 pub weight: f32,
99}
100
101pub struct RubricScorer {
102 pub model: Arc<dyn JudgeModel>,
103 pub criteria: Vec<RubricCriterion>,
104 pub pass_at: f32,
106}
107
108impl RubricScorer {
109 fn build_criterion_prompt(
110 c: &RubricCriterion,
111 expected: &atomr_agents_core::Value,
112 actual: &atomr_agents_core::Value,
113 ) -> String {
114 format!(
115 "Score from 0 to 10 ONLY. Criterion: {} — {}.\nExpected:\n{}\nActual:\n{}\nFirst line: integer score. Second line: short justification.",
116 c.name, c.description, expected, actual
117 )
118 }
119
120 fn aggregate(
121 results: &[(&RubricCriterion, f32)],
122 pass_at: f32,
123 ) -> ScorerOutcome {
124 let mut total = 0.0;
125 let mut total_w = 0.0;
126 let mut notes = Vec::with_capacity(results.len());
127 for (c, score) in results {
128 total += score * c.weight;
129 total_w += c.weight;
130 notes.push(format!("{}={}", c.name, score));
131 }
132 let avg = if total_w > 0.0 { total / total_w } else { 0.0 };
133 let normalized = (avg / 10.0).clamp(0.0, 1.0);
134 ScorerOutcome {
135 passed: normalized >= pass_at,
136 score: normalized,
137 note: notes.join(", "),
138 }
139 }
140}
141
142fn parse_rubric_score(reply: &str) -> f32 {
143 reply
144 .lines()
145 .next()
146 .and_then(|s| s.trim().parse().ok())
147 .unwrap_or(0.0)
148}
149
150#[async_trait]
151impl AsyncScorer for RubricScorer {
152 async fn score(
153 &self,
154 expected: &atomr_agents_core::Value,
155 actual: &atomr_agents_core::Value,
156 ) -> ScorerOutcome {
157 let mut scored: Vec<(&RubricCriterion, f32)> = Vec::with_capacity(self.criteria.len());
158 for c in &self.criteria {
159 let prompt = Self::build_criterion_prompt(c, expected, actual);
160 let reply = match self.model.judge(&prompt).await {
161 Ok(r) => r,
162 Err(e) => format!("0\njudge error: {e}"),
163 };
164 scored.push((c, parse_rubric_score(&reply)));
165 }
166 Self::aggregate(&scored, self.pass_at)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use atomr_agents_core::Value;
174 use parking_lot::Mutex;
175
176 struct ScriptedJudge {
177 replies: Mutex<Vec<String>>,
178 }
179 #[async_trait]
180 impl JudgeModel for ScriptedJudge {
181 async fn judge(&self, _prompt: &str) -> Result<String> {
182 let mut g = self.replies.lock();
183 if g.is_empty() {
184 return Ok("fail\nout of replies".into());
185 }
186 Ok(g.remove(0))
187 }
188 }
189
190 #[tokio::test]
191 async fn async_judge_pass_passes() {
192 let m = Arc::new(ScriptedJudge {
193 replies: Mutex::new(vec!["pass\nlooks good".into()]),
194 });
195 let s = LlmJudgeScorer::new(m);
196 let r = AsyncScorer::score(
197 &s,
198 &Value::String("yes".into()),
199 &Value::String("yes!".into()),
200 )
201 .await;
202 assert!(r.passed);
203 assert!(r.note.contains("looks good"));
204 }
205
206 #[tokio::test]
207 async fn async_judge_propagates_error_into_outcome() {
208 struct FailingJudge;
209 #[async_trait]
210 impl JudgeModel for FailingJudge {
211 async fn judge(&self, _prompt: &str) -> Result<String> {
212 Err(atomr_agents_core::AgentError::Tool("boom".into()))
213 }
214 }
215 let s = LlmJudgeScorer::new(Arc::new(FailingJudge));
216 let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
217 assert!(!r.passed);
218 assert!(r.note.contains("judge error"));
219 }
220
221 #[tokio::test]
222 async fn async_rubric_averages_weighted_scores() {
223 let m = Arc::new(ScriptedJudge {
224 replies: Mutex::new(vec!["10\nperfect".into(), "5\nokay".into()]),
225 });
226 let s = RubricScorer {
227 model: m,
228 criteria: vec![
229 RubricCriterion {
230 name: "correctness".into(),
231 description: "is the answer correct".into(),
232 weight: 1.0,
233 },
234 RubricCriterion {
235 name: "concision".into(),
236 description: "is it terse".into(),
237 weight: 1.0,
238 },
239 ],
240 pass_at: 0.6,
241 };
242 let r = AsyncScorer::score(&s, &Value::Null, &Value::Null).await;
243 assert!((r.score - 0.75).abs() < 1e-5);
244 assert!(r.passed);
245 }
246}