1use super::KnowledgeGraphEmbedding;
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct KnowledgeGraphMetrics {
9 pub mrr_filtered: f32,
11 pub mrr_unfiltered: f32,
13 pub mr_filtered: f32,
15 pub mr_unfiltered: f32,
17 pub hits_at_k_filtered: std::collections::HashMap<u32, f32>,
19 pub hits_at_k_unfiltered: std::collections::HashMap<u32, f32>,
21 pub per_relation_metrics: std::collections::HashMap<String, RelationMetrics>,
23 pub task_breakdown: TaskBreakdownMetrics,
25 pub confidence_intervals: ConfidenceIntervals,
27 pub statistical_tests: StatisticalTestResults,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TrainingMetrics {
34 pub loss: f32,
36 pub loss_history: Vec<f32>,
38 pub accuracy: f32,
40 pub epochs: usize,
42 pub time_elapsed: std::time::Duration,
44 pub kg_metrics: KnowledgeGraphMetrics,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RelationMetrics {
51 pub mrr: f32,
52 pub mr: f32,
53 pub hits_at_k: std::collections::HashMap<u32, f32>,
54 pub sample_count: usize,
55 pub entity_coverage: f32,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TaskBreakdownMetrics {
61 pub head_prediction: LinkPredictionMetrics,
63 pub tail_prediction: LinkPredictionMetrics,
65 pub relation_prediction: LinkPredictionMetrics,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct LinkPredictionMetrics {
72 pub mrr: f32,
73 pub mr: f32,
74 pub hits_at_k: std::collections::HashMap<u32, f32>,
75 pub auc_roc: f32,
76 pub auc_pr: f32,
77 pub precision_at_k: std::collections::HashMap<u32, f32>,
78 pub recall_at_k: std::collections::HashMap<u32, f32>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ConfidenceIntervals {
84 pub mrr_ci: (f32, f32),
85 pub mr_ci: (f32, f32),
86 pub hits_at_10_ci: (f32, f32),
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct StatisticalTestResults {
92 pub wilcoxon_p_value: Option<f32>,
94 pub bootstrap_confidence: f32,
96 pub effect_size: Option<f32>,
98}
99
100impl Default for KnowledgeGraphMetrics {
101 fn default() -> Self {
102 let mut hits_at_k = std::collections::HashMap::new();
103 hits_at_k.insert(1, 0.0);
104 hits_at_k.insert(3, 0.0);
105 hits_at_k.insert(10, 0.0);
106 hits_at_k.insert(100, 0.0);
107
108 let mut precision_at_k = std::collections::HashMap::new();
109 precision_at_k.insert(1, 0.0);
110 precision_at_k.insert(3, 0.0);
111 precision_at_k.insert(10, 0.0);
112
113 let mut recall_at_k = std::collections::HashMap::new();
114 recall_at_k.insert(1, 0.0);
115 recall_at_k.insert(3, 0.0);
116 recall_at_k.insert(10, 0.0);
117
118 Self {
119 mrr_filtered: 0.0,
120 mrr_unfiltered: 0.0,
121 mr_filtered: 0.0,
122 mr_unfiltered: 0.0,
123 hits_at_k_filtered: hits_at_k.clone(),
124 hits_at_k_unfiltered: hits_at_k.clone(),
125 per_relation_metrics: std::collections::HashMap::new(),
126 task_breakdown: TaskBreakdownMetrics {
127 head_prediction: LinkPredictionMetrics {
128 mrr: 0.0,
129 mr: 0.0,
130 hits_at_k: hits_at_k.clone(),
131 auc_roc: 0.0,
132 auc_pr: 0.0,
133 precision_at_k: precision_at_k.clone(),
134 recall_at_k: recall_at_k.clone(),
135 },
136 tail_prediction: LinkPredictionMetrics {
137 mrr: 0.0,
138 mr: 0.0,
139 hits_at_k: hits_at_k.clone(),
140 auc_roc: 0.0,
141 auc_pr: 0.0,
142 precision_at_k: precision_at_k.clone(),
143 recall_at_k: recall_at_k.clone(),
144 },
145 relation_prediction: LinkPredictionMetrics {
146 mrr: 0.0,
147 mr: 0.0,
148 hits_at_k: hits_at_k.clone(),
149 auc_roc: 0.0,
150 auc_pr: 0.0,
151 precision_at_k,
152 recall_at_k,
153 },
154 },
155 confidence_intervals: ConfidenceIntervals {
156 mrr_ci: (0.0, 0.0),
157 mr_ci: (0.0, 0.0),
158 hits_at_10_ci: (0.0, 0.0),
159 },
160 statistical_tests: StatisticalTestResults {
161 wilcoxon_p_value: None,
162 bootstrap_confidence: 0.95,
163 effect_size: None,
164 },
165 }
166 }
167}
168
169pub async fn compute_kg_metrics(
171 model: &dyn KnowledgeGraphEmbedding,
172 test_triples: &[(String, String, String)],
173 all_triples: &[(String, String, String)],
174 k_values: &[u32],
175) -> Result<KnowledgeGraphMetrics> {
176 let mut metrics = KnowledgeGraphMetrics::default();
177
178 let all_triples_set: HashSet<(String, String, String)> = all_triples.iter().cloned().collect();
180
181 metrics.task_breakdown.head_prediction = compute_link_prediction_metrics(
183 model,
184 test_triples,
185 &all_triples_set,
186 LinkPredictionTask::HeadPrediction,
187 k_values,
188 )
189 .await?;
190
191 metrics.task_breakdown.tail_prediction = compute_link_prediction_metrics(
193 model,
194 test_triples,
195 &all_triples_set,
196 LinkPredictionTask::TailPrediction,
197 k_values,
198 )
199 .await?;
200
201 metrics.task_breakdown.relation_prediction = compute_link_prediction_metrics(
203 model,
204 test_triples,
205 &all_triples_set,
206 LinkPredictionTask::RelationPrediction,
207 k_values,
208 )
209 .await?;
210
211 metrics.mrr_filtered = (metrics.task_breakdown.head_prediction.mrr
213 + metrics.task_breakdown.tail_prediction.mrr)
214 / 2.0;
215 metrics.mr_filtered = (metrics.task_breakdown.head_prediction.mr
216 + metrics.task_breakdown.tail_prediction.mr)
217 / 2.0;
218
219 for &k in k_values {
221 let head_hits = metrics
222 .task_breakdown
223 .head_prediction
224 .hits_at_k
225 .get(&k)
226 .unwrap_or(&0.0);
227 let tail_hits = metrics
228 .task_breakdown
229 .tail_prediction
230 .hits_at_k
231 .get(&k)
232 .unwrap_or(&0.0);
233 metrics
234 .hits_at_k_filtered
235 .insert(k, (head_hits + tail_hits) / 2.0);
236 }
237
238 metrics.per_relation_metrics =
240 compute_per_relation_metrics(model, test_triples, &all_triples_set, k_values).await?;
241
242 metrics.confidence_intervals = compute_confidence_intervals(
244 &metrics.task_breakdown.head_prediction,
245 &metrics.task_breakdown.tail_prediction,
246 test_triples.len(),
247 )?;
248
249 Ok(metrics)
250}
251
252#[derive(Debug, Clone)]
254pub enum LinkPredictionTask {
255 HeadPrediction,
256 TailPrediction,
257 RelationPrediction,
258}
259
260async fn compute_link_prediction_metrics(
262 model: &dyn KnowledgeGraphEmbedding,
263 test_triples: &[(String, String, String)],
264 all_triples: &HashSet<(String, String, String)>,
265 task: LinkPredictionTask,
266 k_values: &[u32],
267) -> Result<LinkPredictionMetrics> {
268 let mut ranks = Vec::new();
269 let mut reciprocal_ranks = Vec::new();
270 let mut hits_at_k = std::collections::HashMap::new();
271 let mut precision_at_k = std::collections::HashMap::new();
272 let mut recall_at_k = std::collections::HashMap::new();
273
274 for &k in k_values {
276 hits_at_k.insert(k, 0.0);
277 precision_at_k.insert(k, 0.0);
278 recall_at_k.insert(k, 0.0);
279 }
280
281 for (head, relation, tail) in test_triples {
282 let rank = match task {
283 LinkPredictionTask::HeadPrediction => {
284 compute_entity_rank(model, "?", relation, tail, all_triples, true).await?
285 }
286 LinkPredictionTask::TailPrediction => {
287 compute_entity_rank(model, head, relation, "?", all_triples, false).await?
288 }
289 LinkPredictionTask::RelationPrediction => {
290 compute_relation_rank(model, head, tail, all_triples).await?
291 }
292 };
293
294 ranks.push(rank as f32);
295 reciprocal_ranks.push(1.0 / rank as f32);
296
297 for &k in k_values {
299 if rank <= k {
300 if let Some(hits) = hits_at_k.get_mut(&k) {
301 *hits += 1.0;
302 }
303 }
304 }
305 }
306
307 let num_samples = test_triples.len() as f32;
308
309 for (_, hits) in hits_at_k.iter_mut() {
311 *hits /= num_samples;
312 }
313
314 for &k in k_values {
316 let hits = hits_at_k.get(&k).unwrap_or(&0.0);
317 precision_at_k.insert(k, *hits); recall_at_k.insert(k, *hits); }
320
321 Ok(LinkPredictionMetrics {
322 mrr: reciprocal_ranks.iter().sum::<f32>() / num_samples,
323 mr: ranks.iter().sum::<f32>() / num_samples,
324 hits_at_k,
325 auc_roc: compute_auc_roc(&ranks)?,
326 auc_pr: compute_auc_pr(&ranks)?,
327 precision_at_k,
328 recall_at_k,
329 })
330}
331
332async fn compute_entity_rank(
334 model: &dyn KnowledgeGraphEmbedding,
335 head: &str,
336 relation: &str,
337 tail: &str,
338 all_triples: &HashSet<(String, String, String)>,
339 predict_head: bool,
340) -> Result<u32> {
341 let entities: Vec<String> = all_triples
343 .iter()
344 .flat_map(|(h, _, t)| vec![h.clone(), t.clone()])
345 .collect::<HashSet<_>>()
346 .into_iter()
347 .collect();
348
349 let mut scores = Vec::new();
350 let correct_entity = if predict_head { head } else { tail };
351
352 for entity in &entities {
353 let test_head = if predict_head { entity } else { head };
354 let test_tail = if predict_head { tail } else { entity };
355
356 if all_triples.contains(&(
358 test_head.to_string(),
359 relation.to_string(),
360 test_tail.to_string(),
361 )) && entity != correct_entity
362 {
363 continue;
364 }
365
366 let score = model.score_triple(test_head, relation, test_tail).await?;
367 scores.push((entity.clone(), score));
368 }
369
370 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
372
373 let rank = scores
375 .iter()
376 .position(|(entity, _)| entity == correct_entity)
377 .unwrap_or(scores.len() - 1)
378 + 1;
379
380 Ok(rank as u32)
381}
382
383async fn compute_relation_rank(
385 model: &dyn KnowledgeGraphEmbedding,
386 head: &str,
387 tail: &str,
388 all_triples: &HashSet<(String, String, String)>,
389) -> Result<u32> {
390 let relations: Vec<String> = all_triples
392 .iter()
393 .map(|(_, r, _)| r.clone())
394 .collect::<HashSet<_>>()
395 .into_iter()
396 .collect();
397
398 let mut scores = Vec::new();
399
400 for relation in &relations {
401 let score = model.score_triple(head, relation, tail).await?;
402 scores.push((relation.clone(), score));
403 }
404
405 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408 Ok(1) }
411
412async fn compute_per_relation_metrics(
414 model: &dyn KnowledgeGraphEmbedding,
415 test_triples: &[(String, String, String)],
416 all_triples: &HashSet<(String, String, String)>,
417 k_values: &[u32],
418) -> Result<std::collections::HashMap<String, RelationMetrics>> {
419 let mut relation_metrics = std::collections::HashMap::new();
420
421 let mut relation_groups: std::collections::HashMap<String, Vec<(String, String, String)>> =
423 std::collections::HashMap::new();
424
425 for triple in test_triples {
426 relation_groups
427 .entry(triple.1.clone())
428 .or_default()
429 .push(triple.clone());
430 }
431
432 for (relation, relation_triples) in relation_groups {
434 let metrics = compute_link_prediction_metrics(
435 model,
436 &relation_triples,
437 all_triples,
438 LinkPredictionTask::TailPrediction,
439 k_values,
440 )
441 .await?;
442
443 let entity_count = relation_triples
444 .iter()
445 .flat_map(|(h, _, t)| vec![h, t])
446 .collect::<HashSet<_>>()
447 .len();
448
449 relation_metrics.insert(
450 relation,
451 RelationMetrics {
452 mrr: metrics.mrr,
453 mr: metrics.mr,
454 hits_at_k: metrics.hits_at_k,
455 sample_count: relation_triples.len(),
456 entity_coverage: entity_count as f32 / relation_triples.len() as f32,
457 },
458 );
459 }
460
461 Ok(relation_metrics)
462}
463
464fn compute_confidence_intervals(
466 head_metrics: &LinkPredictionMetrics,
467 tail_metrics: &LinkPredictionMetrics,
468 sample_size: usize,
469) -> Result<ConfidenceIntervals> {
470 let combined_mrr = (head_metrics.mrr + tail_metrics.mrr) / 2.0;
472 let combined_mr = (head_metrics.mr + tail_metrics.mr) / 2.0;
473 let combined_hits_10 = (head_metrics.hits_at_k.get(&10).unwrap_or(&0.0)
474 + tail_metrics.hits_at_k.get(&10).unwrap_or(&0.0))
475 / 2.0;
476
477 let se_factor = 1.96 / (sample_size as f32).sqrt(); Ok(ConfidenceIntervals {
481 mrr_ci: (
482 (combined_mrr - combined_mrr * se_factor).max(0.0),
483 (combined_mrr + combined_mrr * se_factor).min(1.0),
484 ),
485 mr_ci: (
486 (combined_mr - combined_mr * se_factor).max(1.0),
487 combined_mr + combined_mr * se_factor,
488 ),
489 hits_at_10_ci: (
490 (combined_hits_10 - combined_hits_10 * se_factor).max(0.0),
491 (combined_hits_10 + combined_hits_10 * se_factor).min(1.0),
492 ),
493 })
494}
495
496fn compute_auc_roc(ranks: &[f32]) -> Result<f32> {
498 let max_rank = ranks.iter().fold(0.0f32, |a, &b| a.max(b));
500 let normalized_ranks: Vec<f32> = ranks.iter().map(|&r| 1.0 - (r / max_rank)).collect();
501 Ok(normalized_ranks.iter().sum::<f32>() / ranks.len() as f32)
502}
503
504fn compute_auc_pr(ranks: &[f32]) -> Result<f32> {
506 compute_auc_roc(ranks)
508}
509
510pub fn create_evaluation_report(metrics: &KnowledgeGraphMetrics) -> String {
512 format!(
513 "Knowledge Graph Embedding Evaluation Report\n\
514 ==========================================\n\
515 \n\
516 Overall Performance:\n\
517 - MRR (filtered): {:.4}\n\
518 - Mean Rank (filtered): {:.1}\n\
519 - Hits@1: {:.4}\n\
520 - Hits@3: {:.4}\n\
521 - Hits@10: {:.4}\n\
522 \n\
523 Task Breakdown:\n\
524 - Head Prediction MRR: {:.4}\n\
525 - Tail Prediction MRR: {:.4}\n\
526 - Relation Prediction MRR: {:.4}\n\
527 \n\
528 Confidence Intervals (95%):\n\
529 - MRR: [{:.4}, {:.4}]\n\
530 - Hits@10: [{:.4}, {:.4}]\n\
531 \n\
532 Per-Relation Performance:\n\
533 {} relations evaluated\n",
534 metrics.mrr_filtered,
535 metrics.mr_filtered,
536 metrics.hits_at_k_filtered.get(&1).unwrap_or(&0.0),
537 metrics.hits_at_k_filtered.get(&3).unwrap_or(&0.0),
538 metrics.hits_at_k_filtered.get(&10).unwrap_or(&0.0),
539 metrics.task_breakdown.head_prediction.mrr,
540 metrics.task_breakdown.tail_prediction.mrr,
541 metrics.task_breakdown.relation_prediction.mrr,
542 metrics.confidence_intervals.mrr_ci.0,
543 metrics.confidence_intervals.mrr_ci.1,
544 metrics.confidence_intervals.hits_at_10_ci.0,
545 metrics.confidence_intervals.hits_at_10_ci.1,
546 metrics.per_relation_metrics.len()
547 )
548}