oxirs_embed/application_tasks/
search.rs1use super::ApplicationEvalConfig;
8use crate::{EmbeddingModel, Vector};
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct RelevanceJudgment {
16 pub query: String,
18 pub document_id: String,
20 pub relevance_score: u8,
22 pub annotator_id: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum SearchMetric {
29 PrecisionAtK(usize),
31 RecallAtK(usize),
33 MAP,
35 NDCG(usize),
37 MRR,
39 ERR,
41 CTR,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct QueryResults {
48 pub query: String,
50 pub precision_scores: HashMap<usize, f64>,
52 pub recall_scores: HashMap<usize, f64>,
54 pub ndcg_scores: HashMap<usize, f64>,
56 pub num_relevant: usize,
58 pub difficulty_score: f64,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct QueryPerformanceAnalysis {
65 pub avg_query_length: f64,
67 pub query_type_distribution: HashMap<String, usize>,
69 pub performance_by_difficulty: HashMap<String, f64>,
71 pub zero_result_queries: f64,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SearchEffectivenessMetrics {
78 pub search_satisfaction: f64,
80 pub relevance_distribution: HashMap<u8, usize>,
82 pub result_diversity: f64,
84 pub query_success_rate: f64,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SearchResults {
91 pub metric_scores: HashMap<String, f64>,
93 pub per_query_results: HashMap<String, QueryResults>,
95 pub query_analysis: QueryPerformanceAnalysis,
97 pub effectiveness_metrics: SearchEffectivenessMetrics,
99}
100
101pub struct SearchEvaluator {
103 query_relevance: HashMap<String, Vec<RelevanceJudgment>>,
105 metrics: Vec<SearchMetric>,
107}
108
109impl SearchEvaluator {
110 pub fn new() -> Self {
112 Self {
113 query_relevance: HashMap::new(),
114 metrics: vec![
115 SearchMetric::PrecisionAtK(1),
116 SearchMetric::PrecisionAtK(5),
117 SearchMetric::PrecisionAtK(10),
118 SearchMetric::NDCG(10),
119 SearchMetric::MAP,
120 SearchMetric::MRR,
121 ],
122 }
123 }
124
125 pub fn add_relevance_judgment(&mut self, judgment: RelevanceJudgment) {
127 self.query_relevance
128 .entry(judgment.query.clone())
129 .or_default()
130 .push(judgment);
131 }
132
133 pub async fn evaluate(
135 &self,
136 model: &dyn EmbeddingModel,
137 config: &ApplicationEvalConfig,
138 ) -> Result<SearchResults> {
139 let mut metric_scores = HashMap::new();
140 let mut per_query_results = HashMap::new();
141
142 let queries_to_evaluate: Vec<_> = self
144 .query_relevance
145 .keys()
146 .take(config.sample_size)
147 .cloned()
148 .collect();
149
150 for query in &queries_to_evaluate {
151 let query_results = self.evaluate_query_search(query, model).await?;
152 per_query_results.insert(query.clone(), query_results);
153 }
154
155 for metric in &self.metrics {
157 let score = self.calculate_search_metric(metric, &per_query_results)?;
158 metric_scores.insert(format!("{metric:?}"), score);
159 }
160
161 let query_analysis = self.analyze_query_performance(&per_query_results)?;
163 let effectiveness_metrics = self.calculate_effectiveness_metrics(&per_query_results)?;
164
165 Ok(SearchResults {
166 metric_scores,
167 per_query_results,
168 query_analysis,
169 effectiveness_metrics,
170 })
171 }
172
173 async fn evaluate_query_search(
175 &self,
176 query: &str,
177 model: &dyn EmbeddingModel,
178 ) -> Result<QueryResults> {
179 let judgments = self
180 .query_relevance
181 .get(query)
182 .expect("query should exist in query_relevance");
183
184 let search_results = self.perform_search(query, model).await?;
186
187 let mut relevance_scores = Vec::new();
189 for (doc_id, _score) in &search_results {
190 let relevance = judgments
191 .iter()
192 .find(|j| &j.document_id == doc_id)
193 .map(|j| j.relevance_score)
194 .unwrap_or(0);
195 relevance_scores.push(relevance);
196 }
197
198 let num_relevant = judgments.iter().filter(|j| j.relevance_score > 0).count();
199
200 let mut precision_scores = HashMap::new();
202 let mut recall_scores = HashMap::new();
203 let mut ndcg_scores = HashMap::new();
204
205 for &k in &[1, 3, 5, 10] {
206 if k <= search_results.len() {
207 let relevant_at_k =
208 relevance_scores.iter().take(k).filter(|&&r| r > 0).count() as f64;
209
210 let precision = relevant_at_k / k as f64;
211 let recall = if num_relevant > 0 {
212 relevant_at_k / num_relevant as f64
213 } else {
214 0.0
215 };
216
217 precision_scores.insert(k, precision);
218 recall_scores.insert(k, recall);
219
220 let ndcg = self.calculate_search_ndcg(&relevance_scores, k)?;
222 ndcg_scores.insert(k, ndcg);
223 }
224 }
225
226 let difficulty_score = self.calculate_query_difficulty(query, num_relevant);
227
228 Ok(QueryResults {
229 query: query.to_string(),
230 precision_scores,
231 recall_scores,
232 ndcg_scores,
233 num_relevant,
234 difficulty_score,
235 })
236 }
237
238 async fn perform_search(
240 &self,
241 query: &str,
242 model: &dyn EmbeddingModel,
243 ) -> Result<Vec<(String, f64)>> {
244 let query_words: Vec<&str> = query.split_whitespace().collect();
246 let mut query_embedding = vec![0.0f32; 100];
247
248 for (i, word) in query_words.iter().enumerate() {
250 if i < query_embedding.len() {
251 query_embedding[i] = word.len() as f32 / 10.0;
252 }
253 }
254 let query_vector = Vector::new(query_embedding);
255
256 let entities = model.get_entities();
258 let mut search_results = Vec::new();
259
260 for entity in entities.iter().take(100) {
261 if let Ok(entity_embedding) = model.get_entity_embedding(entity) {
263 let score = self.cosine_similarity(&query_vector, &entity_embedding);
264 search_results.push((entity.clone(), score));
265 }
266 }
267
268 search_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270 search_results.truncate(20);
271
272 Ok(search_results)
273 }
274
275 fn calculate_search_ndcg(&self, relevance_scores: &[u8], k: usize) -> Result<f64> {
277 if k == 0 || relevance_scores.is_empty() {
278 return Ok(0.0);
279 }
280
281 let mut dcg = 0.0;
282 for (i, &relevance) in relevance_scores.iter().take(k).enumerate() {
283 if relevance > 0 {
284 let gain = (2_u32.pow(relevance as u32) - 1) as f64;
285 dcg += gain / (i as f64 + 2.0).log2();
286 }
287 }
288
289 let mut ideal_relevance: Vec<u8> = relevance_scores.to_vec();
291 ideal_relevance.sort_by(|a, b| b.cmp(a));
292
293 let mut idcg = 0.0;
294 for (i, &relevance) in ideal_relevance.iter().take(k).enumerate() {
295 if relevance > 0 {
296 let gain = (2_u32.pow(relevance as u32) - 1) as f64;
297 idcg += gain / (i as f64 + 2.0).log2();
298 }
299 }
300
301 if idcg > 0.0 {
302 Ok(dcg / idcg)
303 } else {
304 Ok(0.0)
305 }
306 }
307
308 fn calculate_query_difficulty(&self, query: &str, num_relevant: usize) -> f64 {
310 let query_length = query.split_whitespace().count() as f64;
311 let relevance_factor = if num_relevant == 0 {
312 1.0 } else {
314 1.0 / (num_relevant as f64).log2()
315 };
316
317 (query_length * 0.1 + relevance_factor * 0.9).min(1.0)
318 }
319
320 fn calculate_search_metric(
322 &self,
323 metric: &SearchMetric,
324 per_query_results: &HashMap<String, QueryResults>,
325 ) -> Result<f64> {
326 if per_query_results.is_empty() {
327 return Ok(0.0);
328 }
329
330 match metric {
331 SearchMetric::PrecisionAtK(k) => {
332 let scores: Vec<f64> = per_query_results
333 .values()
334 .filter_map(|r| r.precision_scores.get(k))
335 .cloned()
336 .collect();
337 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
338 }
339 SearchMetric::NDCG(k) => {
340 let scores: Vec<f64> = per_query_results
341 .values()
342 .filter_map(|r| r.ndcg_scores.get(k))
343 .cloned()
344 .collect();
345 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
346 }
347 _ => Ok(0.5), }
349 }
350
351 fn analyze_query_performance(
353 &self,
354 per_query_results: &HashMap<String, QueryResults>,
355 ) -> Result<QueryPerformanceAnalysis> {
356 let avg_query_length = per_query_results
357 .keys()
358 .map(|q| q.split_whitespace().count() as f64)
359 .sum::<f64>()
360 / per_query_results.len() as f64;
361
362 let zero_result_queries = per_query_results
363 .values()
364 .filter(|r| r.num_relevant == 0)
365 .count() as f64
366 / per_query_results.len() as f64;
367
368 Ok(QueryPerformanceAnalysis {
369 avg_query_length,
370 query_type_distribution: HashMap::new(), performance_by_difficulty: HashMap::new(), zero_result_queries,
373 })
374 }
375
376 fn calculate_effectiveness_metrics(
378 &self,
379 per_query_results: &HashMap<String, QueryResults>,
380 ) -> Result<SearchEffectivenessMetrics> {
381 let successful_queries = per_query_results
382 .values()
383 .filter(|r| r.precision_scores.get(&1).unwrap_or(&0.0) > &0.0)
384 .count() as f64;
385
386 let query_success_rate = successful_queries / per_query_results.len() as f64;
387
388 Ok(SearchEffectivenessMetrics {
389 search_satisfaction: query_success_rate * 0.8, relevance_distribution: HashMap::new(), result_diversity: 0.6, query_success_rate,
393 })
394 }
395
396 fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
398 let dot_product: f32 = v1
399 .values
400 .iter()
401 .zip(v2.values.iter())
402 .map(|(a, b)| a * b)
403 .sum();
404 let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
405 let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
406
407 if norm_a > 0.0 && norm_b > 0.0 {
408 (dot_product / (norm_a * norm_b)) as f64
409 } else {
410 0.0
411 }
412 }
413}
414
415impl Default for SearchEvaluator {
416 fn default() -> Self {
417 Self::new()
418 }
419}