Skip to main content

ai_agents_eval/
judge.rs

1use std::sync::Arc;
2
3use ai_agents_core::{ChatMessage, LLMProvider};
4use ai_agents_llm::LLMRegistry;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::{EvalError, Result};
9
10/// Semantic judge assertion declared in an eval suite.
11#[derive(Debug, Clone, Deserialize, Serialize)]
12pub struct JudgeAssertion {
13    /// Optional LLM alias or provider used for judge calls.
14    #[serde(default)]
15    pub llm: Option<String>,
16    /// Minimum overall score required to pass.
17    #[serde(default = "default_threshold")]
18    pub pass_threshold: f32,
19    /// Criteria used by the judge prompt.
20    #[serde(default)]
21    pub criteria: Vec<JudgeCriterion>,
22}
23
24/// Text or weighted object form for judge criteria.
25#[derive(Debug, Clone, Deserialize, Serialize)]
26#[serde(untagged)]
27pub enum JudgeCriterion {
28    Text(String),
29    Object {
30        name: String,
31        description: String,
32        #[serde(default = "default_weight")]
33        weight: f32,
34    },
35}
36
37/// Default behavior for LLM judge evaluation.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct JudgeConfig {
40    /// Whether this feature is enabled.
41    #[serde(default = "default_true")]
42    pub enabled: bool,
43    /// Optional LLM alias or provider used for judge calls.
44    #[serde(default)]
45    pub llm: Option<String>,
46    /// Criteria used when assertions omit criteria.
47    #[serde(default)]
48    pub default_criteria: Vec<JudgeCriterion>,
49    /// Minimum overall score required to pass.
50    #[serde(default = "default_threshold")]
51    pub pass_threshold: f32,
52    /// Whether judge responses must be strict JSON.
53    #[serde(default = "default_true")]
54    pub require_json: bool,
55}
56
57impl Default for JudgeConfig {
58    fn default() -> Self {
59        Self {
60            enabled: true,
61            llm: None,
62            default_criteria: Vec::new(),
63            pass_threshold: default_threshold(),
64            require_json: true,
65        }
66    }
67}
68
69/// Parsed JSON result returned by an LLM judge.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct JudgeResult {
72    /// Scores for individual criteria.
73    pub criteria_scores: Vec<CriterionScore>,
74    /// Aggregated score used for pass or fail.
75    pub overall_score: f32,
76    /// Brief feedback returned by the judge.
77    pub overall_feedback: String,
78    /// Passed count or boolean result.
79    pub passed: bool,
80    /// Optional raw judge response for debugging.
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub raw_response: Option<String>,
83}
84
85/// Score for one criterion inside a judge result.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CriterionScore {
88    /// Human-readable name or criterion name.
89    pub name: String,
90    /// Numeric score assigned by the judge.
91    pub score: f32,
92    /// Brief explanation for the score.
93    #[serde(default)]
94    pub explanation: String,
95}
96
97/// Context passed to the judge prompt for one evaluation.
98pub struct JudgeInput<'a> {
99    /// Assistant response text or redacted output value.
100    pub response: &'a str,
101    /// Optional user input for judge prompt context.
102    pub user_input: Option<&'a str>,
103    /// Optional scenario ID for judge prompt context.
104    pub scenario_id: Option<&'a str>,
105    /// Optional language label for filtering, metrics, and judge context.
106    pub language: Option<&'a str>,
107}
108
109/// Resolves judge LLM aliases from the runtime registry.
110pub struct JudgeResolver {
111    /// Runtime LLM registry used to resolve aliases.
112    registry: Arc<LLMRegistry>,
113    /// Configuration used by this component.
114    config: JudgeConfig,
115}
116
117impl JudgeResolver {
118    pub fn new(registry: Arc<LLMRegistry>, config: JudgeConfig) -> Self {
119        Self { registry, config }
120    }
121
122    pub fn resolve(&self, alias: Option<&str>) -> Result<LLMJudge> {
123        let llm = if let Some(alias) = alias {
124            self.registry
125                .get(alias)
126                .map_err(|error| EvalError::Judge(error.to_string()))?
127        } else {
128            self.registry
129                .router()
130                .or_else(|_| self.registry.default())
131                .map_err(|error| EvalError::Judge(error.to_string()))?
132        };
133        Ok(LLMJudge::new(llm, self.config.clone()))
134    }
135}
136
137/// Wrapper that asks an LLM to score semantic response quality.
138pub struct LLMJudge {
139    /// Optional LLM alias or provider used for judge calls.
140    llm: Arc<dyn LLMProvider>,
141    /// Configuration used by this component.
142    config: JudgeConfig,
143}
144
145impl LLMJudge {
146    pub fn new(llm: Arc<dyn LLMProvider>, config: JudgeConfig) -> Self {
147        Self { llm, config }
148    }
149
150    pub async fn evaluate(
151        &self,
152        response: &str,
153        assertion: &JudgeAssertion,
154    ) -> Result<JudgeResult> {
155        self.evaluate_input(
156            JudgeInput {
157                response,
158                user_input: None,
159                scenario_id: None,
160                language: None,
161            },
162            assertion,
163        )
164        .await
165    }
166
167    pub async fn evaluate_input(
168        &self,
169        input: JudgeInput<'_>,
170        assertion: &JudgeAssertion,
171    ) -> Result<JudgeResult> {
172        let criteria = if assertion.criteria.is_empty() {
173            self.config.default_criteria.clone()
174        } else {
175            assertion.criteria.clone()
176        };
177        if criteria.is_empty() {
178            return Err(EvalError::Judge("judge assertion has no criteria".into()));
179        }
180        let threshold = assertion.pass_threshold;
181        let prompt = build_prompt(input, &criteria, threshold);
182        let llm_response = self
183            .llm
184            .complete(&[ChatMessage::user(&prompt)], None)
185            .await
186            .map_err(|error| EvalError::Judge(error.to_string()))?;
187        let value = extract_json(&llm_response.content)
188            .ok_or_else(|| EvalError::Judge("judge did not return JSON".into()))?;
189        let mut result: JudgeResult = serde_json::from_value(value)
190            .map_err(|error| EvalError::Judge(format!("invalid judge JSON: {}", error)))?;
191        result.passed = result.overall_score >= threshold;
192        if !self.config.require_json {
193            result.raw_response = Some(llm_response.content);
194        }
195        Ok(result)
196    }
197}
198
199fn build_prompt(input: JudgeInput<'_>, criteria: &[JudgeCriterion], threshold: f32) -> String {
200    let criteria_text = criteria
201        .iter()
202        .enumerate()
203        .map(|(idx, criterion)| match criterion {
204            JudgeCriterion::Text(text) => format!("{}. {} (weight 1.0)", idx + 1, text),
205            JudgeCriterion::Object {
206                name,
207                description,
208                weight,
209            } => {
210                format!("{}. {}: {} (weight {})", idx + 1, name, description, weight)
211            }
212        })
213        .collect::<Vec<_>>()
214        .join("\n");
215    let user_input = input.user_input.unwrap_or("");
216    let scenario_id = input.scenario_id.unwrap_or("");
217    let language = input.language.unwrap_or("");
218    let response = input.response;
219    format!(
220        r#"Evaluate the assistant response against the criteria.
221Evaluate semantic meaning across languages. Do not require exact wording unless a criterion says so.
222Return strict JSON only with this shape:
223{{"criteria_scores":[{{"name":"criterion","score":0.0,"explanation":"brief"}}],"overall_score":0.0,"overall_feedback":"brief","passed":false}}
224Pass threshold: {threshold}
225
226Scenario ID: {scenario_id}
227Language: {language}
228User input: {user_input}
229
230Criteria:
231{criteria_text}
232
233Assistant response:
234{response}"#
235    )
236}
237
238fn extract_json(text: &str) -> Option<Value> {
239    if let Ok(value) = serde_json::from_str(text.trim()) {
240        return Some(value);
241    }
242    let start = text.find('{')?;
243    let end = text.rfind('}')?;
244    serde_json::from_str(&text[start..=end]).ok()
245}
246
247fn default_threshold() -> f32 {
248    0.75
249}
250
251fn default_weight() -> f32 {
252    1.0
253}
254
255fn default_true() -> bool {
256    true
257}