use std::sync::Arc;
use async_trait::async_trait;
use super::eval_case::Invocation;
use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
use super::evaluator::{EvalError, Evaluator};
use crate::llm::BaseLlm;
pub struct HallucinationEvaluator {
judge_model: Option<String>,
evaluate_intermediate: bool,
llm: Option<Arc<dyn BaseLlm>>,
}
impl HallucinationEvaluator {
pub fn new() -> Self {
Self {
judge_model: None,
evaluate_intermediate: false,
llm: None,
}
}
pub fn with_intermediate(mut self, eval: bool) -> Self {
self.evaluate_intermediate = eval;
self
}
pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
self.judge_model = Some(model.into());
self
}
pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
self.llm = Some(llm);
self
}
fn extract_context(inv: &Invocation) -> String {
let mut context = String::new();
for turn in &inv.turns {
match turn.role.as_str() {
"user" => {
context.push_str(&format!("USER INPUT: {}\n", turn.content));
}
"model" if !turn.tool_results.is_empty() => {
for result in &turn.tool_results {
context.push_str(&format!("TOOL RESULT: {}\n", result));
}
}
_ => {}
}
}
context
}
fn extract_responses(inv: &Invocation, include_intermediate: bool) -> Vec<String> {
let model_turns: Vec<&str> = inv
.turns
.iter()
.filter(|t| t.role == "model" && !t.content.is_empty())
.map(|t| t.content.as_str())
.collect();
if include_intermediate {
model_turns.into_iter().map(String::from).collect()
} else {
model_turns
.last()
.map(|s| vec![s.to_string()])
.unwrap_or_default()
}
}
fn build_prompt(context: &str, response: &str) -> String {
format!(
"You are an expert evaluator assessing GROUNDEDNESS (absence of hallucination).\n\n\
Your task: determine whether the agent's response is fully supported by the \
provided context. A response is grounded if every factual claim it makes can \
be traced back to information in the context.\n\n\
GROUNDING CONTEXT (source of truth):\n\
{context}\n\n\
AGENT RESPONSE TO EVALUATE:\n\
{response}\n\n\
Scoring guide:\n\
- 1.0: Every claim is directly supported by the context\n\
- 0.75: Most claims are supported, minor unsupported details\n\
- 0.5: Mix of supported and unsupported claims\n\
- 0.25: Mostly unsupported claims with some grounded elements\n\
- 0.0: Entirely fabricated or contradicts the context\n\n\
Respond with ONLY a JSON object:\n\
{{\"score\": <float>, \"hallucinated_claims\": [\"<claim1>\", ...], \"explanation\": \"<text>\"}}"
)
}
fn parse_response(text: &str) -> (f64, String) {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
return extract_score_and_explanation(&v);
}
if let Some(start) = text.find('{') {
if let Some(end) = text[start..].rfind('}') {
let json_str = &text[start..=start + end];
if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
return extract_score_and_explanation(&v);
}
}
}
(
0.0,
format!("Failed to parse hallucination judge response: {text}"),
)
}
}
impl Default for HallucinationEvaluator {
fn default() -> Self {
Self::new()
}
}
fn extract_score_and_explanation(v: &serde_json::Value) -> (f64, String) {
let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
let mut explanation = v["explanation"]
.as_str()
.unwrap_or("No explanation")
.to_string();
if let Some(claims) = v["hallucinated_claims"].as_array() {
let claim_strs: Vec<&str> = claims.iter().filter_map(|c| c.as_str()).collect();
if !claim_strs.is_empty() {
explanation.push_str(&format!(
" | Hallucinated claims: {}",
claim_strs.join("; ")
));
}
}
(score, explanation)
}
#[async_trait]
impl Evaluator for HallucinationEvaluator {
async fn evaluate(
&self,
actual: &[Invocation],
_expected: Option<&[Invocation]>,
) -> Result<EvalResult, EvalError> {
let llm = self
.llm
.as_ref()
.ok_or_else(|| EvalError::Llm("HallucinationEvaluator requires an LLM instance — call .with_llm() before evaluating".into()))?;
let mut per_invocation = Vec::new();
let mut total_score = 0.0;
for (i, actual_inv) in actual.iter().enumerate() {
let context = Self::extract_context(actual_inv);
let responses = Self::extract_responses(actual_inv, self.evaluate_intermediate);
if responses.is_empty() {
per_invocation.push(PerInvocationResult {
invocation_id: inv_id(actual_inv, i),
score: 1.0,
explanation: Some("No model responses to evaluate".into()),
});
total_score += 1.0;
continue;
}
let mut resp_total = 0.0;
let mut explanations = Vec::new();
for response in &responses {
let prompt = Self::build_prompt(&context, response);
let request = crate::llm::LlmRequest::from_text(&prompt);
let llm_response = llm
.generate(request)
.await
.map_err(|e| EvalError::Llm(e.to_string()))?;
let (score, explanation) = Self::parse_response(&llm_response.text());
resp_total += score;
explanations.push(explanation);
}
let avg_score = resp_total / responses.len() as f64;
total_score += avg_score;
per_invocation.push(PerInvocationResult {
invocation_id: inv_id(actual_inv, i),
score: avg_score,
explanation: Some(explanations.join(" | ")),
});
}
let overall_score = if actual.is_empty() {
0.0
} else {
total_score / actual.len() as f64
};
Ok(EvalResult {
overall_score,
metrics: vec![EvalMetric {
name: "groundedness".into(),
score: overall_score,
per_invocation,
}],
})
}
}
fn inv_id(inv: &Invocation, index: usize) -> String {
if inv.id.is_empty() {
format!("inv-{index}")
} else {
inv.id.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::evaluation::eval_case::InvocationTurn;
#[test]
fn extract_context_includes_user_and_tools() {
let inv = Invocation {
id: "test".into(),
turns: vec![
InvocationTurn {
role: "user".into(),
content: "What is the weather?".into(),
tool_calls: vec![],
tool_results: vec![],
},
InvocationTurn {
role: "model".into(),
content: String::new(),
tool_calls: vec![serde_json::json!({"name": "get_weather"})],
tool_results: vec![serde_json::json!({"temp": 22})],
},
InvocationTurn {
role: "model".into(),
content: "It's 22 degrees.".into(),
tool_calls: vec![],
tool_results: vec![],
},
],
metadata: serde_json::Value::Null,
};
let context = HallucinationEvaluator::extract_context(&inv);
assert!(context.contains("What is the weather?"));
assert!(context.contains("22"));
}
#[test]
fn extract_responses_final_only() {
let inv = Invocation {
id: "test".into(),
turns: vec![
InvocationTurn {
role: "model".into(),
content: "first".into(),
tool_calls: vec![],
tool_results: vec![],
},
InvocationTurn {
role: "model".into(),
content: "second".into(),
tool_calls: vec![],
tool_results: vec![],
},
],
metadata: serde_json::Value::Null,
};
let responses = HallucinationEvaluator::extract_responses(&inv, false);
assert_eq!(responses.len(), 1);
assert_eq!(responses[0], "second");
}
#[test]
fn extract_responses_all() {
let inv = Invocation {
id: "test".into(),
turns: vec![
InvocationTurn {
role: "model".into(),
content: "first".into(),
tool_calls: vec![],
tool_results: vec![],
},
InvocationTurn {
role: "model".into(),
content: "second".into(),
tool_calls: vec![],
tool_results: vec![],
},
],
metadata: serde_json::Value::Null,
};
let responses = HallucinationEvaluator::extract_responses(&inv, true);
assert_eq!(responses.len(), 2);
}
#[test]
fn parse_valid_response() {
let json = r#"{"score": 0.9, "hallucinated_claims": [], "explanation": "Well grounded"}"#;
let (score, explanation) = HallucinationEvaluator::parse_response(json);
assert!((score - 0.9).abs() < f64::EPSILON);
assert!(explanation.contains("Well grounded"));
}
#[test]
fn parse_response_with_claims() {
let json = r#"{"score": 0.5, "hallucinated_claims": ["temp was 25 not 22"], "explanation": "Partial"}"#;
let (score, explanation) = HallucinationEvaluator::parse_response(json);
assert!((score - 0.5).abs() < f64::EPSILON);
assert!(explanation.contains("temp was 25 not 22"));
}
#[test]
fn parse_invalid() {
let (score, explanation) = HallucinationEvaluator::parse_response("garbage");
assert!((score - 0.0).abs() < f64::EPSILON);
assert!(explanation.contains("Failed to parse"));
}
#[test]
fn default_impl() {
let eval = HallucinationEvaluator::default();
assert!(!eval.evaluate_intermediate);
assert!(eval.judge_model.is_none());
}
#[test]
fn builder_methods() {
let eval = HallucinationEvaluator::new()
.with_intermediate(true)
.with_judge_model("gemini-2.0-flash");
assert!(eval.evaluate_intermediate);
assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
}
#[test]
fn build_prompt_structure() {
let prompt = HallucinationEvaluator::build_prompt(
"USER INPUT: What is 2+2?\nTOOL RESULT: {\"answer\": 4}",
"The answer is 4.",
);
assert!(prompt.contains("GROUNDEDNESS"));
assert!(prompt.contains("GROUNDING CONTEXT"));
assert!(prompt.contains("What is 2+2?"));
assert!(prompt.contains("The answer is 4."));
}
}