oxirs-embed 0.2.4

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
//! Search relevance evaluation module
//!
//! This module provides comprehensive evaluation for search relevance using
//! embedding models, including precision, recall, NDCG, MAP, and other
//! information retrieval metrics.

use super::ApplicationEvalConfig;
use crate::{EmbeddingModel, Vector};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Relevance judgment for search evaluation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelevanceJudgment {
    /// Query
    pub query: String,
    /// Document/entity identifier
    pub document_id: String,
    /// Relevance score (0-3: not relevant, somewhat relevant, relevant, highly relevant)
    pub relevance_score: u8,
    /// Annotator identifier
    pub annotator_id: String,
}

/// Search evaluation metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchMetric {
    /// Precision at K
    PrecisionAtK(usize),
    /// Recall at K
    RecallAtK(usize),
    /// Mean Average Precision
    MAP,
    /// Normalized Discounted Cumulative Gain
    NDCG(usize),
    /// Mean Reciprocal Rank
    MRR,
    /// Expected Reciprocal Rank
    ERR,
    /// Click-through rate simulation
    CTR,
}

/// Per-query search results
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResults {
    /// Query text
    pub query: String,
    /// Precision scores at different K values
    pub precision_scores: HashMap<usize, f64>,
    /// Recall scores at different K values
    pub recall_scores: HashMap<usize, f64>,
    /// NDCG scores
    pub ndcg_scores: HashMap<usize, f64>,
    /// Number of relevant documents
    pub num_relevant: usize,
    /// Query difficulty score
    pub difficulty_score: f64,
}

/// Query performance analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPerformanceAnalysis {
    /// Average query length
    pub avg_query_length: f64,
    /// Query type distribution
    pub query_type_distribution: HashMap<String, usize>,
    /// Performance by query difficulty
    pub performance_by_difficulty: HashMap<String, f64>,
    /// Zero-result queries percentage
    pub zero_result_queries: f64,
}

/// Search effectiveness metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchEffectivenessMetrics {
    /// Overall search satisfaction
    pub search_satisfaction: f64,
    /// Result relevance distribution
    pub relevance_distribution: HashMap<u8, usize>,
    /// Search result diversity
    pub result_diversity: f64,
    /// Query success rate
    pub query_success_rate: f64,
}

/// Search evaluation results
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResults {
    /// Metric scores
    pub metric_scores: HashMap<String, f64>,
    /// Per-query results
    pub per_query_results: HashMap<String, QueryResults>,
    /// Query performance analysis
    pub query_analysis: QueryPerformanceAnalysis,
    /// Search effectiveness metrics
    pub effectiveness_metrics: SearchEffectivenessMetrics,
}

/// Search relevance evaluator
pub struct SearchEvaluator {
    /// Search queries and their relevance judgments
    query_relevance: HashMap<String, Vec<RelevanceJudgment>>,
    /// Search metrics to evaluate
    metrics: Vec<SearchMetric>,
}

impl SearchEvaluator {
    /// Create a new search evaluator
    pub fn new() -> Self {
        Self {
            query_relevance: HashMap::new(),
            metrics: vec![
                SearchMetric::PrecisionAtK(1),
                SearchMetric::PrecisionAtK(5),
                SearchMetric::PrecisionAtK(10),
                SearchMetric::NDCG(10),
                SearchMetric::MAP,
                SearchMetric::MRR,
            ],
        }
    }

    /// Add relevance judgment
    pub fn add_relevance_judgment(&mut self, judgment: RelevanceJudgment) {
        self.query_relevance
            .entry(judgment.query.clone())
            .or_default()
            .push(judgment);
    }

    /// Evaluate search relevance
    pub async fn evaluate(
        &self,
        model: &dyn EmbeddingModel,
        config: &ApplicationEvalConfig,
    ) -> Result<SearchResults> {
        let mut metric_scores = HashMap::new();
        let mut per_query_results = HashMap::new();

        // Sample queries for evaluation
        let queries_to_evaluate: Vec<_> = self
            .query_relevance
            .keys()
            .take(config.sample_size)
            .cloned()
            .collect();

        for query in &queries_to_evaluate {
            let query_results = self.evaluate_query_search(query, model).await?;
            per_query_results.insert(query.clone(), query_results);
        }

        // Calculate aggregate metrics
        for metric in &self.metrics {
            let score = self.calculate_search_metric(metric, &per_query_results)?;
            metric_scores.insert(format!("{metric:?}"), score);
        }

        // Analyze query performance
        let query_analysis = self.analyze_query_performance(&per_query_results)?;
        let effectiveness_metrics = self.calculate_effectiveness_metrics(&per_query_results)?;

        Ok(SearchResults {
            metric_scores,
            per_query_results,
            query_analysis,
            effectiveness_metrics,
        })
    }

    /// Evaluate search for a specific query
    async fn evaluate_query_search(
        &self,
        query: &str,
        model: &dyn EmbeddingModel,
    ) -> Result<QueryResults> {
        let judgments = self
            .query_relevance
            .get(query)
            .expect("query should exist in query_relevance");

        // Get search results (simplified - would use actual search system)
        let search_results = self.perform_search(query, model).await?;

        // Calculate relevance for each result
        let mut relevance_scores = Vec::new();
        for (doc_id, _score) in &search_results {
            let relevance = judgments
                .iter()
                .find(|j| &j.document_id == doc_id)
                .map(|j| j.relevance_score)
                .unwrap_or(0);
            relevance_scores.push(relevance);
        }

        let num_relevant = judgments.iter().filter(|j| j.relevance_score > 0).count();

        // Calculate metrics at different K values
        let mut precision_scores = HashMap::new();
        let mut recall_scores = HashMap::new();
        let mut ndcg_scores = HashMap::new();

        for &k in &[1, 3, 5, 10] {
            if k <= search_results.len() {
                let relevant_at_k =
                    relevance_scores.iter().take(k).filter(|&&r| r > 0).count() as f64;

                let precision = relevant_at_k / k as f64;
                let recall = if num_relevant > 0 {
                    relevant_at_k / num_relevant as f64
                } else {
                    0.0
                };

                precision_scores.insert(k, precision);
                recall_scores.insert(k, recall);

                // Calculate NDCG
                let ndcg = self.calculate_search_ndcg(&relevance_scores, k)?;
                ndcg_scores.insert(k, ndcg);
            }
        }

        let difficulty_score = self.calculate_query_difficulty(query, num_relevant);

        Ok(QueryResults {
            query: query.to_string(),
            precision_scores,
            recall_scores,
            ndcg_scores,
            num_relevant,
            difficulty_score,
        })
    }

    /// Perform search (simplified implementation)
    async fn perform_search(
        &self,
        query: &str,
        model: &dyn EmbeddingModel,
    ) -> Result<Vec<(String, f64)>> {
        // Create query embedding (simplified)
        let query_words: Vec<&str> = query.split_whitespace().collect();
        let mut query_embedding = vec![0.0f32; 100];

        // Simple word-based embedding (in practice, use proper query embedding)
        for (i, word) in query_words.iter().enumerate() {
            if i < query_embedding.len() {
                query_embedding[i] = word.len() as f32 / 10.0;
            }
        }
        let query_vector = Vector::new(query_embedding);

        // Score entities (documents) against query
        let entities = model.get_entities();
        let mut search_results = Vec::new();

        for entity in entities.iter().take(100) {
            // Limit for efficiency
            if let Ok(entity_embedding) = model.get_entity_embedding(entity) {
                let score = self.cosine_similarity(&query_vector, &entity_embedding);
                search_results.push((entity.clone(), score));
            }
        }

        // Sort by score and return top results
        search_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        search_results.truncate(20);

        Ok(search_results)
    }

    /// Calculate NDCG for search results
    fn calculate_search_ndcg(&self, relevance_scores: &[u8], k: usize) -> Result<f64> {
        if k == 0 || relevance_scores.is_empty() {
            return Ok(0.0);
        }

        let mut dcg = 0.0;
        for (i, &relevance) in relevance_scores.iter().take(k).enumerate() {
            if relevance > 0 {
                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
                dcg += gain / (i as f64 + 2.0).log2();
            }
        }

        // Calculate ideal DCG
        let mut ideal_relevance: Vec<u8> = relevance_scores.to_vec();
        ideal_relevance.sort_by(|a, b| b.cmp(a));

        let mut idcg = 0.0;
        for (i, &relevance) in ideal_relevance.iter().take(k).enumerate() {
            if relevance > 0 {
                let gain = (2_u32.pow(relevance as u32) - 1) as f64;
                idcg += gain / (i as f64 + 2.0).log2();
            }
        }

        if idcg > 0.0 {
            Ok(dcg / idcg)
        } else {
            Ok(0.0)
        }
    }

    /// Calculate query difficulty
    fn calculate_query_difficulty(&self, query: &str, num_relevant: usize) -> f64 {
        let query_length = query.split_whitespace().count() as f64;
        let relevance_factor = if num_relevant == 0 {
            1.0 // High difficulty
        } else {
            1.0 / (num_relevant as f64).log2()
        };

        (query_length * 0.1 + relevance_factor * 0.9).min(1.0)
    }

    /// Calculate aggregate search metric
    fn calculate_search_metric(
        &self,
        metric: &SearchMetric,
        per_query_results: &HashMap<String, QueryResults>,
    ) -> Result<f64> {
        if per_query_results.is_empty() {
            return Ok(0.0);
        }

        match metric {
            SearchMetric::PrecisionAtK(k) => {
                let scores: Vec<f64> = per_query_results
                    .values()
                    .filter_map(|r| r.precision_scores.get(k))
                    .cloned()
                    .collect();
                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
            }
            SearchMetric::NDCG(k) => {
                let scores: Vec<f64> = per_query_results
                    .values()
                    .filter_map(|r| r.ndcg_scores.get(k))
                    .cloned()
                    .collect();
                Ok(scores.iter().sum::<f64>() / scores.len() as f64)
            }
            _ => Ok(0.5), // Placeholder for other metrics
        }
    }

    /// Analyze query performance
    fn analyze_query_performance(
        &self,
        per_query_results: &HashMap<String, QueryResults>,
    ) -> Result<QueryPerformanceAnalysis> {
        let avg_query_length = per_query_results
            .keys()
            .map(|q| q.split_whitespace().count() as f64)
            .sum::<f64>()
            / per_query_results.len() as f64;

        let zero_result_queries = per_query_results
            .values()
            .filter(|r| r.num_relevant == 0)
            .count() as f64
            / per_query_results.len() as f64;

        Ok(QueryPerformanceAnalysis {
            avg_query_length,
            query_type_distribution: HashMap::new(), // Simplified
            performance_by_difficulty: HashMap::new(), // Simplified
            zero_result_queries,
        })
    }

    /// Calculate effectiveness metrics
    fn calculate_effectiveness_metrics(
        &self,
        per_query_results: &HashMap<String, QueryResults>,
    ) -> Result<SearchEffectivenessMetrics> {
        let successful_queries = per_query_results
            .values()
            .filter(|r| r.precision_scores.get(&1).unwrap_or(&0.0) > &0.0)
            .count() as f64;

        let query_success_rate = successful_queries / per_query_results.len() as f64;

        Ok(SearchEffectivenessMetrics {
            search_satisfaction: query_success_rate * 0.8, // Simplified
            relevance_distribution: HashMap::new(),        // Simplified
            result_diversity: 0.6,                         // Simplified
            query_success_rate,
        })
    }

    /// Calculate cosine similarity
    fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
        let dot_product: f32 = v1
            .values
            .iter()
            .zip(v2.values.iter())
            .map(|(a, b)| a * b)
            .sum();
        let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
        let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();

        if norm_a > 0.0 && norm_b > 0.0 {
            (dot_product / (norm_a * norm_b)) as f64
        } else {
            0.0
        }
    }
}

impl Default for SearchEvaluator {
    fn default() -> Self {
        Self::new()
    }
}