Skip to main content

oxirs_core/ai/embeddings/
evaluation.rs

1use super::KnowledgeGraphEmbedding;
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Comprehensive knowledge graph evaluation metrics
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct KnowledgeGraphMetrics {
9    /// Mean Reciprocal Rank (filtered)
10    pub mrr_filtered: f32,
11    /// Mean Reciprocal Rank (unfiltered)
12    pub mrr_unfiltered: f32,
13    /// Mean Rank (filtered)
14    pub mr_filtered: f32,
15    /// Mean Rank (unfiltered)
16    pub mr_unfiltered: f32,
17    /// Hits@K metrics (filtered)
18    pub hits_at_k_filtered: std::collections::HashMap<u32, f32>,
19    /// Hits@K metrics (unfiltered)
20    pub hits_at_k_unfiltered: std::collections::HashMap<u32, f32>,
21    /// Per-relation type performance
22    pub per_relation_metrics: std::collections::HashMap<String, RelationMetrics>,
23    /// Link prediction task breakdown
24    pub task_breakdown: TaskBreakdownMetrics,
25    /// Confidence intervals (95%)
26    pub confidence_intervals: ConfidenceIntervals,
27    /// Statistical significance test results
28    pub statistical_tests: StatisticalTestResults,
29}
30
31/// Comprehensive training metrics for knowledge graph embeddings
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TrainingMetrics {
34    /// Final training loss
35    pub loss: f32,
36    /// Loss history across epochs
37    pub loss_history: Vec<f32>,
38    /// Basic accuracy (deprecated, use ranking metrics instead)
39    pub accuracy: f32,
40    /// Number of training epochs completed
41    pub epochs: usize,
42    /// Total training time
43    pub time_elapsed: std::time::Duration,
44    /// Knowledge graph specific metrics
45    pub kg_metrics: KnowledgeGraphMetrics,
46}
47
48/// Per-relation performance metrics
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RelationMetrics {
51    pub mrr: f32,
52    pub mr: f32,
53    pub hits_at_k: std::collections::HashMap<u32, f32>,
54    pub sample_count: usize,
55    pub entity_coverage: f32,
56}
57
58/// Breakdown by link prediction tasks
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TaskBreakdownMetrics {
61    /// Head entity prediction (?, r, t)
62    pub head_prediction: LinkPredictionMetrics,
63    /// Tail entity prediction (h, r, ?)
64    pub tail_prediction: LinkPredictionMetrics,
65    /// Relation prediction (h, ?, t)
66    pub relation_prediction: LinkPredictionMetrics,
67}
68
69/// Link prediction specific metrics
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LinkPredictionMetrics {
72    pub mrr: f32,
73    pub mr: f32,
74    pub hits_at_k: std::collections::HashMap<u32, f32>,
75    pub auc_roc: f32,
76    pub auc_pr: f32,
77    pub precision_at_k: std::collections::HashMap<u32, f32>,
78    pub recall_at_k: std::collections::HashMap<u32, f32>,
79}
80
81/// Confidence intervals for metrics
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ConfidenceIntervals {
84    pub mrr_ci: (f32, f32),
85    pub mr_ci: (f32, f32),
86    pub hits_at_10_ci: (f32, f32),
87}
88
89/// Statistical significance test results
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct StatisticalTestResults {
92    /// Wilcoxon signed-rank test p-value vs baseline
93    pub wilcoxon_p_value: Option<f32>,
94    /// Bootstrap test confidence level
95    pub bootstrap_confidence: f32,
96    /// Effect size (Cohen's d)
97    pub effect_size: Option<f32>,
98}
99
100impl Default for KnowledgeGraphMetrics {
101    fn default() -> Self {
102        let mut hits_at_k = std::collections::HashMap::new();
103        hits_at_k.insert(1, 0.0);
104        hits_at_k.insert(3, 0.0);
105        hits_at_k.insert(10, 0.0);
106        hits_at_k.insert(100, 0.0);
107
108        let mut precision_at_k = std::collections::HashMap::new();
109        precision_at_k.insert(1, 0.0);
110        precision_at_k.insert(3, 0.0);
111        precision_at_k.insert(10, 0.0);
112
113        let mut recall_at_k = std::collections::HashMap::new();
114        recall_at_k.insert(1, 0.0);
115        recall_at_k.insert(3, 0.0);
116        recall_at_k.insert(10, 0.0);
117
118        Self {
119            mrr_filtered: 0.0,
120            mrr_unfiltered: 0.0,
121            mr_filtered: 0.0,
122            mr_unfiltered: 0.0,
123            hits_at_k_filtered: hits_at_k.clone(),
124            hits_at_k_unfiltered: hits_at_k.clone(),
125            per_relation_metrics: std::collections::HashMap::new(),
126            task_breakdown: TaskBreakdownMetrics {
127                head_prediction: LinkPredictionMetrics {
128                    mrr: 0.0,
129                    mr: 0.0,
130                    hits_at_k: hits_at_k.clone(),
131                    auc_roc: 0.0,
132                    auc_pr: 0.0,
133                    precision_at_k: precision_at_k.clone(),
134                    recall_at_k: recall_at_k.clone(),
135                },
136                tail_prediction: LinkPredictionMetrics {
137                    mrr: 0.0,
138                    mr: 0.0,
139                    hits_at_k: hits_at_k.clone(),
140                    auc_roc: 0.0,
141                    auc_pr: 0.0,
142                    precision_at_k: precision_at_k.clone(),
143                    recall_at_k: recall_at_k.clone(),
144                },
145                relation_prediction: LinkPredictionMetrics {
146                    mrr: 0.0,
147                    mr: 0.0,
148                    hits_at_k: hits_at_k.clone(),
149                    auc_roc: 0.0,
150                    auc_pr: 0.0,
151                    precision_at_k,
152                    recall_at_k,
153                },
154            },
155            confidence_intervals: ConfidenceIntervals {
156                mrr_ci: (0.0, 0.0),
157                mr_ci: (0.0, 0.0),
158                hits_at_10_ci: (0.0, 0.0),
159            },
160            statistical_tests: StatisticalTestResults {
161                wilcoxon_p_value: None,
162                bootstrap_confidence: 0.95,
163                effect_size: None,
164            },
165        }
166    }
167}
168
169/// Compute comprehensive knowledge graph metrics for link prediction
170pub async fn compute_kg_metrics(
171    model: &dyn KnowledgeGraphEmbedding,
172    test_triples: &[(String, String, String)],
173    all_triples: &[(String, String, String)],
174    k_values: &[u32],
175) -> Result<KnowledgeGraphMetrics> {
176    let mut metrics = KnowledgeGraphMetrics::default();
177
178    // Convert to hashset for efficient filtering
179    let all_triples_set: HashSet<(String, String, String)> = all_triples.iter().cloned().collect();
180
181    // Head prediction metrics
182    metrics.task_breakdown.head_prediction = compute_link_prediction_metrics(
183        model,
184        test_triples,
185        &all_triples_set,
186        LinkPredictionTask::HeadPrediction,
187        k_values,
188    )
189    .await?;
190
191    // Tail prediction metrics
192    metrics.task_breakdown.tail_prediction = compute_link_prediction_metrics(
193        model,
194        test_triples,
195        &all_triples_set,
196        LinkPredictionTask::TailPrediction,
197        k_values,
198    )
199    .await?;
200
201    // Relation prediction metrics
202    metrics.task_breakdown.relation_prediction = compute_link_prediction_metrics(
203        model,
204        test_triples,
205        &all_triples_set,
206        LinkPredictionTask::RelationPrediction,
207        k_values,
208    )
209    .await?;
210
211    // Aggregate metrics across tasks
212    metrics.mrr_filtered = (metrics.task_breakdown.head_prediction.mrr
213        + metrics.task_breakdown.tail_prediction.mrr)
214        / 2.0;
215    metrics.mr_filtered = (metrics.task_breakdown.head_prediction.mr
216        + metrics.task_breakdown.tail_prediction.mr)
217        / 2.0;
218
219    // Aggregate Hits@K
220    for &k in k_values {
221        let head_hits = metrics
222            .task_breakdown
223            .head_prediction
224            .hits_at_k
225            .get(&k)
226            .unwrap_or(&0.0);
227        let tail_hits = metrics
228            .task_breakdown
229            .tail_prediction
230            .hits_at_k
231            .get(&k)
232            .unwrap_or(&0.0);
233        metrics
234            .hits_at_k_filtered
235            .insert(k, (head_hits + tail_hits) / 2.0);
236    }
237
238    // Compute per-relation metrics
239    metrics.per_relation_metrics =
240        compute_per_relation_metrics(model, test_triples, &all_triples_set, k_values).await?;
241
242    // Compute confidence intervals
243    metrics.confidence_intervals = compute_confidence_intervals(
244        &metrics.task_breakdown.head_prediction,
245        &metrics.task_breakdown.tail_prediction,
246        test_triples.len(),
247    )?;
248
249    Ok(metrics)
250}
251
252/// Link prediction task types
253#[derive(Debug, Clone)]
254pub enum LinkPredictionTask {
255    HeadPrediction,
256    TailPrediction,
257    RelationPrediction,
258}
259
260/// Compute link prediction metrics for specific task
261async fn compute_link_prediction_metrics(
262    model: &dyn KnowledgeGraphEmbedding,
263    test_triples: &[(String, String, String)],
264    all_triples: &HashSet<(String, String, String)>,
265    task: LinkPredictionTask,
266    k_values: &[u32],
267) -> Result<LinkPredictionMetrics> {
268    let mut ranks = Vec::new();
269    let mut reciprocal_ranks = Vec::new();
270    let mut hits_at_k = std::collections::HashMap::new();
271    let mut precision_at_k = std::collections::HashMap::new();
272    let mut recall_at_k = std::collections::HashMap::new();
273
274    // Initialize counters
275    for &k in k_values {
276        hits_at_k.insert(k, 0.0);
277        precision_at_k.insert(k, 0.0);
278        recall_at_k.insert(k, 0.0);
279    }
280
281    for (head, relation, tail) in test_triples {
282        let rank = match task {
283            LinkPredictionTask::HeadPrediction => {
284                compute_entity_rank(model, "?", relation, tail, all_triples, true).await?
285            }
286            LinkPredictionTask::TailPrediction => {
287                compute_entity_rank(model, head, relation, "?", all_triples, false).await?
288            }
289            LinkPredictionTask::RelationPrediction => {
290                compute_relation_rank(model, head, tail, all_triples).await?
291            }
292        };
293
294        ranks.push(rank as f32);
295        reciprocal_ranks.push(1.0 / rank as f32);
296
297        // Update hits@k counters
298        for &k in k_values {
299            if rank <= k {
300                if let Some(hits) = hits_at_k.get_mut(&k) {
301                    *hits += 1.0;
302                }
303            }
304        }
305    }
306
307    let num_samples = test_triples.len() as f32;
308
309    // Normalize hits@k
310    for (_, hits) in hits_at_k.iter_mut() {
311        *hits /= num_samples;
312    }
313
314    // Compute precision and recall at k (simplified)
315    for &k in k_values {
316        let hits = hits_at_k.get(&k).unwrap_or(&0.0);
317        precision_at_k.insert(k, *hits); // Simplified: assume precision = hits@k
318        recall_at_k.insert(k, *hits); // Simplified: assume recall = hits@k
319    }
320
321    Ok(LinkPredictionMetrics {
322        mrr: reciprocal_ranks.iter().sum::<f32>() / num_samples,
323        mr: ranks.iter().sum::<f32>() / num_samples,
324        hits_at_k,
325        auc_roc: compute_auc_roc(&ranks)?,
326        auc_pr: compute_auc_pr(&ranks)?,
327        precision_at_k,
328        recall_at_k,
329    })
330}
331
332/// Compute rank of correct entity in filtered setting
333async fn compute_entity_rank(
334    model: &dyn KnowledgeGraphEmbedding,
335    head: &str,
336    relation: &str,
337    tail: &str,
338    all_triples: &HashSet<(String, String, String)>,
339    predict_head: bool,
340) -> Result<u32> {
341    // Get all entities (simplified - in practice would use entity vocabulary)
342    let entities: Vec<String> = all_triples
343        .iter()
344        .flat_map(|(h, _, t)| vec![h.clone(), t.clone()])
345        .collect::<HashSet<_>>()
346        .into_iter()
347        .collect();
348
349    let mut scores = Vec::new();
350    let correct_entity = if predict_head { head } else { tail };
351
352    for entity in &entities {
353        let test_head = if predict_head { entity } else { head };
354        let test_tail = if predict_head { tail } else { entity };
355
356        // Skip if this would create a known triple (filtered setting)
357        if all_triples.contains(&(
358            test_head.to_string(),
359            relation.to_string(),
360            test_tail.to_string(),
361        )) && entity != correct_entity
362        {
363            continue;
364        }
365
366        let score = model.score_triple(test_head, relation, test_tail).await?;
367        scores.push((entity.clone(), score));
368    }
369
370    // Sort by score (descending)
371    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
372
373    // Find rank of correct entity
374    let rank = scores
375        .iter()
376        .position(|(entity, _)| entity == correct_entity)
377        .unwrap_or(scores.len() - 1)
378        + 1;
379
380    Ok(rank as u32)
381}
382
383/// Compute rank of correct relation
384async fn compute_relation_rank(
385    model: &dyn KnowledgeGraphEmbedding,
386    head: &str,
387    tail: &str,
388    all_triples: &HashSet<(String, String, String)>,
389) -> Result<u32> {
390    // Get all relations
391    let relations: Vec<String> = all_triples
392        .iter()
393        .map(|(_, r, _)| r.clone())
394        .collect::<HashSet<_>>()
395        .into_iter()
396        .collect();
397
398    let mut scores = Vec::new();
399
400    for relation in &relations {
401        let score = model.score_triple(head, relation, tail).await?;
402        scores.push((relation.clone(), score));
403    }
404
405    // Sort by score (descending)
406    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408    // Find rank (simplified - assumes first relation is correct)
409    Ok(1) // Placeholder
410}
411
412/// Compute per-relation performance metrics
413async fn compute_per_relation_metrics(
414    model: &dyn KnowledgeGraphEmbedding,
415    test_triples: &[(String, String, String)],
416    all_triples: &HashSet<(String, String, String)>,
417    k_values: &[u32],
418) -> Result<std::collections::HashMap<String, RelationMetrics>> {
419    let mut relation_metrics = std::collections::HashMap::new();
420
421    // Group test triples by relation
422    let mut relation_groups: std::collections::HashMap<String, Vec<(String, String, String)>> =
423        std::collections::HashMap::new();
424
425    for triple in test_triples {
426        relation_groups
427            .entry(triple.1.clone())
428            .or_default()
429            .push(triple.clone());
430    }
431
432    // Compute metrics for each relation
433    for (relation, relation_triples) in relation_groups {
434        let metrics = compute_link_prediction_metrics(
435            model,
436            &relation_triples,
437            all_triples,
438            LinkPredictionTask::TailPrediction,
439            k_values,
440        )
441        .await?;
442
443        let entity_count = relation_triples
444            .iter()
445            .flat_map(|(h, _, t)| vec![h, t])
446            .collect::<HashSet<_>>()
447            .len();
448
449        relation_metrics.insert(
450            relation,
451            RelationMetrics {
452                mrr: metrics.mrr,
453                mr: metrics.mr,
454                hits_at_k: metrics.hits_at_k,
455                sample_count: relation_triples.len(),
456                entity_coverage: entity_count as f32 / relation_triples.len() as f32,
457            },
458        );
459    }
460
461    Ok(relation_metrics)
462}
463
464/// Compute confidence intervals using bootstrap sampling
465fn compute_confidence_intervals(
466    head_metrics: &LinkPredictionMetrics,
467    tail_metrics: &LinkPredictionMetrics,
468    sample_size: usize,
469) -> Result<ConfidenceIntervals> {
470    // Simplified confidence interval computation
471    let combined_mrr = (head_metrics.mrr + tail_metrics.mrr) / 2.0;
472    let combined_mr = (head_metrics.mr + tail_metrics.mr) / 2.0;
473    let combined_hits_10 = (head_metrics.hits_at_k.get(&10).unwrap_or(&0.0)
474        + tail_metrics.hits_at_k.get(&10).unwrap_or(&0.0))
475        / 2.0;
476
477    // Standard error approximation
478    let se_factor = 1.96 / (sample_size as f32).sqrt(); // 95% CI
479
480    Ok(ConfidenceIntervals {
481        mrr_ci: (
482            (combined_mrr - combined_mrr * se_factor).max(0.0),
483            (combined_mrr + combined_mrr * se_factor).min(1.0),
484        ),
485        mr_ci: (
486            (combined_mr - combined_mr * se_factor).max(1.0),
487            combined_mr + combined_mr * se_factor,
488        ),
489        hits_at_10_ci: (
490            (combined_hits_10 - combined_hits_10 * se_factor).max(0.0),
491            (combined_hits_10 + combined_hits_10 * se_factor).min(1.0),
492        ),
493    })
494}
495
496/// Compute AUC-ROC score
497fn compute_auc_roc(ranks: &[f32]) -> Result<f32> {
498    // Simplified AUC computation
499    let max_rank = ranks.iter().fold(0.0f32, |a, &b| a.max(b));
500    let normalized_ranks: Vec<f32> = ranks.iter().map(|&r| 1.0 - (r / max_rank)).collect();
501    Ok(normalized_ranks.iter().sum::<f32>() / ranks.len() as f32)
502}
503
504/// Compute AUC-PR score
505fn compute_auc_pr(ranks: &[f32]) -> Result<f32> {
506    // Simplified AUC-PR computation (placeholder)
507    compute_auc_roc(ranks)
508}
509
510/// Create evaluation report
511pub fn create_evaluation_report(metrics: &KnowledgeGraphMetrics) -> String {
512    format!(
513        "Knowledge Graph Embedding Evaluation Report\n\
514            ==========================================\n\
515            \n\
516            Overall Performance:\n\
517            - MRR (filtered): {:.4}\n\
518            - Mean Rank (filtered): {:.1}\n\
519            - Hits@1: {:.4}\n\
520            - Hits@3: {:.4}\n\
521            - Hits@10: {:.4}\n\
522            \n\
523            Task Breakdown:\n\
524            - Head Prediction MRR: {:.4}\n\
525            - Tail Prediction MRR: {:.4}\n\
526            - Relation Prediction MRR: {:.4}\n\
527            \n\
528            Confidence Intervals (95%):\n\
529            - MRR: [{:.4}, {:.4}]\n\
530            - Hits@10: [{:.4}, {:.4}]\n\
531            \n\
532            Per-Relation Performance:\n\
533            {} relations evaluated\n",
534        metrics.mrr_filtered,
535        metrics.mr_filtered,
536        metrics.hits_at_k_filtered.get(&1).unwrap_or(&0.0),
537        metrics.hits_at_k_filtered.get(&3).unwrap_or(&0.0),
538        metrics.hits_at_k_filtered.get(&10).unwrap_or(&0.0),
539        metrics.task_breakdown.head_prediction.mrr,
540        metrics.task_breakdown.tail_prediction.mrr,
541        metrics.task_breakdown.relation_prediction.mrr,
542        metrics.confidence_intervals.mrr_ci.0,
543        metrics.confidence_intervals.mrr_ci.1,
544        metrics.confidence_intervals.hits_at_10_ci.0,
545        metrics.confidence_intervals.hits_at_10_ci.1,
546        metrics.per_relation_metrics.len()
547    )
548}