Skip to main content

oxirs_embed/application_tasks/
search.rs

1//! Search relevance evaluation module
2//!
3//! This module provides comprehensive evaluation for search relevance using
4//! embedding models, including precision, recall, NDCG, MAP, and other
5//! information retrieval metrics.
6
7use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Relevance judgment for search evaluation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct RelevanceJudgment {
16    /// Query
17    pub query: String,
18    /// Document/entity identifier
19    pub document_id: String,
20    /// Relevance score (0-3: not relevant, somewhat relevant, relevant, highly relevant)
21    pub relevance_score: u8,
22    /// Annotator identifier
23    pub annotator_id: String,
24}
25
26/// Search evaluation metrics
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum SearchMetric {
29    /// Precision at K
30    PrecisionAtK(usize),
31    /// Recall at K
32    RecallAtK(usize),
33    /// Mean Average Precision
34    MAP,
35    /// Normalized Discounted Cumulative Gain
36    NDCG(usize),
37    /// Mean Reciprocal Rank
38    MRR,
39    /// Expected Reciprocal Rank
40    ERR,
41    /// Click-through rate simulation
42    CTR,
43}
44
45/// Per-query search results
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct QueryResults {
48    /// Query text
49    pub query: String,
50    /// Precision scores at different K values
51    pub precision_scores: HashMap<usize, f64>,
52    /// Recall scores at different K values
53    pub recall_scores: HashMap<usize, f64>,
54    /// NDCG scores
55    pub ndcg_scores: HashMap<usize, f64>,
56    /// Number of relevant documents
57    pub num_relevant: usize,
58    /// Query difficulty score
59    pub difficulty_score: f64,
60}
61
62/// Query performance analysis
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryPerformanceAnalysis {
65    /// Average query length
66    pub avg_query_length: f64,
67    /// Query type distribution
68    pub query_type_distribution: HashMap<String, usize>,
69    /// Performance by query difficulty
70    pub performance_by_difficulty: HashMap<String, f64>,
71    /// Zero-result queries percentage
72    pub zero_result_queries: f64,
73}
74
75/// Search effectiveness metrics
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SearchEffectivenessMetrics {
78    /// Overall search satisfaction
79    pub search_satisfaction: f64,
80    /// Result relevance distribution
81    pub relevance_distribution: HashMap<u8, usize>,
82    /// Search result diversity
83    pub result_diversity: f64,
84    /// Query success rate
85    pub query_success_rate: f64,
86}
87
88/// Search evaluation results
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SearchResults {
91    /// Metric scores
92    pub metric_scores: HashMap<String, f64>,
93    /// Per-query results
94    pub per_query_results: HashMap<String, QueryResults>,
95    /// Query performance analysis
96    pub query_analysis: QueryPerformanceAnalysis,
97    /// Search effectiveness metrics
98    pub effectiveness_metrics: SearchEffectivenessMetrics,
99}
100
101/// Search relevance evaluator
102pub struct SearchEvaluator {
103    /// Search queries and their relevance judgments
104    query_relevance: HashMap<String, Vec<RelevanceJudgment>>,
105    /// Search metrics to evaluate
106    metrics: Vec<SearchMetric>,
107}
108
109impl SearchEvaluator {
110    /// Create a new search evaluator
111    pub fn new() -> Self {
112        Self {
113            query_relevance: HashMap::new(),
114            metrics: vec![
115                SearchMetric::PrecisionAtK(1),
116                SearchMetric::PrecisionAtK(5),
117                SearchMetric::PrecisionAtK(10),
118                SearchMetric::NDCG(10),
119                SearchMetric::MAP,
120                SearchMetric::MRR,
121            ],
122        }
123    }
124
125    /// Add relevance judgment
126    pub fn add_relevance_judgment(&mut self, judgment: RelevanceJudgment) {
127        self.query_relevance
128            .entry(judgment.query.clone())
129            .or_default()
130            .push(judgment);
131    }
132
133    /// Evaluate search relevance
134    pub async fn evaluate(
135        &self,
136        model: &dyn EmbeddingModel,
137        config: &ApplicationEvalConfig,
138    ) -> Result<SearchResults> {
139        let mut metric_scores = HashMap::new();
140        let mut per_query_results = HashMap::new();
141
142        // Sample queries for evaluation
143        let queries_to_evaluate: Vec<_> = self
144            .query_relevance
145            .keys()
146            .take(config.sample_size)
147            .cloned()
148            .collect();
149
150        for query in &queries_to_evaluate {
151            let query_results = self.evaluate_query_search(query, model).await?;
152            per_query_results.insert(query.clone(), query_results);
153        }
154
155        // Calculate aggregate metrics
156        for metric in &self.metrics {
157            let score = self.calculate_search_metric(metric, &per_query_results)?;
158            metric_scores.insert(format!("{metric:?}"), score);
159        }
160
161        // Analyze query performance
162        let query_analysis = self.analyze_query_performance(&per_query_results)?;
163        let effectiveness_metrics = self.calculate_effectiveness_metrics(&per_query_results)?;
164
165        Ok(SearchResults {
166            metric_scores,
167            per_query_results,
168            query_analysis,
169            effectiveness_metrics,
170        })
171    }
172
173    /// Evaluate search for a specific query
174    async fn evaluate_query_search(
175        &self,
176        query: &str,
177        model: &dyn EmbeddingModel,
178    ) -> Result<QueryResults> {
179        let judgments = self
180            .query_relevance
181            .get(query)
182            .expect("query should exist in query_relevance");
183
184        // Get search results (simplified - would use actual search system)
185        let search_results = self.perform_search(query, model).await?;
186
187        // Calculate relevance for each result
188        let mut relevance_scores = Vec::new();
189        for (doc_id, _score) in &search_results {
190            let relevance = judgments
191                .iter()
192                .find(|j| &j.document_id == doc_id)
193                .map(|j| j.relevance_score)
194                .unwrap_or(0);
195            relevance_scores.push(relevance);
196        }
197
198        let num_relevant = judgments.iter().filter(|j| j.relevance_score > 0).count();
199
200        // Calculate metrics at different K values
201        let mut precision_scores = HashMap::new();
202        let mut recall_scores = HashMap::new();
203        let mut ndcg_scores = HashMap::new();
204
205        for &k in &[1, 3, 5, 10] {
206            if k <= search_results.len() {
207                let relevant_at_k =
208                    relevance_scores.iter().take(k).filter(|&&r| r > 0).count() as f64;
209
210                let precision = relevant_at_k / k as f64;
211                let recall = if num_relevant > 0 {
212                    relevant_at_k / num_relevant as f64
213                } else {
214                    0.0
215                };
216
217                precision_scores.insert(k, precision);
218                recall_scores.insert(k, recall);
219
220                // Calculate NDCG
221                let ndcg = self.calculate_search_ndcg(&relevance_scores, k)?;
222                ndcg_scores.insert(k, ndcg);
223            }
224        }
225
226        let difficulty_score = self.calculate_query_difficulty(query, num_relevant);
227
228        Ok(QueryResults {
229            query: query.to_string(),
230            precision_scores,
231            recall_scores,
232            ndcg_scores,
233            num_relevant,
234            difficulty_score,
235        })
236    }
237
238    /// Perform search (simplified implementation)
239    async fn perform_search(
240        &self,
241        query: &str,
242        model: &dyn EmbeddingModel,
243    ) -> Result<Vec<(String, f64)>> {
244        // Create query embedding (simplified)
245        let query_words: Vec<&str> = query.split_whitespace().collect();
246        let mut query_embedding = vec![0.0f32; 100];
247
248        // Simple word-based embedding (in practice, use proper query embedding)
249        for (i, word) in query_words.iter().enumerate() {
250            if i < query_embedding.len() {
251                query_embedding[i] = word.len() as f32 / 10.0;
252            }
253        }
254        let query_vector = Vector::new(query_embedding);
255
256        // Score entities (documents) against query
257        let entities = model.get_entities();
258        let mut search_results = Vec::new();
259
260        for entity in entities.iter().take(100) {
261            // Limit for efficiency
262            if let Ok(entity_embedding) = model.get_entity_embedding(entity) {
263                let score = self.cosine_similarity(&query_vector, &entity_embedding);
264                search_results.push((entity.clone(), score));
265            }
266        }
267
268        // Sort by score and return top results
269        search_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270        search_results.truncate(20);
271
272        Ok(search_results)
273    }
274
275    /// Calculate NDCG for search results
276    fn calculate_search_ndcg(&self, relevance_scores: &[u8], k: usize) -> Result<f64> {
277        if k == 0 || relevance_scores.is_empty() {
278            return Ok(0.0);
279        }
280
281        let mut dcg = 0.0;
282        for (i, &relevance) in relevance_scores.iter().take(k).enumerate() {
283            if relevance > 0 {
284                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
285                dcg += gain / (i as f64 + 2.0).log2();
286            }
287        }
288
289        // Calculate ideal DCG
290        let mut ideal_relevance: Vec<u8> = relevance_scores.to_vec();
291        ideal_relevance.sort_by(|a, b| b.cmp(a));
292
293        let mut idcg = 0.0;
294        for (i, &relevance) in ideal_relevance.iter().take(k).enumerate() {
295            if relevance > 0 {
296                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
297                idcg += gain / (i as f64 + 2.0).log2();
298            }
299        }
300
301        if idcg > 0.0 {
302            Ok(dcg / idcg)
303        } else {
304            Ok(0.0)
305        }
306    }
307
308    /// Calculate query difficulty
309    fn calculate_query_difficulty(&self, query: &str, num_relevant: usize) -> f64 {
310        let query_length = query.split_whitespace().count() as f64;
311        let relevance_factor = if num_relevant == 0 {
312            1.0 // High difficulty
313        } else {
314            1.0 / (num_relevant as f64).log2()
315        };
316
317        (query_length * 0.1 + relevance_factor * 0.9).min(1.0)
318    }
319
320    /// Calculate aggregate search metric
321    fn calculate_search_metric(
322        &self,
323        metric: &SearchMetric,
324        per_query_results: &HashMap<String, QueryResults>,
325    ) -> Result<f64> {
326        if per_query_results.is_empty() {
327            return Ok(0.0);
328        }
329
330        match metric {
331            SearchMetric::PrecisionAtK(k) => {
332                let scores: Vec<f64> = per_query_results
333                    .values()
334                    .filter_map(|r| r.precision_scores.get(k))
335                    .cloned()
336                    .collect();
337                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
338            }
339            SearchMetric::NDCG(k) => {
340                let scores: Vec<f64> = per_query_results
341                    .values()
342                    .filter_map(|r| r.ndcg_scores.get(k))
343                    .cloned()
344                    .collect();
345                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
346            }
347            _ => Ok(0.5), // Placeholder for other metrics
348        }
349    }
350
351    /// Analyze query performance
352    fn analyze_query_performance(
353        &self,
354        per_query_results: &HashMap<String, QueryResults>,
355    ) -> Result<QueryPerformanceAnalysis> {
356        let avg_query_length = per_query_results
357            .keys()
358            .map(|q| q.split_whitespace().count() as f64)
359            .sum::<f64>()
360            / per_query_results.len() as f64;
361
362        let zero_result_queries = per_query_results
363            .values()
364            .filter(|r| r.num_relevant == 0)
365            .count() as f64
366            / per_query_results.len() as f64;
367
368        Ok(QueryPerformanceAnalysis {
369            avg_query_length,
370            query_type_distribution: HashMap::new(), // Simplified
371            performance_by_difficulty: HashMap::new(), // Simplified
372            zero_result_queries,
373        })
374    }
375
376    /// Calculate effectiveness metrics
377    fn calculate_effectiveness_metrics(
378        &self,
379        per_query_results: &HashMap<String, QueryResults>,
380    ) -> Result<SearchEffectivenessMetrics> {
381        let successful_queries = per_query_results
382            .values()
383            .filter(|r| r.precision_scores.get(&1).unwrap_or(&0.0) > &0.0)
384            .count() as f64;
385
386        let query_success_rate = successful_queries / per_query_results.len() as f64;
387
388        Ok(SearchEffectivenessMetrics {
389            search_satisfaction: query_success_rate * 0.8, // Simplified
390            relevance_distribution: HashMap::new(),        // Simplified
391            result_diversity: 0.6,                         // Simplified
392            query_success_rate,
393        })
394    }
395
396    /// Calculate cosine similarity
397    fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
398        let dot_product: f32 = v1
399            .values
400            .iter()
401            .zip(v2.values.iter())
402            .map(|(a, b)| a * b)
403            .sum();
404        let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
405        let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
406
407        if norm_a > 0.0 && norm_b > 0.0 {
408            (dot_product / (norm_a * norm_b)) as f64
409        } else {
410            0.0
411        }
412    }
413}
414
415impl Default for SearchEvaluator {
416    fn default() -> Self {
417        Self::new()
418    }
419}