oxirs_embed/link_prediction.rs
1//! Link Prediction for Knowledge Graph Completion
2//!
3//! This module provides comprehensive link prediction functionality for knowledge graphs,
4//! enabling entity prediction, relation prediction, ranking, and evaluation capabilities.
5//!
6//! Link prediction is a fundamental task in knowledge graph completion where we predict
7//! missing links (triples) based on learned embeddings. This module supports multiple
8//! prediction tasks and provides evaluation metrics following standard benchmarks.
9//!
10//! # Overview
11//!
12//! The module provides:
13//! - **Tail entity prediction**: Given (head, relation, ?), predict the tail entity
14//! - **Head entity prediction**: Given (?, relation, tail), predict the head entity
15//! - **Relation prediction**: Given (head, ?, tail), predict the relation
16//! - **Batch prediction**: Process multiple queries efficiently in parallel
17//! - **Evaluation metrics**: MRR, Hits@K, Mean Rank for benchmarking
18//! - **Filtered ranking**: Remove known triples from evaluation
19//!
20//! # Quick Start
21//!
22//! ```rust,no_run
23//! use oxirs_embed::{
24//! TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
25//! link_prediction::{LinkPredictor, LinkPredictionConfig},
26//! };
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train a knowledge graph embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! // Add training triples
34//! model.add_triple(Triple::new(
35//! NamedNode::new("alice")?,
36//! NamedNode::new("knows")?,
37//! NamedNode::new("bob")?,
38//! ))?;
39//! model.add_triple(Triple::new(
40//! NamedNode::new("bob")?,
41//! NamedNode::new("knows")?,
42//! NamedNode::new("charlie")?,
43//! ))?;
44//!
45//! // Train the model
46//! model.train(Some(100)).await?;
47//!
48//! // 2. Create link predictor
49//! let pred_config = LinkPredictionConfig {
50//! top_k: 10,
51//! min_confidence: 0.5,
52//! filter_known_triples: true,
53//! ..Default::default()
54//! };
55//! let predictor = LinkPredictor::new(pred_config, model);
56//!
57//! // 3. Predict tail entities
58//! let candidates = vec!["bob".to_string(), "charlie".to_string(), "dave".to_string()];
59//! let predictions = predictor.predict_tail("alice", "knows", &candidates)?;
60//!
61//! for pred in predictions {
62//! println!("Entity: {}, Score: {:.3}, Confidence: {:.3}, Rank: {}",
63//! pred.predicted_id, pred.score, pred.confidence, pred.rank);
64//! }
65//! # Ok(())
66//! # }
67//! ```
68//!
69//! # Prediction Tasks
70//!
71//! ## Tail Entity Prediction
72//!
73//! Given a subject and predicate, predict the most likely objects:
74//!
75//! ```rust,no_run
76//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
77//! # async fn example() -> anyhow::Result<()> {
78//! # let model = TransE::new(ModelConfig::default());
79//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
80//! let candidates = vec!["paris".to_string(), "london".to_string(), "berlin".to_string()];
81//! let predictions = predictor.predict_tail("france", "has_capital", &candidates)?;
82//! // Expected: "paris" should rank first with high confidence
83//! # Ok(())
84//! # }
85//! ```
86//!
87//! ## Head Entity Prediction
88//!
89//! Given a predicate and object, predict the most likely subjects:
90//!
91//! ```rust,no_run
92//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
93//! # async fn example() -> anyhow::Result<()> {
94//! # let model = TransE::new(ModelConfig::default());
95//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
96//! let candidates = vec!["france".to_string(), "germany".to_string(), "uk".to_string()];
97//! let predictions = predictor.predict_head("has_capital", "paris", &candidates)?;
98//! // Expected: "france" should rank first
99//! # Ok(())
100//! # }
101//! ```
102//!
103//! ## Relation Prediction
104//!
105//! Given a subject and object, predict the most likely relations:
106//!
107//! ```rust,no_run
108//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
109//! # async fn example() -> anyhow::Result<()> {
110//! # let model = TransE::new(ModelConfig::default());
111//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
112//! let candidates = vec!["has_capital".to_string(), "located_in".to_string()];
113//! let predictions = predictor.predict_relation("france", "paris", &candidates)?;
114//! # Ok(())
115//! # }
116//! ```
117//!
118//! # Batch Processing
119//!
120//! For efficient processing of multiple queries:
121//!
122//! ```rust,no_run
123//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
124//! # async fn example() -> anyhow::Result<()> {
125//! # let model = TransE::new(ModelConfig::default());
126//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
127//! let queries = vec![
128//! ("france".to_string(), "has_capital".to_string()),
129//! ("germany".to_string(), "has_capital".to_string()),
130//! ];
131//! let candidates = vec!["paris".to_string(), "berlin".to_string()];
132//! let batch_results = predictor.predict_tails_batch(&queries, &candidates)?;
133//! # Ok(())
134//! # }
135//! ```
136//!
137//! # Evaluation
138//!
139//! Evaluate link prediction performance on a test set:
140//!
141//! ```rust,no_run
142//! # use oxirs_embed::{TransE, ModelConfig, Triple, NamedNode, link_prediction::{LinkPredictor, LinkPredictionConfig}};
143//! # async fn example() -> anyhow::Result<()> {
144//! # let model = TransE::new(ModelConfig::default());
145//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
146//! let test_triples = vec![
147//! Triple::new(
148//! NamedNode::new("france")?,
149//! NamedNode::new("has_capital")?,
150//! NamedNode::new("paris")?,
151//! ),
152//! ];
153//! let candidates = vec!["paris".to_string(), "london".to_string()];
154//!
155//! let metrics = predictor.evaluate(&test_triples, &candidates)?;
156//! println!("Mean Rank: {:.2}", metrics.mean_rank);
157//! println!("MRR: {:.4}", metrics.mrr);
158//! println!("Hits@1: {:.4}", metrics.hits_at_1);
159//! println!("Hits@10: {:.4}", metrics.hits_at_10);
160//! # Ok(())
161//! # }
162//! ```
163//!
164//! # Configuration
165//!
166//! The [`LinkPredictionConfig`] allows fine-tuning prediction behavior:
167//!
168//! ```rust
169//! use oxirs_embed::link_prediction::LinkPredictionConfig;
170//!
171//! let config = LinkPredictionConfig {
172//! top_k: 10, // Return top 10 predictions
173//! min_confidence: 0.5, // Filter predictions below 50% confidence
174//! filter_known_triples: true, // Remove known facts from ranking
175//! parallel: true, // Use parallel processing
176//! batch_size: 100, // Batch size for processing
177//! };
178//! ```
179//!
180//! # Evaluation Metrics
181//!
182//! The module provides standard knowledge graph evaluation metrics:
183//!
184//! - **Mean Rank (MR)**: Average rank of correct entities (lower is better)
185//! - **Mean Reciprocal Rank (MRR)**: Average of 1/rank (higher is better, 0-1)
186//! - **Hits@K**: Percentage of correct entities in top-K predictions (higher is better)
187//!
188//! These metrics are computed following the filtered setting used in standard benchmarks
189//! like FB15k-237 and WN18RR, where known triples are removed from ranking to avoid
190//! trivial predictions.
191//!
192//! # Performance Considerations
193//!
194//! - Enable `parallel: true` for large-scale predictions
195//! - Use batch processing for multiple queries
196//! - Filter known triples to avoid redundant computation
197//! - Adjust `top_k` and `min_confidence` based on application needs
198//!
199//! # See Also
200//!
201//! - [`LinkPredictor`]: Main prediction interface
202//! - [`LinkPredictionConfig`]: Configuration options
203//! - [`LinkPredictionMetrics`]: Evaluation metrics
204
205use anyhow::Result;
206use rayon::prelude::*;
207use serde::{Deserialize, Serialize};
208use std::collections::HashSet;
209use tracing::info;
210
211use crate::{EmbeddingModel, Triple};
212
213/// Link prediction configuration
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct LinkPredictionConfig {
216 /// Top-K candidates to return
217 pub top_k: usize,
218 /// Minimum confidence threshold (0.0 to 1.0)
219 pub min_confidence: f32,
220 /// Use filtering to remove known triples from ranking
221 pub filter_known_triples: bool,
222 /// Enable parallel processing
223 pub parallel: bool,
224 /// Batch size for batch predictions
225 pub batch_size: usize,
226}
227
228impl Default for LinkPredictionConfig {
229 fn default() -> Self {
230 Self {
231 top_k: 10,
232 min_confidence: 0.5,
233 filter_known_triples: true,
234 parallel: true,
235 batch_size: 100,
236 }
237 }
238}
239
240/// Link prediction result
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct LinkPrediction {
243 /// Predicted entity or relation ID
244 pub predicted_id: String,
245 /// Prediction score (higher is better)
246 pub score: f32,
247 /// Confidence level (0.0 to 1.0)
248 pub confidence: f32,
249 /// Rank in the candidate list (1-indexed)
250 pub rank: usize,
251}
252
253/// Link prediction type
254#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255pub enum PredictionType {
256 /// Predict tail entity: (head, relation, ?)
257 TailEntity,
258 /// Predict head entity: (?, relation, tail)
259 HeadEntity,
260 /// Predict predicate: (head, ?, tail)
261 Relation,
262}
263
264/// Evaluation metrics for link prediction
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct LinkPredictionMetrics {
267 /// Mean Rank (lower is better)
268 pub mean_rank: f32,
269 /// Mean Reciprocal Rank (higher is better, 0-1)
270 pub mrr: f32,
271 /// Hits@1 (percentage, 0-1)
272 pub hits_at_1: f32,
273 /// Hits@3 (percentage, 0-1)
274 pub hits_at_3: f32,
275 /// Hits@5 (percentage, 0-1)
276 pub hits_at_5: f32,
277 /// Hits@10 (percentage, 0-1)
278 pub hits_at_10: f32,
279 /// Number of predictions evaluated
280 pub num_predictions: usize,
281}
282
283impl LinkPredictionMetrics {
284 /// Create empty metrics
285 pub fn new() -> Self {
286 Self {
287 mean_rank: 0.0,
288 mrr: 0.0,
289 hits_at_1: 0.0,
290 hits_at_3: 0.0,
291 hits_at_5: 0.0,
292 hits_at_10: 0.0,
293 num_predictions: 0,
294 }
295 }
296
297 /// Update metrics with a new rank
298 pub fn update(&mut self, rank: usize) {
299 self.num_predictions += 1;
300 let n = self.num_predictions as f32;
301
302 // Update mean rank
303 self.mean_rank = ((self.mean_rank * (n - 1.0)) + rank as f32) / n;
304
305 // Update MRR
306 let reciprocal_rank = 1.0 / rank as f32;
307 self.mrr = ((self.mrr * (n - 1.0)) + reciprocal_rank) / n;
308
309 // Update Hits@K
310 if rank <= 1 {
311 self.hits_at_1 = ((self.hits_at_1 * (n - 1.0)) + 1.0) / n;
312 } else {
313 self.hits_at_1 = (self.hits_at_1 * (n - 1.0)) / n;
314 }
315
316 if rank <= 3 {
317 self.hits_at_3 = ((self.hits_at_3 * (n - 1.0)) + 1.0) / n;
318 } else {
319 self.hits_at_3 = (self.hits_at_3 * (n - 1.0)) / n;
320 }
321
322 if rank <= 5 {
323 self.hits_at_5 = ((self.hits_at_5 * (n - 1.0)) + 1.0) / n;
324 } else {
325 self.hits_at_5 = (self.hits_at_5 * (n - 1.0)) / n;
326 }
327
328 if rank <= 10 {
329 self.hits_at_10 = ((self.hits_at_10 * (n - 1.0)) + 1.0) / n;
330 } else {
331 self.hits_at_10 = (self.hits_at_10 * (n - 1.0)) / n;
332 }
333 }
334}
335
336impl Default for LinkPredictionMetrics {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342/// Link predictor for knowledge graph completion
343pub struct LinkPredictor<M: EmbeddingModel> {
344 config: LinkPredictionConfig,
345 model: M,
346 known_triples: HashSet<(String, String, String)>,
347}
348
349impl<M: EmbeddingModel> LinkPredictor<M> {
350 /// Create new link predictor
351 pub fn new(config: LinkPredictionConfig, model: M) -> Self {
352 Self {
353 config,
354 model,
355 known_triples: HashSet::new(),
356 }
357 }
358
359 /// Add known triples for filtering
360 pub fn add_known_triples(&mut self, triples: &[Triple]) {
361 for triple in triples {
362 self.known_triples.insert((
363 triple.subject.to_string(),
364 triple.predicate.to_string(),
365 triple.object.to_string(),
366 ));
367 }
368 }
369
370 /// Predict tail entities given head and relation
371 pub fn predict_tail(
372 &self,
373 subject: &str,
374 predicate: &str,
375 candidate_entities: &[String],
376 ) -> Result<Vec<LinkPrediction>> {
377 // Score all candidates
378 let scored: Vec<(String, f64)> = if self.config.parallel {
379 candidate_entities
380 .par_iter()
381 .filter_map(|tail| {
382 if self.config.filter_known_triples
383 && self.known_triples.contains(&(
384 subject.to_string(),
385 predicate.to_string(),
386 tail.clone(),
387 ))
388 {
389 return None;
390 }
391
392 self.model
393 .score_triple(subject, predicate, tail)
394 .ok()
395 .map(|score| (tail.clone(), score))
396 })
397 .collect()
398 } else {
399 candidate_entities
400 .iter()
401 .filter_map(|tail| {
402 if self.config.filter_known_triples
403 && self.known_triples.contains(&(
404 subject.to_string(),
405 predicate.to_string(),
406 tail.clone(),
407 ))
408 {
409 return None;
410 }
411
412 self.model
413 .score_triple(subject, predicate, tail)
414 .ok()
415 .map(|score| (tail.clone(), score))
416 })
417 .collect()
418 };
419
420 self.rank_and_filter(scored)
421 }
422
423 /// Predict head entities given relation and tail
424 pub fn predict_head(
425 &self,
426 predicate: &str,
427 object: &str,
428 candidate_entities: &[String],
429 ) -> Result<Vec<LinkPrediction>> {
430 // Score all candidates
431 let scored: Vec<(String, f64)> = if self.config.parallel {
432 candidate_entities
433 .par_iter()
434 .filter_map(|head| {
435 if self.config.filter_known_triples
436 && self.known_triples.contains(&(
437 head.clone(),
438 predicate.to_string(),
439 object.to_string(),
440 ))
441 {
442 return None;
443 }
444
445 self.model
446 .score_triple(head, predicate, object)
447 .ok()
448 .map(|score| (head.clone(), score))
449 })
450 .collect()
451 } else {
452 candidate_entities
453 .iter()
454 .filter_map(|head| {
455 if self.config.filter_known_triples
456 && self.known_triples.contains(&(
457 head.clone(),
458 predicate.to_string(),
459 object.to_string(),
460 ))
461 {
462 return None;
463 }
464
465 self.model
466 .score_triple(head, predicate, object)
467 .ok()
468 .map(|score| (head.clone(), score))
469 })
470 .collect()
471 };
472
473 self.rank_and_filter(scored)
474 }
475
476 /// Predict relations given head and tail
477 pub fn predict_relation(
478 &self,
479 subject: &str,
480 object: &str,
481 candidate_relations: &[String],
482 ) -> Result<Vec<LinkPrediction>> {
483 // Score all candidate relations
484 let scored: Vec<(String, f64)> = if self.config.parallel {
485 candidate_relations
486 .par_iter()
487 .filter_map(|relation| {
488 if self.config.filter_known_triples
489 && self.known_triples.contains(&(
490 subject.to_string(),
491 relation.clone(),
492 object.to_string(),
493 ))
494 {
495 return None;
496 }
497
498 self.model
499 .score_triple(subject, relation, object)
500 .ok()
501 .map(|score| (relation.clone(), score))
502 })
503 .collect()
504 } else {
505 candidate_relations
506 .iter()
507 .filter_map(|relation| {
508 if self.config.filter_known_triples
509 && self.known_triples.contains(&(
510 subject.to_string(),
511 relation.clone(),
512 object.to_string(),
513 ))
514 {
515 return None;
516 }
517
518 self.model
519 .score_triple(subject, relation, object)
520 .ok()
521 .map(|score| (relation.clone(), score))
522 })
523 .collect()
524 };
525
526 self.rank_and_filter(scored)
527 }
528
529 /// Batch prediction of tails
530 pub fn predict_tails_batch(
531 &self,
532 queries: &[(String, String)], // (head, relation) pairs
533 candidate_entities: &[String],
534 ) -> Result<Vec<Vec<LinkPrediction>>> {
535 queries
536 .par_iter()
537 .map(|(head, relation)| {
538 self.predict_tail(head, relation, candidate_entities)
539 .unwrap_or_default()
540 })
541 .collect::<Vec<_>>()
542 .into_iter()
543 .map(Ok)
544 .collect()
545 }
546
547 /// Evaluate link prediction on a test set
548 pub fn evaluate(
549 &self,
550 test_triples: &[Triple],
551 candidate_entities: &[String],
552 ) -> Result<LinkPredictionMetrics> {
553 let mut metrics = LinkPredictionMetrics::new();
554
555 info!(
556 "Evaluating link prediction on {} test triples",
557 test_triples.len()
558 );
559
560 for triple in test_triples {
561 // Predict tail
562 if let Ok(predictions) = self.predict_tail(
563 &triple.subject.to_string(),
564 &triple.predicate.to_string(),
565 candidate_entities,
566 ) {
567 // Find rank of correct tail
568 if let Some(rank) = predictions
569 .iter()
570 .position(|pred| pred.predicted_id == triple.object.to_string())
571 {
572 metrics.update(rank + 1); // 1-indexed rank
573 }
574 }
575 }
576
577 info!(
578 "Evaluation complete: MRR={:.4}, Hits@10={:.4}",
579 metrics.mrr, metrics.hits_at_10
580 );
581
582 Ok(metrics)
583 }
584
585 /// Rank and filter predictions
586 fn rank_and_filter(&self, mut scored: Vec<(String, f64)>) -> Result<Vec<LinkPrediction>> {
587 // Sort by score descending
588 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
589
590 // Take top-K
591 scored.truncate(self.config.top_k);
592
593 // Normalize scores to confidence (0-1)
594 let max_score = scored.first().map(|(_, s)| *s).unwrap_or(1.0);
595 let min_score = scored.last().map(|(_, s)| *s).unwrap_or(0.0);
596 let score_range = (max_score - min_score).max(1e-10);
597
598 let predictions: Vec<LinkPrediction> = scored
599 .into_iter()
600 .enumerate()
601 .filter_map(|(rank, (id, score))| {
602 let confidence = (score - min_score) / score_range;
603
604 if confidence >= self.config.min_confidence as f64 {
605 Some(LinkPrediction {
606 predicted_id: id,
607 score: score as f32,
608 confidence: confidence as f32,
609 rank: rank + 1, // 1-indexed
610 })
611 } else {
612 None
613 }
614 })
615 .collect();
616
617 Ok(predictions)
618 }
619
620 /// Get reference to underlying model
621 pub fn model(&self) -> &M {
622 &self.model
623 }
624
625 /// Get mutable reference to underlying model
626 pub fn model_mut(&mut self) -> &mut M {
627 &mut self.model
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use crate::models::transe::TransE;
635 use crate::{ModelConfig, NamedNode};
636
637 #[tokio::test]
638 async fn test_link_prediction_tail() {
639 let config = ModelConfig {
640 dimensions: 50,
641 learning_rate: 0.01,
642 max_epochs: 50,
643 ..Default::default()
644 };
645
646 let mut model = TransE::new(config);
647
648 // Add training data
649 model
650 .add_triple(Triple::new(
651 NamedNode::new("alice").unwrap(),
652 NamedNode::new("knows").unwrap(),
653 NamedNode::new("bob").unwrap(),
654 ))
655 .unwrap();
656
657 model
658 .add_triple(Triple::new(
659 NamedNode::new("alice").unwrap(),
660 NamedNode::new("knows").unwrap(),
661 NamedNode::new("charlie").unwrap(),
662 ))
663 .unwrap();
664
665 model
666 .add_triple(Triple::new(
667 NamedNode::new("bob").unwrap(),
668 NamedNode::new("likes").unwrap(),
669 NamedNode::new("dave").unwrap(),
670 ))
671 .unwrap();
672
673 // Train model
674 model.train(Some(50)).await.unwrap();
675
676 // Create link predictor
677 let pred_config = LinkPredictionConfig {
678 top_k: 5,
679 filter_known_triples: false,
680 ..Default::default()
681 };
682
683 let predictor = LinkPredictor::new(pred_config, model);
684
685 // Predict tails
686 let candidates = vec!["bob".to_string(), "charlie".to_string(), "dave".to_string()];
687
688 let predictions = predictor
689 .predict_tail("alice", "knows", &candidates)
690 .unwrap();
691
692 assert!(!predictions.is_empty());
693 assert!(predictions.len() <= 5);
694
695 // Check that predictions are ranked
696 for i in 0..predictions.len() - 1 {
697 assert!(predictions[i].score >= predictions[i + 1].score);
698 }
699 }
700
701 #[tokio::test]
702 async fn test_link_prediction_metrics() {
703 let mut metrics = LinkPredictionMetrics::new();
704
705 // Simulate some predictions
706 metrics.update(1); // Perfect prediction
707 metrics.update(3); // Rank 3
708 metrics.update(10); // Rank 10
709
710 assert_eq!(metrics.num_predictions, 3);
711 assert!(metrics.mrr > 0.0);
712 assert!(metrics.hits_at_1 > 0.0);
713 assert!(metrics.hits_at_10 == 1.0); // All within top 10
714 }
715
716 #[tokio::test]
717 async fn test_batch_prediction() {
718 let config = ModelConfig {
719 dimensions: 50,
720 max_epochs: 30,
721 ..Default::default()
722 };
723
724 let mut model = TransE::new(config);
725
726 model
727 .add_triple(Triple::new(
728 NamedNode::new("a").unwrap(),
729 NamedNode::new("r1").unwrap(),
730 NamedNode::new("b").unwrap(),
731 ))
732 .unwrap();
733
734 model.train(Some(30)).await.unwrap();
735
736 let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
737
738 let queries = vec![("a".to_string(), "r1".to_string())];
739
740 let candidates = vec!["b".to_string()];
741
742 let results = predictor
743 .predict_tails_batch(&queries, &candidates)
744 .unwrap();
745
746 assert_eq!(results.len(), 1);
747 }
748}