oxirs_embed/application_tasks/
search.rs1use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct RelevanceJudgment {
16 pub query: String,
18 pub document_id: String,
20 pub relevance_score: u8,
22 pub annotator_id: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum SearchMetric {
29 PrecisionAtK(usize),
31 RecallAtK(usize),
33 MAP,
35 NDCG(usize),
37 MRR,
39 ERR,
41 CTR,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct QueryResults {
48 pub query: String,
50 pub precision_scores: HashMap<usize, f64>,
52 pub recall_scores: HashMap<usize, f64>,
54 pub ndcg_scores: HashMap<usize, f64>,
56 pub num_relevant: usize,
58 pub difficulty_score: f64,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryPerformanceAnalysis {
65 pub avg_query_length: f64,
67 pub query_type_distribution: HashMap<String, usize>,
69 pub performance_by_difficulty: HashMap<String, f64>,
71 pub zero_result_queries: f64,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SearchEffectivenessMetrics {
78 pub search_satisfaction: f64,
80 pub relevance_distribution: HashMap<u8, usize>,
82 pub result_diversity: f64,
84 pub query_success_rate: f64,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SearchResults {
91 pub metric_scores: HashMap<String, f64>,
93 pub per_query_results: HashMap<String, QueryResults>,
95 pub query_analysis: QueryPerformanceAnalysis,
97 pub effectiveness_metrics: SearchEffectivenessMetrics,
99}
100
101pub struct SearchEvaluator {
103 query_relevance: HashMap<String, Vec<RelevanceJudgment>>,
105 metrics: Vec<SearchMetric>,
107}
108
109impl SearchEvaluator {
110 pub fn new() -> Self {
112 Self {
113 query_relevance: HashMap::new(),
114 metrics: vec![
115 SearchMetric::PrecisionAtK(1),
116 SearchMetric::PrecisionAtK(5),
117 SearchMetric::PrecisionAtK(10),
118 SearchMetric::NDCG(10),
119 SearchMetric::MAP,
120 SearchMetric::MRR,
121 ],
122 }
123 }
124
125 pub fn add_relevance_judgment(&mut self, judgment: RelevanceJudgment) {
127 self.query_relevance
128 .entry(judgment.query.clone())
129 .or_default()
130 .push(judgment);
131 }
132
133 pub async fn evaluate(
135 &self,
136 model: &dyn EmbeddingModel,
137 config: &ApplicationEvalConfig,
138 ) -> Result<SearchResults> {
139 let mut metric_scores = HashMap::new();
140 let mut per_query_results = HashMap::new();
141
142 let queries_to_evaluate: Vec<_> = self
144 .query_relevance
145 .keys()
146 .take(config.sample_size)
147 .cloned()
148 .collect();
149
150 for query in &queries_to_evaluate {
151 let query_results = self.evaluate_query_search(query, model).await?;
152 per_query_results.insert(query.clone(), query_results);
153 }
154
155 for metric in &self.metrics {
157 let score = self.calculate_search_metric(metric, &per_query_results)?;
158 metric_scores.insert(format!("{metric:?}"), score);
159 }
160
161 let query_analysis = self.analyze_query_performance(&per_query_results)?;
163 let effectiveness_metrics = self.calculate_effectiveness_metrics(&per_query_results)?;
164
165 Ok(SearchResults {
166 metric_scores,
167 per_query_results,
168 query_analysis,
169 effectiveness_metrics,
170 })
171 }
172
173 async fn evaluate_query_search(
175 &self,
176 query: &str,
177 model: &dyn EmbeddingModel,
178 ) -> Result<QueryResults> {
179 let judgments = self.query_relevance.get(query).unwrap();
180
181 let search_results = self.perform_search(query, model).await?;
183
184 let mut relevance_scores = Vec::new();
186 for (doc_id, _score) in &search_results {
187 let relevance = judgments
188 .iter()
189 .find(|j| &j.document_id == doc_id)
190 .map(|j| j.relevance_score)
191 .unwrap_or(0);
192 relevance_scores.push(relevance);
193 }
194
195 let num_relevant = judgments.iter().filter(|j| j.relevance_score > 0).count();
196
197 let mut precision_scores = HashMap::new();
199 let mut recall_scores = HashMap::new();
200 let mut ndcg_scores = HashMap::new();
201
202 for &k in &[1, 3, 5, 10] {
203 if k <= search_results.len() {
204 let relevant_at_k =
205 relevance_scores.iter().take(k).filter(|&&r| r > 0).count() as f64;
206
207 let precision = relevant_at_k / k as f64;
208 let recall = if num_relevant > 0 {
209 relevant_at_k / num_relevant as f64
210 } else {
211 0.0
212 };
213
214 precision_scores.insert(k, precision);
215 recall_scores.insert(k, recall);
216
217 let ndcg = self.calculate_search_ndcg(&relevance_scores, k)?;
219 ndcg_scores.insert(k, ndcg);
220 }
221 }
222
223 let difficulty_score = self.calculate_query_difficulty(query, num_relevant);
224
225 Ok(QueryResults {
226 query: query.to_string(),
227 precision_scores,
228 recall_scores,
229 ndcg_scores,
230 num_relevant,
231 difficulty_score,
232 })
233 }
234
235 async fn perform_search(
237 &self,
238 query: &str,
239 model: &dyn EmbeddingModel,
240 ) -> Result<Vec<(String, f64)>> {
241 let query_words: Vec<&str> = query.split_whitespace().collect();
243 let mut query_embedding = vec![0.0f32; 100];
244
245 for (i, word) in query_words.iter().enumerate() {
247 if i < query_embedding.len() {
248 query_embedding[i] = word.len() as f32 / 10.0;
249 }
250 }
251 let query_vector = Vector::new(query_embedding);
252
253 let entities = model.get_entities();
255 let mut search_results = Vec::new();
256
257 for entity in entities.iter().take(100) {
258 if let Ok(entity_embedding) = model.get_entity_embedding(entity) {
260 let score = self.cosine_similarity(&query_vector, &entity_embedding);
261 search_results.push((entity.clone(), score));
262 }
263 }
264
265 search_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267 search_results.truncate(20);
268
269 Ok(search_results)
270 }
271
272 fn calculate_search_ndcg(&self, relevance_scores: &[u8], k: usize) -> Result<f64> {
274 if k == 0 || relevance_scores.is_empty() {
275 return Ok(0.0);
276 }
277
278 let mut dcg = 0.0;
279 for (i, &relevance) in relevance_scores.iter().take(k).enumerate() {
280 if relevance > 0 {
281 let gain = (2_u32.pow(relevance as u32) - 1) as f64;
282 dcg += gain / (i as f64 + 2.0).log2();
283 }
284 }
285
286 let mut ideal_relevance: Vec<u8> = relevance_scores.to_vec();
288 ideal_relevance.sort_by(|a, b| b.cmp(a));
289
290 let mut idcg = 0.0;
291 for (i, &relevance) in ideal_relevance.iter().take(k).enumerate() {
292 if relevance > 0 {
293 let gain = (2_u32.pow(relevance as u32) - 1) as f64;
294 idcg += gain / (i as f64 + 2.0).log2();
295 }
296 }
297
298 if idcg > 0.0 {
299 Ok(dcg / idcg)
300 } else {
301 Ok(0.0)
302 }
303 }
304
305 fn calculate_query_difficulty(&self, query: &str, num_relevant: usize) -> f64 {
307 let query_length = query.split_whitespace().count() as f64;
308 let relevance_factor = if num_relevant == 0 {
309 1.0 } else {
311 1.0 / (num_relevant as f64).log2()
312 };
313
314 (query_length * 0.1 + relevance_factor * 0.9).min(1.0)
315 }
316
317 fn calculate_search_metric(
319 &self,
320 metric: &SearchMetric,
321 per_query_results: &HashMap<String, QueryResults>,
322 ) -> Result<f64> {
323 if per_query_results.is_empty() {
324 return Ok(0.0);
325 }
326
327 match metric {
328 SearchMetric::PrecisionAtK(k) => {
329 let scores: Vec<f64> = per_query_results
330 .values()
331 .filter_map(|r| r.precision_scores.get(k))
332 .cloned()
333 .collect();
334 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
335 }
336 SearchMetric::NDCG(k) => {
337 let scores: Vec<f64> = per_query_results
338 .values()
339 .filter_map(|r| r.ndcg_scores.get(k))
340 .cloned()
341 .collect();
342 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
343 }
344 _ => Ok(0.5), }
346 }
347
348 fn analyze_query_performance(
350 &self,
351 per_query_results: &HashMap<String, QueryResults>,
352 ) -> Result<QueryPerformanceAnalysis> {
353 let avg_query_length = per_query_results
354 .keys()
355 .map(|q| q.split_whitespace().count() as f64)
356 .sum::<f64>()
357 / per_query_results.len() as f64;
358
359 let zero_result_queries = per_query_results
360 .values()
361 .filter(|r| r.num_relevant == 0)
362 .count() as f64
363 / per_query_results.len() as f64;
364
365 Ok(QueryPerformanceAnalysis {
366 avg_query_length,
367 query_type_distribution: HashMap::new(), performance_by_difficulty: HashMap::new(), zero_result_queries,
370 })
371 }
372
373 fn calculate_effectiveness_metrics(
375 &self,
376 per_query_results: &HashMap<String, QueryResults>,
377 ) -> Result<SearchEffectivenessMetrics> {
378 let successful_queries = per_query_results
379 .values()
380 .filter(|r| r.precision_scores.get(&1).unwrap_or(&0.0) > &0.0)
381 .count() as f64;
382
383 let query_success_rate = successful_queries / per_query_results.len() as f64;
384
385 Ok(SearchEffectivenessMetrics {
386 search_satisfaction: query_success_rate * 0.8, relevance_distribution: HashMap::new(), result_diversity: 0.6, query_success_rate,
390 })
391 }
392
393 fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
395 let dot_product: f32 = v1
396 .values
397 .iter()
398 .zip(v2.values.iter())
399 .map(|(a, b)| a * b)
400 .sum();
401 let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
402 let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
403
404 if norm_a > 0.0 && norm_b > 0.0 {
405 (dot_product / (norm_a * norm_b)) as f64
406 } else {
407 0.0
408 }
409 }
410}
411
412impl Default for SearchEvaluator {
413 fn default() -> Self {
414 Self::new()
415 }
416}