use async_trait::async_trait;
use crate::{rag::RAGError, schemas::Document};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RetrievalValidationResult {
pub is_valid: bool,
pub confidence: f64,
pub feedback: Option<String>,
pub suggestions: Vec<String>,
}
impl RetrievalValidationResult {
pub fn valid(confidence: f64) -> Self {
Self {
is_valid: true,
confidence,
feedback: None,
suggestions: Vec::new(),
}
}
pub fn invalid(confidence: f64, feedback: String, suggestions: Vec<String>) -> Self {
Self {
is_valid: false,
confidence,
feedback: Some(feedback),
suggestions,
}
}
}
#[async_trait]
pub trait RetrievalValidator: Send + Sync {
async fn validate(
&self,
query: &str,
documents: &[Document],
) -> Result<RetrievalValidationResult, RAGError>;
}
pub struct RelevanceValidator {
min_documents: usize,
min_relevance: Option<f64>,
}
impl RelevanceValidator {
pub fn new() -> Self {
Self {
min_documents: 1,
min_relevance: None,
}
}
pub fn with_min_documents(mut self, min: usize) -> Self {
self.min_documents = min;
self
}
pub fn with_min_relevance(mut self, threshold: f64) -> Self {
self.min_relevance = Some(threshold);
self
}
}
impl Default for RelevanceValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schemas::Document;
#[tokio::test]
async fn test_relevance_validator_min_documents() {
let validator = RelevanceValidator::new().with_min_documents(2);
let docs = vec![Document::new("test")];
let result = validator.validate("test query", &docs).await.unwrap();
assert!(!result.is_valid);
}
#[tokio::test]
async fn test_relevance_validator_sufficient_documents() {
let validator = RelevanceValidator::new().with_min_documents(1);
let docs = vec![Document::new("test content")];
let result = validator.validate("test query", &docs).await.unwrap();
assert!(result.is_valid);
}
}
#[async_trait]
impl RetrievalValidator for RelevanceValidator {
async fn validate(
&self,
_query: &str,
documents: &[Document],
) -> Result<RetrievalValidationResult, RAGError> {
if documents.len() < self.min_documents {
return Ok(RetrievalValidationResult::invalid(
0.0,
format!(
"Insufficient documents retrieved: got {}, need at least {}",
documents.len(),
self.min_documents
),
vec!["Try expanding the query or using different search terms".to_string()],
));
}
if let Some(min_relevance) = self.min_relevance {
let mut all_relevant = true;
let mut avg_score = 0.0;
let mut count = 0;
for doc in documents {
if let Some(score_val) = doc.metadata.get("score") {
if let Some(score) = score_val.as_f64() {
avg_score += score;
count += 1;
if score < min_relevance {
all_relevant = false;
}
}
}
}
if count > 0 {
avg_score /= count as f64;
if !all_relevant || avg_score < min_relevance {
return Ok(RetrievalValidationResult::invalid(
avg_score,
format!("Some documents have low relevance scores (avg: {:.2}, min required: {:.2})",
avg_score, min_relevance),
vec!["Try refining the query to be more specific".to_string()],
));
}
}
}
Ok(RetrievalValidationResult::valid(0.8))
}
}
pub struct LLMRetrievalValidator {
llm: Box<dyn crate::language_models::llm::LLM>,
validation_prompt: Option<String>,
}
impl LLMRetrievalValidator {
pub fn new(llm: Box<dyn crate::language_models::llm::LLM>) -> Self {
Self {
llm,
validation_prompt: None,
}
}
pub fn with_prompt<S: Into<String>>(mut self, prompt: S) -> Self {
self.validation_prompt = Some(prompt.into());
self
}
}
#[async_trait]
impl RetrievalValidator for LLMRetrievalValidator {
async fn validate(
&self,
query: &str,
documents: &[Document],
) -> Result<RetrievalValidationResult, RAGError> {
if documents.is_empty() {
return Ok(RetrievalValidationResult::invalid(
0.0,
"No documents retrieved".to_string(),
vec!["Try expanding the query".to_string()],
));
}
let doc_texts: Vec<String> = documents
.iter()
.take(5) .map(|doc| format!("[Document]\n{}\n", doc.page_content))
.collect();
let prompt = self.validation_prompt.as_deref().unwrap_or(
"Evaluate whether the following documents are relevant and sufficient to answer the query.\n\n\
Query: {query}\n\n\
Documents:\n{documents}\n\n\
Respond with JSON: {{\"is_valid\": true/false, \"confidence\": 0.0-1.0, \"feedback\": \"...\", \"suggestions\": [\"...\"]}}"
);
let formatted_prompt = prompt
.replace("{query}", query)
.replace("{documents}", &doc_texts.join("\n---\n"));
let response = self
.llm
.invoke(&formatted_prompt)
.await
.map_err(|e| RAGError::RetrievalValidationError(format!("LLM error: {}", e)))?;
match serde_json::from_str::<RetrievalValidationResult>(&response) {
Ok(result) => Ok(result),
Err(_) => {
let is_valid = response.to_lowercase().contains("valid")
|| response.to_lowercase().contains("yes")
|| response.to_lowercase().contains("sufficient");
Ok(RetrievalValidationResult {
is_valid,
confidence: if is_valid { 0.7 } else { 0.3 },
feedback: Some(response),
suggestions: Vec::new(),
})
}
}
}
}