use std::sync::Arc;
use async_trait::async_trait;
use super::eval_case::Invocation;
use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
use super::evaluator::{EvalError, Evaluator};
use crate::llm::BaseLlm;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RubricMode {
FinalResponse,
ToolUse,
}
pub struct RubricEvaluator {
rubrics: Vec<String>,
judge_model: Option<String>,
mode: RubricMode,
llm: Option<Arc<dyn BaseLlm>>,
}
impl RubricEvaluator {
pub fn new(rubrics: Vec<String>) -> Self {
Self {
rubrics,
judge_model: None,
mode: RubricMode::FinalResponse,
llm: None,
}
}
pub fn for_response(rubrics: Vec<String>) -> Self {
Self {
rubrics,
judge_model: None,
mode: RubricMode::FinalResponse,
llm: None,
}
}
pub fn for_tool_use(rubrics: Vec<String>) -> Self {
Self {
rubrics,
judge_model: None,
mode: RubricMode::ToolUse,
llm: None,
}
}
pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
self.judge_model = Some(model.into());
self
}
pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
self.llm = Some(llm);
self
}
fn build_prompt(&self, actual: &Invocation, expected: Option<&Invocation>) -> String {
let mode_label = match self.mode {
RubricMode::FinalResponse => "FINAL RESPONSE QUALITY",
RubricMode::ToolUse => "TOOL USE QUALITY",
};
let mut prompt = format!(
"You are an expert evaluator assessing {mode_label}.\n\n\
Score the agent's performance on a scale of 0.0 to 1.0 for EACH rubric criterion.\n\n"
);
prompt.push_str("RUBRIC CRITERIA:\n");
for (i, rubric) in self.rubrics.iter().enumerate() {
prompt.push_str(&format!("{}. {}\n", i + 1, rubric));
}
prompt.push('\n');
prompt.push_str("ACTUAL AGENT CONVERSATION:\n");
for turn in &actual.turns {
prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
if !turn.tool_calls.is_empty() {
prompt.push_str(&format!(
" Tool calls: {}\n",
serde_json::json!(turn.tool_calls)
));
}
if !turn.tool_results.is_empty() {
prompt.push_str(&format!(
" Tool results: {}\n",
serde_json::json!(turn.tool_results)
));
}
}
if let Some(expected) = expected {
prompt.push_str("\nEXPECTED CONVERSATION:\n");
for turn in &expected.turns {
prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
if !turn.tool_calls.is_empty() {
prompt.push_str(&format!(
" Tool calls: {}\n",
serde_json::json!(turn.tool_calls)
));
}
}
}
prompt.push_str(
"\nRespond with ONLY a JSON object:\n\
{\"scores\": [<float per rubric criterion>], \
\"overall_score\": <float average>, \
\"explanation\": \"<text>\"}\n",
);
prompt
}
fn parse_response(text: &str, num_rubrics: usize) -> (f64, String) {
if let Some((score, explanation)) = try_parse_json(text) {
return (score, explanation);
}
if let Some(start) = text.find('{') {
if let Some(end) = text[start..].rfind('}') {
let json_str = &text[start..=start + end];
if let Some((score, explanation)) = try_parse_json(json_str) {
return (score, explanation);
}
}
}
let _ = num_rubrics; (
0.0,
format!("Failed to parse rubric judge response: {text}"),
)
}
}
fn try_parse_json(text: &str) -> Option<(f64, String)> {
let v: serde_json::Value = serde_json::from_str(text).ok()?;
let score = if let Some(overall) = v["overall_score"].as_f64() {
overall.clamp(0.0, 1.0)
} else if let Some(scores) = v["scores"].as_array() {
let sum: f64 = scores
.iter()
.filter_map(|s| s.as_f64())
.map(|s| s.clamp(0.0, 1.0))
.sum();
let count = scores.len().max(1) as f64;
sum / count
} else {
return None;
};
let explanation = v["explanation"]
.as_str()
.unwrap_or("No explanation")
.to_string();
Some((score, explanation))
}
#[async_trait]
impl Evaluator for RubricEvaluator {
async fn evaluate(
&self,
actual: &[Invocation],
expected: Option<&[Invocation]>,
) -> Result<EvalResult, EvalError> {
let llm = self.llm.as_ref().ok_or_else(|| {
EvalError::Llm(
"RubricEvaluator requires an LLM instance — call .with_llm() before evaluating"
.into(),
)
})?;
let mut per_invocation = Vec::new();
let mut total_score = 0.0;
for (i, actual_inv) in actual.iter().enumerate() {
let expected_inv = expected.and_then(|e| e.get(i));
let prompt = self.build_prompt(actual_inv, expected_inv);
let request = crate::llm::LlmRequest::from_text(&prompt);
let response = llm
.generate(request)
.await
.map_err(|e| EvalError::Llm(e.to_string()))?;
let (score, explanation) = Self::parse_response(&response.text(), self.rubrics.len());
total_score += score;
per_invocation.push(PerInvocationResult {
invocation_id: if actual_inv.id.is_empty() {
format!("inv-{i}")
} else {
actual_inv.id.clone()
},
score,
explanation: Some(explanation),
});
}
let overall_score = if actual.is_empty() {
0.0
} else {
total_score / actual.len() as f64
};
let metric_name = match self.mode {
RubricMode::FinalResponse => "rubric_based_final_response_quality_v1",
RubricMode::ToolUse => "rubric_based_tool_use_quality_v1",
};
Ok(EvalResult {
overall_score,
metrics: vec![EvalMetric {
name: metric_name.into(),
score: overall_score,
per_invocation,
}],
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_valid_response() {
let json = r#"{"scores": [0.8, 0.9], "overall_score": 0.85, "explanation": "Good"}"#;
let (score, explanation) = RubricEvaluator::parse_response(json, 2);
assert!((score - 0.85).abs() < f64::EPSILON);
assert_eq!(explanation, "Good");
}
#[test]
fn parse_scores_only() {
let json = r#"{"scores": [0.8, 0.6]}"#;
let (score, _) = RubricEvaluator::parse_response(json, 2);
assert!((score - 0.7).abs() < f64::EPSILON);
}
#[test]
fn parse_embedded_json() {
let text = r#"Here is my evaluation: {"overall_score": 0.9, "explanation": "Great"}"#;
let (score, _) = RubricEvaluator::parse_response(text, 1);
assert!((score - 0.9).abs() < f64::EPSILON);
}
#[test]
fn parse_invalid() {
let (score, explanation) = RubricEvaluator::parse_response("no json here", 1);
assert!((score - 0.0).abs() < f64::EPSILON);
assert!(explanation.contains("Failed to parse"));
}
#[test]
fn for_response_mode() {
let eval = RubricEvaluator::for_response(vec!["Accuracy".into()]);
assert_eq!(eval.mode, RubricMode::FinalResponse);
}
#[test]
fn for_tool_use_mode() {
let eval = RubricEvaluator::for_tool_use(vec!["Tool selection".into()]);
assert_eq!(eval.mode, RubricMode::ToolUse);
}
#[test]
fn build_prompt_includes_rubrics() {
use crate::evaluation::eval_case::InvocationTurn;
let eval = RubricEvaluator::new(vec![
"Is the response accurate?".into(),
"Is it well-formatted?".into(),
]);
let inv = Invocation {
id: "test".into(),
turns: vec![InvocationTurn {
role: "user".into(),
content: "Hello".into(),
tool_calls: vec![],
tool_results: vec![],
}],
metadata: serde_json::Value::Null,
};
let prompt = eval.build_prompt(&inv, None);
assert!(prompt.contains("Is the response accurate?"));
assert!(prompt.contains("Is it well-formatted?"));
assert!(prompt.contains("FINAL RESPONSE QUALITY"));
}
#[test]
fn with_judge_model() {
let eval = RubricEvaluator::new(vec!["test".into()]).with_judge_model("gemini-2.0-flash");
assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
}
#[test]
fn score_clamped() {
let json = r#"{"overall_score": 1.5, "explanation": "Over"}"#;
let (score, _) = RubricEvaluator::parse_response(json, 1);
assert!((score - 1.0).abs() < f64::EPSILON);
}
}