use super::ApplicationEvalConfig;
use crate::EmbeddingModel;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RetrievalMetric {
PrecisionAtK(usize),
RecallAtK(usize),
MAP,
NDCG(usize),
MRR,
F1AtK(usize),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentMetadata {
pub doc_id: String,
pub title: String,
pub content: String,
pub category: String,
pub embedding: Option<Vec<f32>>,
pub relevance_scores: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalQuery {
pub query_id: String,
pub query_text: String,
pub relevant_docs: Vec<String>,
pub query_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalAnalysis {
pub performance_by_type: HashMap<String, f64>,
pub document_coverage: f64,
pub completion_rate: f64,
pub avg_response_time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalResults {
pub metric_scores: HashMap<String, f64>,
pub per_query_results: HashMap<String, QueryRetrievalResults>,
pub retrieval_analysis: RetrievalAnalysis,
pub overall_quality: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRetrievalResults {
pub query_id: String,
pub retrieved_docs: Vec<(String, f64)>,
pub precision_at_k: HashMap<usize, f64>,
pub recall_at_k: HashMap<usize, f64>,
pub ndcg_scores: HashMap<usize, f64>,
pub response_time: f64,
}
pub struct RetrievalEvaluator {
documents: HashMap<String, DocumentMetadata>,
queries: Vec<RetrievalQuery>,
metrics: Vec<RetrievalMetric>,
}
impl RetrievalEvaluator {
pub fn new() -> Self {
Self {
documents: HashMap::new(),
queries: Vec::new(),
metrics: vec![
RetrievalMetric::PrecisionAtK(5),
RetrievalMetric::PrecisionAtK(10),
RetrievalMetric::RecallAtK(5),
RetrievalMetric::RecallAtK(10),
RetrievalMetric::NDCG(10),
RetrievalMetric::MAP,
RetrievalMetric::MRR,
],
}
}
pub fn add_document(&mut self, document: DocumentMetadata) {
self.documents.insert(document.doc_id.clone(), document);
}
pub fn add_query(&mut self, query: RetrievalQuery) {
self.queries.push(query);
}
pub async fn evaluate(
&self,
model: &dyn EmbeddingModel,
config: &ApplicationEvalConfig,
) -> Result<RetrievalResults> {
let mut metric_scores = HashMap::new();
let mut per_query_results = HashMap::new();
let queries_to_evaluate = if self.queries.len() > config.sample_size {
&self.queries[..config.sample_size]
} else {
&self.queries
};
for query in queries_to_evaluate {
let query_results = self.evaluate_query_retrieval(query, model).await?;
per_query_results.insert(query.query_id.clone(), query_results);
}
for metric in &self.metrics {
let score = self.calculate_retrieval_metric(metric, &per_query_results)?;
metric_scores.insert(format!("{metric:?}"), score);
}
let retrieval_analysis = self.analyze_retrieval_performance(&per_query_results)?;
let overall_quality = self.calculate_overall_quality(&metric_scores);
Ok(RetrievalResults {
metric_scores,
per_query_results,
retrieval_analysis,
overall_quality,
})
}
async fn evaluate_query_retrieval(
&self,
query: &RetrievalQuery,
model: &dyn EmbeddingModel,
) -> Result<QueryRetrievalResults> {
let start_time = std::time::Instant::now();
let retrieved_docs = self.retrieve_documents(query, model).await?;
let response_time = start_time.elapsed().as_millis() as f64;
let mut precision_at_k = HashMap::new();
let mut recall_at_k = HashMap::new();
let mut ndcg_scores = HashMap::new();
let relevant_set: std::collections::HashSet<String> =
query.relevant_docs.iter().cloned().collect();
for &k in &[1, 3, 5, 10, 20] {
if k <= retrieved_docs.len() {
let top_k_docs: std::collections::HashSet<String> = retrieved_docs
.iter()
.take(k)
.map(|(doc_id, _)| doc_id.clone())
.collect();
let relevant_retrieved = top_k_docs.intersection(&relevant_set).count();
let precision = relevant_retrieved as f64 / k as f64;
let recall = if !query.relevant_docs.is_empty() {
relevant_retrieved as f64 / query.relevant_docs.len() as f64
} else {
0.0
};
precision_at_k.insert(k, precision);
recall_at_k.insert(k, recall);
let ndcg = self.calculate_ndcg_for_query(&retrieved_docs, &relevant_set, k);
ndcg_scores.insert(k, ndcg);
}
}
Ok(QueryRetrievalResults {
query_id: query.query_id.clone(),
retrieved_docs,
precision_at_k,
recall_at_k,
ndcg_scores,
response_time,
})
}
async fn retrieve_documents(
&self,
query: &RetrievalQuery,
_model: &dyn EmbeddingModel,
) -> Result<Vec<(String, f64)>> {
let mut doc_scores = Vec::new();
for (doc_id, doc) in &self.documents {
let query_words: std::collections::HashSet<&str> =
query.query_text.split_whitespace().collect();
let doc_words: std::collections::HashSet<&str> =
doc.content.split_whitespace().collect();
let overlap = query_words.intersection(&doc_words).count();
let score = overlap as f64 / query_words.len() as f64;
doc_scores.push((doc_id.clone(), score));
}
doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
doc_scores.truncate(20);
Ok(doc_scores)
}
fn calculate_ndcg_for_query(
&self,
retrieved_docs: &[(String, f64)],
relevant_docs: &std::collections::HashSet<String>,
k: usize,
) -> f64 {
if k == 0 || retrieved_docs.is_empty() {
return 0.0;
}
let mut dcg = 0.0;
for (i, (doc_id, _)) in retrieved_docs.iter().take(k).enumerate() {
if relevant_docs.contains(doc_id) {
dcg += 1.0 / (i as f64 + 2.0).log2();
}
}
let relevant_count = relevant_docs.len().min(k);
let mut idcg = 0.0;
for i in 0..relevant_count {
idcg += 1.0 / (i as f64 + 2.0).log2();
}
if idcg > 0.0 {
dcg / idcg
} else {
0.0
}
}
fn calculate_retrieval_metric(
&self,
metric: &RetrievalMetric,
per_query_results: &HashMap<String, QueryRetrievalResults>,
) -> Result<f64> {
if per_query_results.is_empty() {
return Ok(0.0);
}
match metric {
RetrievalMetric::PrecisionAtK(k) => {
let scores: Vec<f64> = per_query_results
.values()
.filter_map(|r| r.precision_at_k.get(k))
.cloned()
.collect();
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
RetrievalMetric::RecallAtK(k) => {
let scores: Vec<f64> = per_query_results
.values()
.filter_map(|r| r.recall_at_k.get(k))
.cloned()
.collect();
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
RetrievalMetric::NDCG(k) => {
let scores: Vec<f64> = per_query_results
.values()
.filter_map(|r| r.ndcg_scores.get(k))
.cloned()
.collect();
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
_ => Ok(0.5), }
}
fn analyze_retrieval_performance(
&self,
per_query_results: &HashMap<String, QueryRetrievalResults>,
) -> Result<RetrievalAnalysis> {
let avg_response_time = per_query_results
.values()
.map(|r| r.response_time)
.sum::<f64>()
/ per_query_results.len() as f64;
let completion_rate = per_query_results
.values()
.filter(|r| !r.retrieved_docs.is_empty())
.count() as f64
/ per_query_results.len() as f64;
Ok(RetrievalAnalysis {
performance_by_type: HashMap::new(), document_coverage: 0.8, completion_rate,
avg_response_time,
})
}
fn calculate_overall_quality(&self, metric_scores: &HashMap<String, f64>) -> f64 {
let relevant_metrics = ["PrecisionAtK(10)", "RecallAtK(10)", "NDCG(10)"];
let mut total_score = 0.0;
let mut count = 0;
for metric_name in &relevant_metrics {
if let Some(&score) = metric_scores.get(*metric_name) {
total_score += score;
count += 1;
}
}
if count > 0 {
total_score / count as f64
} else {
0.0
}
}
}
impl Default for RetrievalEvaluator {
fn default() -> Self {
Self::new()
}
}