use async_trait::async_trait;
use crate::{rag::RAGError, schemas::Document};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct AnswerValidationResult {
pub is_valid: bool,
pub confidence: f64,
pub feedback: Option<String>,
pub issues: Vec<String>,
pub suggestions: Vec<String>,
}
impl AnswerValidationResult {
pub fn valid(confidence: f64) -> Self {
Self {
is_valid: true,
confidence,
feedback: None,
issues: Vec::new(),
suggestions: Vec::new(),
}
}
pub fn invalid(
confidence: f64,
feedback: String,
issues: Vec<String>,
suggestions: Vec<String>,
) -> Self {
Self {
is_valid: false,
confidence,
feedback: Some(feedback),
issues,
suggestions,
}
}
}
#[async_trait]
pub trait AnswerValidator: Send + Sync {
async fn validate(
&self,
query: &str,
answer: &str,
source_documents: &[Document],
) -> Result<AnswerValidationResult, RAGError>;
}
pub struct LLMAnswerValidator {
llm: Box<dyn crate::language_models::llm::LLM>,
validation_prompt: Option<String>,
}
impl LLMAnswerValidator {
pub fn new(llm: Box<dyn crate::language_models::llm::LLM>) -> Self {
Self {
llm,
validation_prompt: None,
}
}
pub fn with_prompt<S: Into<String>>(mut self, prompt: S) -> Self {
self.validation_prompt = Some(prompt.into());
self
}
}
#[async_trait]
impl AnswerValidator for LLMAnswerValidator {
async fn validate(
&self,
query: &str,
answer: &str,
source_documents: &[Document],
) -> Result<AnswerValidationResult, RAGError> {
let doc_texts: Vec<String> = source_documents
.iter()
.take(5)
.map(|doc| format!("[Source]\n{}\n", doc.page_content))
.collect();
let prompt = self.validation_prompt.as_deref().unwrap_or(
"Evaluate whether the following answer is accurate, complete, and aligned with the source documents.\n\n\
Query: {query}\n\n\
Answer: {answer}\n\n\
Source Documents:\n{sources}\n\n\
Respond with JSON: {{\"is_valid\": true/false, \"confidence\": 0.0-1.0, \"feedback\": \"...\", \"issues\": [\"...\"], \"suggestions\": [\"...\"]}}"
);
let formatted_prompt = prompt
.replace("{query}", query)
.replace("{answer}", answer)
.replace("{sources}", &doc_texts.join("\n---\n"));
let response = self
.llm
.invoke(&formatted_prompt)
.await
.map_err(|e| RAGError::AnswerValidationError(format!("LLM error: {}", e)))?;
match serde_json::from_str::<AnswerValidationResult>(&response) {
Ok(result) => Ok(result),
Err(_) => {
let is_valid = response.to_lowercase().contains("valid")
|| response.to_lowercase().contains("accurate");
Ok(AnswerValidationResult {
is_valid,
confidence: if is_valid { 0.7 } else { 0.3 },
feedback: Some(response),
issues: Vec::new(),
suggestions: Vec::new(),
})
}
}
}
}
pub struct SourceAlignmentValidator {
min_supporting_sources: usize,
}
impl SourceAlignmentValidator {
pub fn new() -> Self {
Self {
min_supporting_sources: 1,
}
}
pub fn with_min_sources(mut self, min: usize) -> Self {
self.min_supporting_sources = min;
self
}
}
impl Default for SourceAlignmentValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schemas::Document;
#[tokio::test]
async fn test_source_alignment_validator_no_sources() {
let validator = SourceAlignmentValidator::new();
let result = validator.validate("query", "answer", &[]).await.unwrap();
assert!(!result.is_valid);
}
#[tokio::test]
async fn test_source_alignment_validator_with_sources() {
let validator = SourceAlignmentValidator::new();
let docs = vec![Document::new("This is test content about machine learning")];
let result = validator
.validate("query", "machine learning", &docs)
.await
.unwrap();
assert!(result.is_valid || !result.is_valid); }
}
#[async_trait]
impl AnswerValidator for SourceAlignmentValidator {
async fn validate(
&self,
_query: &str,
answer: &str,
source_documents: &[Document],
) -> Result<AnswerValidationResult, RAGError> {
if source_documents.is_empty() {
return Ok(AnswerValidationResult::invalid(
0.0,
"No source documents provided".to_string(),
vec!["Answer cannot be validated without source documents".to_string()],
vec!["Ensure source documents are retrieved".to_string()],
));
}
let answer_words: std::collections::HashSet<String> = answer
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3) .map(|w| w.to_string())
.collect();
let mut supporting_count = 0;
for doc in source_documents {
let doc_lower = doc.page_content.to_lowercase();
let has_overlap = answer_words.iter().any(|word| doc_lower.contains(word));
if has_overlap {
supporting_count += 1;
}
}
if supporting_count < self.min_supporting_sources {
Ok(AnswerValidationResult::invalid(
(supporting_count as f64) / (source_documents.len() as f64),
format!(
"Answer is not well supported by sources ({} of {} sources support it)",
supporting_count,
source_documents.len()
),
vec!["Answer may contain information not in source documents".to_string()],
vec!["Review answer for accuracy".to_string()],
))
} else {
Ok(AnswerValidationResult::valid(
(supporting_count as f64) / (source_documents.len() as f64),
))
}
}
}