1use std::sync::Arc;
2
3use ai_agents_core::{ChatMessage, LLMProvider};
4use ai_agents_llm::LLMRegistry;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{EvalError, Result};
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
12pub struct JudgeAssertion {
13 #[serde(default)]
15 pub llm: Option<String>,
16 #[serde(default = "default_threshold")]
18 pub pass_threshold: f32,
19 #[serde(default)]
21 pub criteria: Vec<JudgeCriterion>,
22}
23
24#[derive(Debug, Clone, Deserialize, Serialize)]
26#[serde(untagged)]
27pub enum JudgeCriterion {
28 Text(String),
29 Object {
30 name: String,
31 description: String,
32 #[serde(default = "default_weight")]
33 weight: f32,
34 },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct JudgeConfig {
40 #[serde(default = "default_true")]
42 pub enabled: bool,
43 #[serde(default)]
45 pub llm: Option<String>,
46 #[serde(default)]
48 pub default_criteria: Vec<JudgeCriterion>,
49 #[serde(default = "default_threshold")]
51 pub pass_threshold: f32,
52 #[serde(default = "default_true")]
54 pub require_json: bool,
55}
56
57impl Default for JudgeConfig {
58 fn default() -> Self {
59 Self {
60 enabled: true,
61 llm: None,
62 default_criteria: Vec::new(),
63 pass_threshold: default_threshold(),
64 require_json: true,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct JudgeResult {
72 pub criteria_scores: Vec<CriterionScore>,
74 pub overall_score: f32,
76 pub overall_feedback: String,
78 pub passed: bool,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub raw_response: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CriterionScore {
88 pub name: String,
90 pub score: f32,
92 #[serde(default)]
94 pub explanation: String,
95}
96
97pub struct JudgeInput<'a> {
99 pub response: &'a str,
101 pub user_input: Option<&'a str>,
103 pub scenario_id: Option<&'a str>,
105 pub language: Option<&'a str>,
107}
108
109pub struct JudgeResolver {
111 registry: Arc<LLMRegistry>,
113 config: JudgeConfig,
115}
116
117impl JudgeResolver {
118 pub fn new(registry: Arc<LLMRegistry>, config: JudgeConfig) -> Self {
119 Self { registry, config }
120 }
121
122 pub fn resolve(&self, alias: Option<&str>) -> Result<LLMJudge> {
123 let llm = if let Some(alias) = alias {
124 self.registry
125 .get(alias)
126 .map_err(|error| EvalError::Judge(error.to_string()))?
127 } else {
128 self.registry
129 .router()
130 .or_else(|_| self.registry.default())
131 .map_err(|error| EvalError::Judge(error.to_string()))?
132 };
133 Ok(LLMJudge::new(llm, self.config.clone()))
134 }
135}
136
137pub struct LLMJudge {
139 llm: Arc<dyn LLMProvider>,
141 config: JudgeConfig,
143}
144
145impl LLMJudge {
146 pub fn new(llm: Arc<dyn LLMProvider>, config: JudgeConfig) -> Self {
147 Self { llm, config }
148 }
149
150 pub async fn evaluate(
151 &self,
152 response: &str,
153 assertion: &JudgeAssertion,
154 ) -> Result<JudgeResult> {
155 self.evaluate_input(
156 JudgeInput {
157 response,
158 user_input: None,
159 scenario_id: None,
160 language: None,
161 },
162 assertion,
163 )
164 .await
165 }
166
167 pub async fn evaluate_input(
168 &self,
169 input: JudgeInput<'_>,
170 assertion: &JudgeAssertion,
171 ) -> Result<JudgeResult> {
172 let criteria = if assertion.criteria.is_empty() {
173 self.config.default_criteria.clone()
174 } else {
175 assertion.criteria.clone()
176 };
177 if criteria.is_empty() {
178 return Err(EvalError::Judge("judge assertion has no criteria".into()));
179 }
180 let threshold = assertion.pass_threshold;
181 let prompt = build_prompt(input, &criteria, threshold);
182 let llm_response = self
183 .llm
184 .complete(&[ChatMessage::user(&prompt)], None)
185 .await
186 .map_err(|error| EvalError::Judge(error.to_string()))?;
187 let value = extract_json(&llm_response.content)
188 .ok_or_else(|| EvalError::Judge("judge did not return JSON".into()))?;
189 let mut result: JudgeResult = serde_json::from_value(value)
190 .map_err(|error| EvalError::Judge(format!("invalid judge JSON: {}", error)))?;
191 result.passed = result.overall_score >= threshold;
192 if !self.config.require_json {
193 result.raw_response = Some(llm_response.content);
194 }
195 Ok(result)
196 }
197}
198
199fn build_prompt(input: JudgeInput<'_>, criteria: &[JudgeCriterion], threshold: f32) -> String {
200 let criteria_text = criteria
201 .iter()
202 .enumerate()
203 .map(|(idx, criterion)| match criterion {
204 JudgeCriterion::Text(text) => format!("{}. {} (weight 1.0)", idx + 1, text),
205 JudgeCriterion::Object {
206 name,
207 description,
208 weight,
209 } => {
210 format!("{}. {}: {} (weight {})", idx + 1, name, description, weight)
211 }
212 })
213 .collect::<Vec<_>>()
214 .join("\n");
215 let user_input = input.user_input.unwrap_or("");
216 let scenario_id = input.scenario_id.unwrap_or("");
217 let language = input.language.unwrap_or("");
218 let response = input.response;
219 format!(
220 r#"Evaluate the assistant response against the criteria.
221Evaluate semantic meaning across languages. Do not require exact wording unless a criterion says so.
222Return strict JSON only with this shape:
223{{"criteria_scores":[{{"name":"criterion","score":0.0,"explanation":"brief"}}],"overall_score":0.0,"overall_feedback":"brief","passed":false}}
224Pass threshold: {threshold}
225
226Scenario ID: {scenario_id}
227Language: {language}
228User input: {user_input}
229
230Criteria:
231{criteria_text}
232
233Assistant response:
234{response}"#
235 )
236}
237
238fn extract_json(text: &str) -> Option<Value> {
239 if let Ok(value) = serde_json::from_str(text.trim()) {
240 return Some(value);
241 }
242 let start = text.find('{')?;
243 let end = text.rfind('}')?;
244 serde_json::from_str(&text[start..=end]).ok()
245}
246
247fn default_threshold() -> f32 {
248 0.75
249}
250
251fn default_weight() -> f32 {
252 1.0
253}
254
255fn default_true() -> bool {
256 true
257}