oxirs_embed/application_tasks/
retrieval.rs

1//! Retrieval evaluation module
2//!
3//! This module provides comprehensive evaluation for information retrieval tasks
4//! using embedding models, including document ranking and retrieval effectiveness.
5
6use super::ApplicationEvalConfig;
7use crate::EmbeddingModel;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Retrieval evaluation metrics
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum RetrievalMetric {
15    /// Precision at K
16    PrecisionAtK(usize),
17    /// Recall at K
18    RecallAtK(usize),
19    /// Mean Average Precision
20    MAP,
21    /// Normalized Discounted Cumulative Gain
22    NDCG(usize),
23    /// Mean Reciprocal Rank
24    MRR,
25    /// F1 Score at K
26    F1AtK(usize),
27}
28
29/// Document metadata for retrieval
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct DocumentMetadata {
32    /// Document identifier
33    pub doc_id: String,
34    /// Document title
35    pub title: String,
36    /// Document content
37    pub content: String,
38    /// Document category
39    pub category: String,
40    /// Document embedding (if available)
41    pub embedding: Option<Vec<f32>>,
42    /// Relevance score for queries
43    pub relevance_scores: HashMap<String, f64>,
44}
45
46/// Retrieval query
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct RetrievalQuery {
49    /// Query identifier
50    pub query_id: String,
51    /// Query text
52    pub query_text: String,
53    /// Relevant document IDs
54    pub relevant_docs: Vec<String>,
55    /// Query type
56    pub query_type: String,
57}
58
59/// Retrieval analysis
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct RetrievalAnalysis {
62    /// Query performance by type
63    pub performance_by_type: HashMap<String, f64>,
64    /// Document coverage statistics
65    pub document_coverage: f64,
66    /// Query completion rate
67    pub completion_rate: f64,
68    /// Average response time
69    pub avg_response_time: f64,
70}
71
72/// Retrieval evaluation results
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RetrievalResults {
75    /// Metric scores
76    pub metric_scores: HashMap<String, f64>,
77    /// Per-query results
78    pub per_query_results: HashMap<String, QueryRetrievalResults>,
79    /// Retrieval analysis
80    pub retrieval_analysis: RetrievalAnalysis,
81    /// Overall retrieval quality
82    pub overall_quality: f64,
83}
84
85/// Per-query retrieval results
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct QueryRetrievalResults {
88    /// Query ID
89    pub query_id: String,
90    /// Retrieved documents with scores
91    pub retrieved_docs: Vec<(String, f64)>,
92    /// Precision at different K values
93    pub precision_at_k: HashMap<usize, f64>,
94    /// Recall at different K values
95    pub recall_at_k: HashMap<usize, f64>,
96    /// NDCG scores
97    pub ndcg_scores: HashMap<usize, f64>,
98    /// Response time (milliseconds)
99    pub response_time: f64,
100}
101
102/// Retrieval evaluator
103pub struct RetrievalEvaluator {
104    /// Document collection
105    documents: HashMap<String, DocumentMetadata>,
106    /// Retrieval queries
107    queries: Vec<RetrievalQuery>,
108    /// Evaluation metrics
109    metrics: Vec<RetrievalMetric>,
110}
111
112impl RetrievalEvaluator {
113    /// Create a new retrieval evaluator
114    pub fn new() -> Self {
115        Self {
116            documents: HashMap::new(),
117            queries: Vec::new(),
118            metrics: vec![
119                RetrievalMetric::PrecisionAtK(5),
120                RetrievalMetric::PrecisionAtK(10),
121                RetrievalMetric::RecallAtK(5),
122                RetrievalMetric::RecallAtK(10),
123                RetrievalMetric::NDCG(10),
124                RetrievalMetric::MAP,
125                RetrievalMetric::MRR,
126            ],
127        }
128    }
129
130    /// Add document to collection
131    pub fn add_document(&mut self, document: DocumentMetadata) {
132        self.documents.insert(document.doc_id.clone(), document);
133    }
134
135    /// Add retrieval query
136    pub fn add_query(&mut self, query: RetrievalQuery) {
137        self.queries.push(query);
138    }
139
140    /// Evaluate retrieval performance
141    pub async fn evaluate(
142        &self,
143        model: &dyn EmbeddingModel,
144        config: &ApplicationEvalConfig,
145    ) -> Result<RetrievalResults> {
146        let mut metric_scores = HashMap::new();
147        let mut per_query_results = HashMap::new();
148
149        // Evaluate each query
150        let queries_to_evaluate = if self.queries.len() > config.sample_size {
151            &self.queries[..config.sample_size]
152        } else {
153            &self.queries
154        };
155
156        for query in queries_to_evaluate {
157            let query_results = self.evaluate_query_retrieval(query, model).await?;
158            per_query_results.insert(query.query_id.clone(), query_results);
159        }
160
161        // Calculate aggregate metrics
162        for metric in &self.metrics {
163            let score = self.calculate_retrieval_metric(metric, &per_query_results)?;
164            metric_scores.insert(format!("{metric:?}"), score);
165        }
166
167        // Generate retrieval analysis
168        let retrieval_analysis = self.analyze_retrieval_performance(&per_query_results)?;
169
170        // Calculate overall quality score
171        let overall_quality = self.calculate_overall_quality(&metric_scores);
172
173        Ok(RetrievalResults {
174            metric_scores,
175            per_query_results,
176            retrieval_analysis,
177            overall_quality,
178        })
179    }
180
181    /// Evaluate retrieval for a specific query
182    async fn evaluate_query_retrieval(
183        &self,
184        query: &RetrievalQuery,
185        model: &dyn EmbeddingModel,
186    ) -> Result<QueryRetrievalResults> {
187        let start_time = std::time::Instant::now();
188
189        // Perform document retrieval
190        let retrieved_docs = self.retrieve_documents(query, model).await?;
191
192        let response_time = start_time.elapsed().as_millis() as f64;
193
194        // Calculate precision and recall at different K values
195        let mut precision_at_k = HashMap::new();
196        let mut recall_at_k = HashMap::new();
197        let mut ndcg_scores = HashMap::new();
198
199        let relevant_set: std::collections::HashSet<String> =
200            query.relevant_docs.iter().cloned().collect();
201
202        for &k in &[1, 3, 5, 10, 20] {
203            if k <= retrieved_docs.len() {
204                let top_k_docs: std::collections::HashSet<String> = retrieved_docs
205                    .iter()
206                    .take(k)
207                    .map(|(doc_id, _)| doc_id.clone())
208                    .collect();
209
210                let relevant_retrieved = top_k_docs.intersection(&relevant_set).count();
211
212                let precision = relevant_retrieved as f64 / k as f64;
213                let recall = if !query.relevant_docs.is_empty() {
214                    relevant_retrieved as f64 / query.relevant_docs.len() as f64
215                } else {
216                    0.0
217                };
218
219                precision_at_k.insert(k, precision);
220                recall_at_k.insert(k, recall);
221
222                // Calculate NDCG (simplified)
223                let ndcg = self.calculate_ndcg_for_query(&retrieved_docs, &relevant_set, k);
224                ndcg_scores.insert(k, ndcg);
225            }
226        }
227
228        Ok(QueryRetrievalResults {
229            query_id: query.query_id.clone(),
230            retrieved_docs,
231            precision_at_k,
232            recall_at_k,
233            ndcg_scores,
234            response_time,
235        })
236    }
237
238    /// Retrieve documents for a query
239    async fn retrieve_documents(
240        &self,
241        query: &RetrievalQuery,
242        _model: &dyn EmbeddingModel,
243    ) -> Result<Vec<(String, f64)>> {
244        // Simple retrieval using text similarity (placeholder)
245        let mut doc_scores = Vec::new();
246
247        for (doc_id, doc) in &self.documents {
248            // Calculate relevance score based on text overlap
249            let query_words: std::collections::HashSet<&str> =
250                query.query_text.split_whitespace().collect();
251            let doc_words: std::collections::HashSet<&str> =
252                doc.content.split_whitespace().collect();
253
254            let overlap = query_words.intersection(&doc_words).count();
255            let score = overlap as f64 / query_words.len() as f64;
256
257            doc_scores.push((doc_id.clone(), score));
258        }
259
260        // Sort by score and return top documents
261        doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
262        doc_scores.truncate(20); // Return top 20 documents
263
264        Ok(doc_scores)
265    }
266
267    /// Calculate NDCG for a query
268    fn calculate_ndcg_for_query(
269        &self,
270        retrieved_docs: &[(String, f64)],
271        relevant_docs: &std::collections::HashSet<String>,
272        k: usize,
273    ) -> f64 {
274        if k == 0 || retrieved_docs.is_empty() {
275            return 0.0;
276        }
277
278        let mut dcg = 0.0;
279        for (i, (doc_id, _)) in retrieved_docs.iter().take(k).enumerate() {
280            if relevant_docs.contains(doc_id) {
281                dcg += 1.0 / (i as f64 + 2.0).log2();
282            }
283        }
284
285        // Calculate ideal DCG
286        let relevant_count = relevant_docs.len().min(k);
287        let mut idcg = 0.0;
288        for i in 0..relevant_count {
289            idcg += 1.0 / (i as f64 + 2.0).log2();
290        }
291
292        if idcg > 0.0 {
293            dcg / idcg
294        } else {
295            0.0
296        }
297    }
298
299    /// Calculate aggregate retrieval metric
300    fn calculate_retrieval_metric(
301        &self,
302        metric: &RetrievalMetric,
303        per_query_results: &HashMap<String, QueryRetrievalResults>,
304    ) -> Result<f64> {
305        if per_query_results.is_empty() {
306            return Ok(0.0);
307        }
308
309        match metric {
310            RetrievalMetric::PrecisionAtK(k) => {
311                let scores: Vec<f64> = per_query_results
312                    .values()
313                    .filter_map(|r| r.precision_at_k.get(k))
314                    .cloned()
315                    .collect();
316                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
317            }
318            RetrievalMetric::RecallAtK(k) => {
319                let scores: Vec<f64> = per_query_results
320                    .values()
321                    .filter_map(|r| r.recall_at_k.get(k))
322                    .cloned()
323                    .collect();
324                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
325            }
326            RetrievalMetric::NDCG(k) => {
327                let scores: Vec<f64> = per_query_results
328                    .values()
329                    .filter_map(|r| r.ndcg_scores.get(k))
330                    .cloned()
331                    .collect();
332                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
333            }
334            _ => Ok(0.5), // Placeholder for other metrics
335        }
336    }
337
338    /// Analyze retrieval performance
339    fn analyze_retrieval_performance(
340        &self,
341        per_query_results: &HashMap<String, QueryRetrievalResults>,
342    ) -> Result<RetrievalAnalysis> {
343        let avg_response_time = per_query_results
344            .values()
345            .map(|r| r.response_time)
346            .sum::<f64>()
347            / per_query_results.len() as f64;
348
349        let completion_rate = per_query_results
350            .values()
351            .filter(|r| !r.retrieved_docs.is_empty())
352            .count() as f64
353            / per_query_results.len() as f64;
354
355        Ok(RetrievalAnalysis {
356            performance_by_type: HashMap::new(), // Simplified
357            document_coverage: 0.8,              // Simplified
358            completion_rate,
359            avg_response_time,
360        })
361    }
362
363    /// Calculate overall quality score
364    fn calculate_overall_quality(&self, metric_scores: &HashMap<String, f64>) -> f64 {
365        let relevant_metrics = ["PrecisionAtK(10)", "RecallAtK(10)", "NDCG(10)"];
366        let mut total_score = 0.0;
367        let mut count = 0;
368
369        for metric_name in &relevant_metrics {
370            if let Some(&score) = metric_scores.get(*metric_name) {
371                total_score += score;
372                count += 1;
373            }
374        }
375
376        if count > 0 {
377            total_score / count as f64
378        } else {
379            0.0
380        }
381    }
382}
383
384impl Default for RetrievalEvaluator {
385    fn default() -> Self {
386        Self::new()
387    }
388}