oxirs_embed/application_tasks/
retrieval.rs1use super::ApplicationEvalConfig;
7use crate::EmbeddingModel;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum RetrievalMetric {
15 PrecisionAtK(usize),
17 RecallAtK(usize),
19 MAP,
21 NDCG(usize),
23 MRR,
25 F1AtK(usize),
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct DocumentMetadata {
32 pub doc_id: String,
34 pub title: String,
36 pub content: String,
38 pub category: String,
40 pub embedding: Option<Vec<f32>>,
42 pub relevance_scores: HashMap<String, f64>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct RetrievalQuery {
49 pub query_id: String,
51 pub query_text: String,
53 pub relevant_docs: Vec<String>,
55 pub query_type: String,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct RetrievalAnalysis {
62 pub performance_by_type: HashMap<String, f64>,
64 pub document_coverage: f64,
66 pub completion_rate: f64,
68 pub avg_response_time: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RetrievalResults {
75 pub metric_scores: HashMap<String, f64>,
77 pub per_query_results: HashMap<String, QueryRetrievalResults>,
79 pub retrieval_analysis: RetrievalAnalysis,
81 pub overall_quality: f64,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct QueryRetrievalResults {
88 pub query_id: String,
90 pub retrieved_docs: Vec<(String, f64)>,
92 pub precision_at_k: HashMap<usize, f64>,
94 pub recall_at_k: HashMap<usize, f64>,
96 pub ndcg_scores: HashMap<usize, f64>,
98 pub response_time: f64,
100}
101
102pub struct RetrievalEvaluator {
104 documents: HashMap<String, DocumentMetadata>,
106 queries: Vec<RetrievalQuery>,
108 metrics: Vec<RetrievalMetric>,
110}
111
112impl RetrievalEvaluator {
113 pub fn new() -> Self {
115 Self {
116 documents: HashMap::new(),
117 queries: Vec::new(),
118 metrics: vec![
119 RetrievalMetric::PrecisionAtK(5),
120 RetrievalMetric::PrecisionAtK(10),
121 RetrievalMetric::RecallAtK(5),
122 RetrievalMetric::RecallAtK(10),
123 RetrievalMetric::NDCG(10),
124 RetrievalMetric::MAP,
125 RetrievalMetric::MRR,
126 ],
127 }
128 }
129
130 pub fn add_document(&mut self, document: DocumentMetadata) {
132 self.documents.insert(document.doc_id.clone(), document);
133 }
134
135 pub fn add_query(&mut self, query: RetrievalQuery) {
137 self.queries.push(query);
138 }
139
140 pub async fn evaluate(
142 &self,
143 model: &dyn EmbeddingModel,
144 config: &ApplicationEvalConfig,
145 ) -> Result<RetrievalResults> {
146 let mut metric_scores = HashMap::new();
147 let mut per_query_results = HashMap::new();
148
149 let queries_to_evaluate = if self.queries.len() > config.sample_size {
151 &self.queries[..config.sample_size]
152 } else {
153 &self.queries
154 };
155
156 for query in queries_to_evaluate {
157 let query_results = self.evaluate_query_retrieval(query, model).await?;
158 per_query_results.insert(query.query_id.clone(), query_results);
159 }
160
161 for metric in &self.metrics {
163 let score = self.calculate_retrieval_metric(metric, &per_query_results)?;
164 metric_scores.insert(format!("{metric:?}"), score);
165 }
166
167 let retrieval_analysis = self.analyze_retrieval_performance(&per_query_results)?;
169
170 let overall_quality = self.calculate_overall_quality(&metric_scores);
172
173 Ok(RetrievalResults {
174 metric_scores,
175 per_query_results,
176 retrieval_analysis,
177 overall_quality,
178 })
179 }
180
181 async fn evaluate_query_retrieval(
183 &self,
184 query: &RetrievalQuery,
185 model: &dyn EmbeddingModel,
186 ) -> Result<QueryRetrievalResults> {
187 let start_time = std::time::Instant::now();
188
189 let retrieved_docs = self.retrieve_documents(query, model).await?;
191
192 let response_time = start_time.elapsed().as_millis() as f64;
193
194 let mut precision_at_k = HashMap::new();
196 let mut recall_at_k = HashMap::new();
197 let mut ndcg_scores = HashMap::new();
198
199 let relevant_set: std::collections::HashSet<String> =
200 query.relevant_docs.iter().cloned().collect();
201
202 for &k in &[1, 3, 5, 10, 20] {
203 if k <= retrieved_docs.len() {
204 let top_k_docs: std::collections::HashSet<String> = retrieved_docs
205 .iter()
206 .take(k)
207 .map(|(doc_id, _)| doc_id.clone())
208 .collect();
209
210 let relevant_retrieved = top_k_docs.intersection(&relevant_set).count();
211
212 let precision = relevant_retrieved as f64 / k as f64;
213 let recall = if !query.relevant_docs.is_empty() {
214 relevant_retrieved as f64 / query.relevant_docs.len() as f64
215 } else {
216 0.0
217 };
218
219 precision_at_k.insert(k, precision);
220 recall_at_k.insert(k, recall);
221
222 let ndcg = self.calculate_ndcg_for_query(&retrieved_docs, &relevant_set, k);
224 ndcg_scores.insert(k, ndcg);
225 }
226 }
227
228 Ok(QueryRetrievalResults {
229 query_id: query.query_id.clone(),
230 retrieved_docs,
231 precision_at_k,
232 recall_at_k,
233 ndcg_scores,
234 response_time,
235 })
236 }
237
238 async fn retrieve_documents(
240 &self,
241 query: &RetrievalQuery,
242 _model: &dyn EmbeddingModel,
243 ) -> Result<Vec<(String, f64)>> {
244 let mut doc_scores = Vec::new();
246
247 for (doc_id, doc) in &self.documents {
248 let query_words: std::collections::HashSet<&str> =
250 query.query_text.split_whitespace().collect();
251 let doc_words: std::collections::HashSet<&str> =
252 doc.content.split_whitespace().collect();
253
254 let overlap = query_words.intersection(&doc_words).count();
255 let score = overlap as f64 / query_words.len() as f64;
256
257 doc_scores.push((doc_id.clone(), score));
258 }
259
260 doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
262 doc_scores.truncate(20); Ok(doc_scores)
265 }
266
267 fn calculate_ndcg_for_query(
269 &self,
270 retrieved_docs: &[(String, f64)],
271 relevant_docs: &std::collections::HashSet<String>,
272 k: usize,
273 ) -> f64 {
274 if k == 0 || retrieved_docs.is_empty() {
275 return 0.0;
276 }
277
278 let mut dcg = 0.0;
279 for (i, (doc_id, _)) in retrieved_docs.iter().take(k).enumerate() {
280 if relevant_docs.contains(doc_id) {
281 dcg += 1.0 / (i as f64 + 2.0).log2();
282 }
283 }
284
285 let relevant_count = relevant_docs.len().min(k);
287 let mut idcg = 0.0;
288 for i in 0..relevant_count {
289 idcg += 1.0 / (i as f64 + 2.0).log2();
290 }
291
292 if idcg > 0.0 {
293 dcg / idcg
294 } else {
295 0.0
296 }
297 }
298
299 fn calculate_retrieval_metric(
301 &self,
302 metric: &RetrievalMetric,
303 per_query_results: &HashMap<String, QueryRetrievalResults>,
304 ) -> Result<f64> {
305 if per_query_results.is_empty() {
306 return Ok(0.0);
307 }
308
309 match metric {
310 RetrievalMetric::PrecisionAtK(k) => {
311 let scores: Vec<f64> = per_query_results
312 .values()
313 .filter_map(|r| r.precision_at_k.get(k))
314 .cloned()
315 .collect();
316 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
317 }
318 RetrievalMetric::RecallAtK(k) => {
319 let scores: Vec<f64> = per_query_results
320 .values()
321 .filter_map(|r| r.recall_at_k.get(k))
322 .cloned()
323 .collect();
324 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
325 }
326 RetrievalMetric::NDCG(k) => {
327 let scores: Vec<f64> = per_query_results
328 .values()
329 .filter_map(|r| r.ndcg_scores.get(k))
330 .cloned()
331 .collect();
332 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
333 }
334 _ => Ok(0.5), }
336 }
337
338 fn analyze_retrieval_performance(
340 &self,
341 per_query_results: &HashMap<String, QueryRetrievalResults>,
342 ) -> Result<RetrievalAnalysis> {
343 let avg_response_time = per_query_results
344 .values()
345 .map(|r| r.response_time)
346 .sum::<f64>()
347 / per_query_results.len() as f64;
348
349 let completion_rate = per_query_results
350 .values()
351 .filter(|r| !r.retrieved_docs.is_empty())
352 .count() as f64
353 / per_query_results.len() as f64;
354
355 Ok(RetrievalAnalysis {
356 performance_by_type: HashMap::new(), document_coverage: 0.8, completion_rate,
359 avg_response_time,
360 })
361 }
362
363 fn calculate_overall_quality(&self, metric_scores: &HashMap<String, f64>) -> f64 {
365 let relevant_metrics = ["PrecisionAtK(10)", "RecallAtK(10)", "NDCG(10)"];
366 let mut total_score = 0.0;
367 let mut count = 0;
368
369 for metric_name in &relevant_metrics {
370 if let Some(&score) = metric_scores.get(*metric_name) {
371 total_score += score;
372 count += 1;
373 }
374 }
375
376 if count > 0 {
377 total_score / count as f64
378 } else {
379 0.0
380 }
381 }
382}
383
384impl Default for RetrievalEvaluator {
385 fn default() -> Self {
386 Self::new()
387 }
388}