oxirs_embed/application_tasks/
search.rs

1//! Search relevance evaluation module
2//!
3//! This module provides comprehensive evaluation for search relevance using
4//! embedding models, including precision, recall, NDCG, MAP, and other
5//! information retrieval metrics.
6
7use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Relevance judgment for search evaluation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct RelevanceJudgment {
16    /// Query
17    pub query: String,
18    /// Document/entity identifier
19    pub document_id: String,
20    /// Relevance score (0-3: not relevant, somewhat relevant, relevant, highly relevant)
21    pub relevance_score: u8,
22    /// Annotator identifier
23    pub annotator_id: String,
24}
25
26/// Search evaluation metrics
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum SearchMetric {
29    /// Precision at K
30    PrecisionAtK(usize),
31    /// Recall at K
32    RecallAtK(usize),
33    /// Mean Average Precision
34    MAP,
35    /// Normalized Discounted Cumulative Gain
36    NDCG(usize),
37    /// Mean Reciprocal Rank
38    MRR,
39    /// Expected Reciprocal Rank
40    ERR,
41    /// Click-through rate simulation
42    CTR,
43}
44
45/// Per-query search results
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct QueryResults {
48    /// Query text
49    pub query: String,
50    /// Precision scores at different K values
51    pub precision_scores: HashMap<usize, f64>,
52    /// Recall scores at different K values
53    pub recall_scores: HashMap<usize, f64>,
54    /// NDCG scores
55    pub ndcg_scores: HashMap<usize, f64>,
56    /// Number of relevant documents
57    pub num_relevant: usize,
58    /// Query difficulty score
59    pub difficulty_score: f64,
60}
61
62/// Query performance analysis
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryPerformanceAnalysis {
65    /// Average query length
66    pub avg_query_length: f64,
67    /// Query type distribution
68    pub query_type_distribution: HashMap<String, usize>,
69    /// Performance by query difficulty
70    pub performance_by_difficulty: HashMap<String, f64>,
71    /// Zero-result queries percentage
72    pub zero_result_queries: f64,
73}
74
75/// Search effectiveness metrics
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SearchEffectivenessMetrics {
78    /// Overall search satisfaction
79    pub search_satisfaction: f64,
80    /// Result relevance distribution
81    pub relevance_distribution: HashMap<u8, usize>,
82    /// Search result diversity
83    pub result_diversity: f64,
84    /// Query success rate
85    pub query_success_rate: f64,
86}
87
88/// Search evaluation results
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SearchResults {
91    /// Metric scores
92    pub metric_scores: HashMap<String, f64>,
93    /// Per-query results
94    pub per_query_results: HashMap<String, QueryResults>,
95    /// Query performance analysis
96    pub query_analysis: QueryPerformanceAnalysis,
97    /// Search effectiveness metrics
98    pub effectiveness_metrics: SearchEffectivenessMetrics,
99}
100
101/// Search relevance evaluator
102pub struct SearchEvaluator {
103    /// Search queries and their relevance judgments
104    query_relevance: HashMap<String, Vec<RelevanceJudgment>>,
105    /// Search metrics to evaluate
106    metrics: Vec<SearchMetric>,
107}
108
109impl SearchEvaluator {
110    /// Create a new search evaluator
111    pub fn new() -> Self {
112        Self {
113            query_relevance: HashMap::new(),
114            metrics: vec![
115                SearchMetric::PrecisionAtK(1),
116                SearchMetric::PrecisionAtK(5),
117                SearchMetric::PrecisionAtK(10),
118                SearchMetric::NDCG(10),
119                SearchMetric::MAP,
120                SearchMetric::MRR,
121            ],
122        }
123    }
124
125    /// Add relevance judgment
126    pub fn add_relevance_judgment(&mut self, judgment: RelevanceJudgment) {
127        self.query_relevance
128            .entry(judgment.query.clone())
129            .or_default()
130            .push(judgment);
131    }
132
133    /// Evaluate search relevance
134    pub async fn evaluate(
135        &self,
136        model: &dyn EmbeddingModel,
137        config: &ApplicationEvalConfig,
138    ) -> Result<SearchResults> {
139        let mut metric_scores = HashMap::new();
140        let mut per_query_results = HashMap::new();
141
142        // Sample queries for evaluation
143        let queries_to_evaluate: Vec<_> = self
144            .query_relevance
145            .keys()
146            .take(config.sample_size)
147            .cloned()
148            .collect();
149
150        for query in &queries_to_evaluate {
151            let query_results = self.evaluate_query_search(query, model).await?;
152            per_query_results.insert(query.clone(), query_results);
153        }
154
155        // Calculate aggregate metrics
156        for metric in &self.metrics {
157            let score = self.calculate_search_metric(metric, &per_query_results)?;
158            metric_scores.insert(format!("{metric:?}"), score);
159        }
160
161        // Analyze query performance
162        let query_analysis = self.analyze_query_performance(&per_query_results)?;
163        let effectiveness_metrics = self.calculate_effectiveness_metrics(&per_query_results)?;
164
165        Ok(SearchResults {
166            metric_scores,
167            per_query_results,
168            query_analysis,
169            effectiveness_metrics,
170        })
171    }
172
173    /// Evaluate search for a specific query
174    async fn evaluate_query_search(
175        &self,
176        query: &str,
177        model: &dyn EmbeddingModel,
178    ) -> Result<QueryResults> {
179        let judgments = self.query_relevance.get(query).unwrap();
180
181        // Get search results (simplified - would use actual search system)
182        let search_results = self.perform_search(query, model).await?;
183
184        // Calculate relevance for each result
185        let mut relevance_scores = Vec::new();
186        for (doc_id, _score) in &search_results {
187            let relevance = judgments
188                .iter()
189                .find(|j| &j.document_id == doc_id)
190                .map(|j| j.relevance_score)
191                .unwrap_or(0);
192            relevance_scores.push(relevance);
193        }
194
195        let num_relevant = judgments.iter().filter(|j| j.relevance_score > 0).count();
196
197        // Calculate metrics at different K values
198        let mut precision_scores = HashMap::new();
199        let mut recall_scores = HashMap::new();
200        let mut ndcg_scores = HashMap::new();
201
202        for &k in &[1, 3, 5, 10] {
203            if k <= search_results.len() {
204                let relevant_at_k =
205                    relevance_scores.iter().take(k).filter(|&&r| r > 0).count() as f64;
206
207                let precision = relevant_at_k / k as f64;
208                let recall = if num_relevant > 0 {
209                    relevant_at_k / num_relevant as f64
210                } else {
211                    0.0
212                };
213
214                precision_scores.insert(k, precision);
215                recall_scores.insert(k, recall);
216
217                // Calculate NDCG
218                let ndcg = self.calculate_search_ndcg(&relevance_scores, k)?;
219                ndcg_scores.insert(k, ndcg);
220            }
221        }
222
223        let difficulty_score = self.calculate_query_difficulty(query, num_relevant);
224
225        Ok(QueryResults {
226            query: query.to_string(),
227            precision_scores,
228            recall_scores,
229            ndcg_scores,
230            num_relevant,
231            difficulty_score,
232        })
233    }
234
235    /// Perform search (simplified implementation)
236    async fn perform_search(
237        &self,
238        query: &str,
239        model: &dyn EmbeddingModel,
240    ) -> Result<Vec<(String, f64)>> {
241        // Create query embedding (simplified)
242        let query_words: Vec<&str> = query.split_whitespace().collect();
243        let mut query_embedding = vec![0.0f32; 100];
244
245        // Simple word-based embedding (in practice, use proper query embedding)
246        for (i, word) in query_words.iter().enumerate() {
247            if i < query_embedding.len() {
248                query_embedding[i] = word.len() as f32 / 10.0;
249            }
250        }
251        let query_vector = Vector::new(query_embedding);
252
253        // Score entities (documents) against query
254        let entities = model.get_entities();
255        let mut search_results = Vec::new();
256
257        for entity in entities.iter().take(100) {
258            // Limit for efficiency
259            if let Ok(entity_embedding) = model.get_entity_embedding(entity) {
260                let score = self.cosine_similarity(&query_vector, &entity_embedding);
261                search_results.push((entity.clone(), score));
262            }
263        }
264
265        // Sort by score and return top results
266        search_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267        search_results.truncate(20);
268
269        Ok(search_results)
270    }
271
272    /// Calculate NDCG for search results
273    fn calculate_search_ndcg(&self, relevance_scores: &[u8], k: usize) -> Result<f64> {
274        if k == 0 || relevance_scores.is_empty() {
275            return Ok(0.0);
276        }
277
278        let mut dcg = 0.0;
279        for (i, &relevance) in relevance_scores.iter().take(k).enumerate() {
280            if relevance > 0 {
281                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
282                dcg += gain / (i as f64 + 2.0).log2();
283            }
284        }
285
286        // Calculate ideal DCG
287        let mut ideal_relevance: Vec<u8> = relevance_scores.to_vec();
288        ideal_relevance.sort_by(|a, b| b.cmp(a));
289
290        let mut idcg = 0.0;
291        for (i, &relevance) in ideal_relevance.iter().take(k).enumerate() {
292            if relevance > 0 {
293                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
294                idcg += gain / (i as f64 + 2.0).log2();
295            }
296        }
297
298        if idcg > 0.0 {
299            Ok(dcg / idcg)
300        } else {
301            Ok(0.0)
302        }
303    }
304
305    /// Calculate query difficulty
306    fn calculate_query_difficulty(&self, query: &str, num_relevant: usize) -> f64 {
307        let query_length = query.split_whitespace().count() as f64;
308        let relevance_factor = if num_relevant == 0 {
309            1.0 // High difficulty
310        } else {
311            1.0 / (num_relevant as f64).log2()
312        };
313
314        (query_length * 0.1 + relevance_factor * 0.9).min(1.0)
315    }
316
317    /// Calculate aggregate search metric
318    fn calculate_search_metric(
319        &self,
320        metric: &SearchMetric,
321        per_query_results: &HashMap<String, QueryResults>,
322    ) -> Result<f64> {
323        if per_query_results.is_empty() {
324            return Ok(0.0);
325        }
326
327        match metric {
328            SearchMetric::PrecisionAtK(k) => {
329                let scores: Vec<f64> = per_query_results
330                    .values()
331                    .filter_map(|r| r.precision_scores.get(k))
332                    .cloned()
333                    .collect();
334                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
335            }
336            SearchMetric::NDCG(k) => {
337                let scores: Vec<f64> = per_query_results
338                    .values()
339                    .filter_map(|r| r.ndcg_scores.get(k))
340                    .cloned()
341                    .collect();
342                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
343            }
344            _ => Ok(0.5), // Placeholder for other metrics
345        }
346    }
347
348    /// Analyze query performance
349    fn analyze_query_performance(
350        &self,
351        per_query_results: &HashMap<String, QueryResults>,
352    ) -> Result<QueryPerformanceAnalysis> {
353        let avg_query_length = per_query_results
354            .keys()
355            .map(|q| q.split_whitespace().count() as f64)
356            .sum::<f64>()
357            / per_query_results.len() as f64;
358
359        let zero_result_queries = per_query_results
360            .values()
361            .filter(|r| r.num_relevant == 0)
362            .count() as f64
363            / per_query_results.len() as f64;
364
365        Ok(QueryPerformanceAnalysis {
366            avg_query_length,
367            query_type_distribution: HashMap::new(), // Simplified
368            performance_by_difficulty: HashMap::new(), // Simplified
369            zero_result_queries,
370        })
371    }
372
373    /// Calculate effectiveness metrics
374    fn calculate_effectiveness_metrics(
375        &self,
376        per_query_results: &HashMap<String, QueryResults>,
377    ) -> Result<SearchEffectivenessMetrics> {
378        let successful_queries = per_query_results
379            .values()
380            .filter(|r| r.precision_scores.get(&1).unwrap_or(&0.0) > &0.0)
381            .count() as f64;
382
383        let query_success_rate = successful_queries / per_query_results.len() as f64;
384
385        Ok(SearchEffectivenessMetrics {
386            search_satisfaction: query_success_rate * 0.8, // Simplified
387            relevance_distribution: HashMap::new(),        // Simplified
388            result_diversity: 0.6,                         // Simplified
389            query_success_rate,
390        })
391    }
392
393    /// Calculate cosine similarity
394    fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
395        let dot_product: f32 = v1
396            .values
397            .iter()
398            .zip(v2.values.iter())
399            .map(|(a, b)| a * b)
400            .sum();
401        let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
402        let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
403
404        if norm_a > 0.0 && norm_b > 0.0 {
405            (dot_product / (norm_a * norm_b)) as f64
406        } else {
407            0.0
408        }
409    }
410}
411
412impl Default for SearchEvaluator {
413    fn default() -> Self {
414        Self::new()
415    }
416}