oxirs_embed/
link_prediction.rs

1//! Link Prediction for Knowledge Graph Completion
2//!
3//! This module provides comprehensive link prediction functionality for knowledge graphs,
4//! enabling entity prediction, relation prediction, ranking, and evaluation capabilities.
5//!
6//! Link prediction is a fundamental task in knowledge graph completion where we predict
7//! missing links (triples) based on learned embeddings. This module supports multiple
8//! prediction tasks and provides evaluation metrics following standard benchmarks.
9//!
10//! # Overview
11//!
12//! The module provides:
13//! - **Tail entity prediction**: Given (head, relation, ?), predict the tail entity
14//! - **Head entity prediction**: Given (?, relation, tail), predict the head entity
15//! - **Relation prediction**: Given (head, ?, tail), predict the relation
16//! - **Batch prediction**: Process multiple queries efficiently in parallel
17//! - **Evaluation metrics**: MRR, Hits@K, Mean Rank for benchmarking
18//! - **Filtered ranking**: Remove known triples from evaluation
19//!
20//! # Quick Start
21//!
22//! ```rust,no_run
23//! use oxirs_embed::{
24//!     TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
25//!     link_prediction::{LinkPredictor, LinkPredictionConfig},
26//! };
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train a knowledge graph embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! // Add training triples
34//! model.add_triple(Triple::new(
35//!     NamedNode::new("alice")?,
36//!     NamedNode::new("knows")?,
37//!     NamedNode::new("bob")?,
38//! ))?;
39//! model.add_triple(Triple::new(
40//!     NamedNode::new("bob")?,
41//!     NamedNode::new("knows")?,
42//!     NamedNode::new("charlie")?,
43//! ))?;
44//!
45//! // Train the model
46//! model.train(Some(100)).await?;
47//!
48//! // 2. Create link predictor
49//! let pred_config = LinkPredictionConfig {
50//!     top_k: 10,
51//!     min_confidence: 0.5,
52//!     filter_known_triples: true,
53//!     ..Default::default()
54//! };
55//! let predictor = LinkPredictor::new(pred_config, model);
56//!
57//! // 3. Predict tail entities
58//! let candidates = vec!["bob".to_string(), "charlie".to_string(), "dave".to_string()];
59//! let predictions = predictor.predict_tail("alice", "knows", &candidates)?;
60//!
61//! for pred in predictions {
62//!     println!("Entity: {}, Score: {:.3}, Confidence: {:.3}, Rank: {}",
63//!              pred.predicted_id, pred.score, pred.confidence, pred.rank);
64//! }
65//! # Ok(())
66//! # }
67//! ```
68//!
69//! # Prediction Tasks
70//!
71//! ## Tail Entity Prediction
72//!
73//! Given a subject and predicate, predict the most likely objects:
74//!
75//! ```rust,no_run
76//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
77//! # async fn example() -> anyhow::Result<()> {
78//! # let model = TransE::new(ModelConfig::default());
79//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
80//! let candidates = vec!["paris".to_string(), "london".to_string(), "berlin".to_string()];
81//! let predictions = predictor.predict_tail("france", "has_capital", &candidates)?;
82//! // Expected: "paris" should rank first with high confidence
83//! # Ok(())
84//! # }
85//! ```
86//!
87//! ## Head Entity Prediction
88//!
89//! Given a predicate and object, predict the most likely subjects:
90//!
91//! ```rust,no_run
92//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
93//! # async fn example() -> anyhow::Result<()> {
94//! # let model = TransE::new(ModelConfig::default());
95//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
96//! let candidates = vec!["france".to_string(), "germany".to_string(), "uk".to_string()];
97//! let predictions = predictor.predict_head("has_capital", "paris", &candidates)?;
98//! // Expected: "france" should rank first
99//! # Ok(())
100//! # }
101//! ```
102//!
103//! ## Relation Prediction
104//!
105//! Given a subject and object, predict the most likely relations:
106//!
107//! ```rust,no_run
108//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
109//! # async fn example() -> anyhow::Result<()> {
110//! # let model = TransE::new(ModelConfig::default());
111//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
112//! let candidates = vec!["has_capital".to_string(), "located_in".to_string()];
113//! let predictions = predictor.predict_relation("france", "paris", &candidates)?;
114//! # Ok(())
115//! # }
116//! ```
117//!
118//! # Batch Processing
119//!
120//! For efficient processing of multiple queries:
121//!
122//! ```rust,no_run
123//! # use oxirs_embed::{TransE, ModelConfig, link_prediction::{LinkPredictor, LinkPredictionConfig}};
124//! # async fn example() -> anyhow::Result<()> {
125//! # let model = TransE::new(ModelConfig::default());
126//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
127//! let queries = vec![
128//!     ("france".to_string(), "has_capital".to_string()),
129//!     ("germany".to_string(), "has_capital".to_string()),
130//! ];
131//! let candidates = vec!["paris".to_string(), "berlin".to_string()];
132//! let batch_results = predictor.predict_tails_batch(&queries, &candidates)?;
133//! # Ok(())
134//! # }
135//! ```
136//!
137//! # Evaluation
138//!
139//! Evaluate link prediction performance on a test set:
140//!
141//! ```rust,no_run
142//! # use oxirs_embed::{TransE, ModelConfig, Triple, NamedNode, link_prediction::{LinkPredictor, LinkPredictionConfig}};
143//! # async fn example() -> anyhow::Result<()> {
144//! # let model = TransE::new(ModelConfig::default());
145//! # let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
146//! let test_triples = vec![
147//!     Triple::new(
148//!         NamedNode::new("france")?,
149//!         NamedNode::new("has_capital")?,
150//!         NamedNode::new("paris")?,
151//!     ),
152//! ];
153//! let candidates = vec!["paris".to_string(), "london".to_string()];
154//!
155//! let metrics = predictor.evaluate(&test_triples, &candidates)?;
156//! println!("Mean Rank: {:.2}", metrics.mean_rank);
157//! println!("MRR: {:.4}", metrics.mrr);
158//! println!("Hits@1: {:.4}", metrics.hits_at_1);
159//! println!("Hits@10: {:.4}", metrics.hits_at_10);
160//! # Ok(())
161//! # }
162//! ```
163//!
164//! # Configuration
165//!
166//! The [`LinkPredictionConfig`] allows fine-tuning prediction behavior:
167//!
168//! ```rust
169//! use oxirs_embed::link_prediction::LinkPredictionConfig;
170//!
171//! let config = LinkPredictionConfig {
172//!     top_k: 10,                      // Return top 10 predictions
173//!     min_confidence: 0.5,             // Filter predictions below 50% confidence
174//!     filter_known_triples: true,      // Remove known facts from ranking
175//!     parallel: true,                  // Use parallel processing
176//!     batch_size: 100,                 // Batch size for processing
177//! };
178//! ```
179//!
180//! # Evaluation Metrics
181//!
182//! The module provides standard knowledge graph evaluation metrics:
183//!
184//! - **Mean Rank (MR)**: Average rank of correct entities (lower is better)
185//! - **Mean Reciprocal Rank (MRR)**: Average of 1/rank (higher is better, 0-1)
186//! - **Hits@K**: Percentage of correct entities in top-K predictions (higher is better)
187//!
188//! These metrics are computed following the filtered setting used in standard benchmarks
189//! like FB15k-237 and WN18RR, where known triples are removed from ranking to avoid
190//! trivial predictions.
191//!
192//! # Performance Considerations
193//!
194//! - Enable `parallel: true` for large-scale predictions
195//! - Use batch processing for multiple queries
196//! - Filter known triples to avoid redundant computation
197//! - Adjust `top_k` and `min_confidence` based on application needs
198//!
199//! # See Also
200//!
201//! - [`LinkPredictor`]: Main prediction interface
202//! - [`LinkPredictionConfig`]: Configuration options
203//! - [`LinkPredictionMetrics`]: Evaluation metrics
204
205use anyhow::Result;
206use rayon::prelude::*;
207use serde::{Deserialize, Serialize};
208use std::collections::HashSet;
209use tracing::info;
210
211use crate::{EmbeddingModel, Triple};
212
213/// Link prediction configuration
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct LinkPredictionConfig {
216    /// Top-K candidates to return
217    pub top_k: usize,
218    /// Minimum confidence threshold (0.0 to 1.0)
219    pub min_confidence: f32,
220    /// Use filtering to remove known triples from ranking
221    pub filter_known_triples: bool,
222    /// Enable parallel processing
223    pub parallel: bool,
224    /// Batch size for batch predictions
225    pub batch_size: usize,
226}
227
228impl Default for LinkPredictionConfig {
229    fn default() -> Self {
230        Self {
231            top_k: 10,
232            min_confidence: 0.5,
233            filter_known_triples: true,
234            parallel: true,
235            batch_size: 100,
236        }
237    }
238}
239
240/// Link prediction result
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct LinkPrediction {
243    /// Predicted entity or relation ID
244    pub predicted_id: String,
245    /// Prediction score (higher is better)
246    pub score: f32,
247    /// Confidence level (0.0 to 1.0)
248    pub confidence: f32,
249    /// Rank in the candidate list (1-indexed)
250    pub rank: usize,
251}
252
253/// Link prediction type
254#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255pub enum PredictionType {
256    /// Predict tail entity: (head, relation, ?)
257    TailEntity,
258    /// Predict head entity: (?, relation, tail)
259    HeadEntity,
260    /// Predict predicate: (head, ?, tail)
261    Relation,
262}
263
264/// Evaluation metrics for link prediction
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct LinkPredictionMetrics {
267    /// Mean Rank (lower is better)
268    pub mean_rank: f32,
269    /// Mean Reciprocal Rank (higher is better, 0-1)
270    pub mrr: f32,
271    /// Hits@1 (percentage, 0-1)
272    pub hits_at_1: f32,
273    /// Hits@3 (percentage, 0-1)
274    pub hits_at_3: f32,
275    /// Hits@5 (percentage, 0-1)
276    pub hits_at_5: f32,
277    /// Hits@10 (percentage, 0-1)
278    pub hits_at_10: f32,
279    /// Number of predictions evaluated
280    pub num_predictions: usize,
281}
282
283impl LinkPredictionMetrics {
284    /// Create empty metrics
285    pub fn new() -> Self {
286        Self {
287            mean_rank: 0.0,
288            mrr: 0.0,
289            hits_at_1: 0.0,
290            hits_at_3: 0.0,
291            hits_at_5: 0.0,
292            hits_at_10: 0.0,
293            num_predictions: 0,
294        }
295    }
296
297    /// Update metrics with a new rank
298    pub fn update(&mut self, rank: usize) {
299        self.num_predictions += 1;
300        let n = self.num_predictions as f32;
301
302        // Update mean rank
303        self.mean_rank = ((self.mean_rank * (n - 1.0)) + rank as f32) / n;
304
305        // Update MRR
306        let reciprocal_rank = 1.0 / rank as f32;
307        self.mrr = ((self.mrr * (n - 1.0)) + reciprocal_rank) / n;
308
309        // Update Hits@K
310        if rank <= 1 {
311            self.hits_at_1 = ((self.hits_at_1 * (n - 1.0)) + 1.0) / n;
312        } else {
313            self.hits_at_1 = (self.hits_at_1 * (n - 1.0)) / n;
314        }
315
316        if rank <= 3 {
317            self.hits_at_3 = ((self.hits_at_3 * (n - 1.0)) + 1.0) / n;
318        } else {
319            self.hits_at_3 = (self.hits_at_3 * (n - 1.0)) / n;
320        }
321
322        if rank <= 5 {
323            self.hits_at_5 = ((self.hits_at_5 * (n - 1.0)) + 1.0) / n;
324        } else {
325            self.hits_at_5 = (self.hits_at_5 * (n - 1.0)) / n;
326        }
327
328        if rank <= 10 {
329            self.hits_at_10 = ((self.hits_at_10 * (n - 1.0)) + 1.0) / n;
330        } else {
331            self.hits_at_10 = (self.hits_at_10 * (n - 1.0)) / n;
332        }
333    }
334}
335
336impl Default for LinkPredictionMetrics {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342/// Link predictor for knowledge graph completion
343pub struct LinkPredictor<M: EmbeddingModel> {
344    config: LinkPredictionConfig,
345    model: M,
346    known_triples: HashSet<(String, String, String)>,
347}
348
349impl<M: EmbeddingModel> LinkPredictor<M> {
350    /// Create new link predictor
351    pub fn new(config: LinkPredictionConfig, model: M) -> Self {
352        Self {
353            config,
354            model,
355            known_triples: HashSet::new(),
356        }
357    }
358
359    /// Add known triples for filtering
360    pub fn add_known_triples(&mut self, triples: &[Triple]) {
361        for triple in triples {
362            self.known_triples.insert((
363                triple.subject.to_string(),
364                triple.predicate.to_string(),
365                triple.object.to_string(),
366            ));
367        }
368    }
369
370    /// Predict tail entities given head and relation
371    pub fn predict_tail(
372        &self,
373        subject: &str,
374        predicate: &str,
375        candidate_entities: &[String],
376    ) -> Result<Vec<LinkPrediction>> {
377        // Score all candidates
378        let scored: Vec<(String, f64)> = if self.config.parallel {
379            candidate_entities
380                .par_iter()
381                .filter_map(|tail| {
382                    if self.config.filter_known_triples
383                        && self.known_triples.contains(&(
384                            subject.to_string(),
385                            predicate.to_string(),
386                            tail.clone(),
387                        ))
388                    {
389                        return None;
390                    }
391
392                    self.model
393                        .score_triple(subject, predicate, tail)
394                        .ok()
395                        .map(|score| (tail.clone(), score))
396                })
397                .collect()
398        } else {
399            candidate_entities
400                .iter()
401                .filter_map(|tail| {
402                    if self.config.filter_known_triples
403                        && self.known_triples.contains(&(
404                            subject.to_string(),
405                            predicate.to_string(),
406                            tail.clone(),
407                        ))
408                    {
409                        return None;
410                    }
411
412                    self.model
413                        .score_triple(subject, predicate, tail)
414                        .ok()
415                        .map(|score| (tail.clone(), score))
416                })
417                .collect()
418        };
419
420        self.rank_and_filter(scored)
421    }
422
423    /// Predict head entities given relation and tail
424    pub fn predict_head(
425        &self,
426        predicate: &str,
427        object: &str,
428        candidate_entities: &[String],
429    ) -> Result<Vec<LinkPrediction>> {
430        // Score all candidates
431        let scored: Vec<(String, f64)> = if self.config.parallel {
432            candidate_entities
433                .par_iter()
434                .filter_map(|head| {
435                    if self.config.filter_known_triples
436                        && self.known_triples.contains(&(
437                            head.clone(),
438                            predicate.to_string(),
439                            object.to_string(),
440                        ))
441                    {
442                        return None;
443                    }
444
445                    self.model
446                        .score_triple(head, predicate, object)
447                        .ok()
448                        .map(|score| (head.clone(), score))
449                })
450                .collect()
451        } else {
452            candidate_entities
453                .iter()
454                .filter_map(|head| {
455                    if self.config.filter_known_triples
456                        && self.known_triples.contains(&(
457                            head.clone(),
458                            predicate.to_string(),
459                            object.to_string(),
460                        ))
461                    {
462                        return None;
463                    }
464
465                    self.model
466                        .score_triple(head, predicate, object)
467                        .ok()
468                        .map(|score| (head.clone(), score))
469                })
470                .collect()
471        };
472
473        self.rank_and_filter(scored)
474    }
475
476    /// Predict relations given head and tail
477    pub fn predict_relation(
478        &self,
479        subject: &str,
480        object: &str,
481        candidate_relations: &[String],
482    ) -> Result<Vec<LinkPrediction>> {
483        // Score all candidate relations
484        let scored: Vec<(String, f64)> = if self.config.parallel {
485            candidate_relations
486                .par_iter()
487                .filter_map(|relation| {
488                    if self.config.filter_known_triples
489                        && self.known_triples.contains(&(
490                            subject.to_string(),
491                            relation.clone(),
492                            object.to_string(),
493                        ))
494                    {
495                        return None;
496                    }
497
498                    self.model
499                        .score_triple(subject, relation, object)
500                        .ok()
501                        .map(|score| (relation.clone(), score))
502                })
503                .collect()
504        } else {
505            candidate_relations
506                .iter()
507                .filter_map(|relation| {
508                    if self.config.filter_known_triples
509                        && self.known_triples.contains(&(
510                            subject.to_string(),
511                            relation.clone(),
512                            object.to_string(),
513                        ))
514                    {
515                        return None;
516                    }
517
518                    self.model
519                        .score_triple(subject, relation, object)
520                        .ok()
521                        .map(|score| (relation.clone(), score))
522                })
523                .collect()
524        };
525
526        self.rank_and_filter(scored)
527    }
528
529    /// Batch prediction of tails
530    pub fn predict_tails_batch(
531        &self,
532        queries: &[(String, String)], // (head, relation) pairs
533        candidate_entities: &[String],
534    ) -> Result<Vec<Vec<LinkPrediction>>> {
535        queries
536            .par_iter()
537            .map(|(head, relation)| {
538                self.predict_tail(head, relation, candidate_entities)
539                    .unwrap_or_default()
540            })
541            .collect::<Vec<_>>()
542            .into_iter()
543            .map(Ok)
544            .collect()
545    }
546
547    /// Evaluate link prediction on a test set
548    pub fn evaluate(
549        &self,
550        test_triples: &[Triple],
551        candidate_entities: &[String],
552    ) -> Result<LinkPredictionMetrics> {
553        let mut metrics = LinkPredictionMetrics::new();
554
555        info!(
556            "Evaluating link prediction on {} test triples",
557            test_triples.len()
558        );
559
560        for triple in test_triples {
561            // Predict tail
562            if let Ok(predictions) = self.predict_tail(
563                &triple.subject.to_string(),
564                &triple.predicate.to_string(),
565                candidate_entities,
566            ) {
567                // Find rank of correct tail
568                if let Some(rank) = predictions
569                    .iter()
570                    .position(|pred| pred.predicted_id == triple.object.to_string())
571                {
572                    metrics.update(rank + 1); // 1-indexed rank
573                }
574            }
575        }
576
577        info!(
578            "Evaluation complete: MRR={:.4}, Hits@10={:.4}",
579            metrics.mrr, metrics.hits_at_10
580        );
581
582        Ok(metrics)
583    }
584
585    /// Rank and filter predictions
586    fn rank_and_filter(&self, mut scored: Vec<(String, f64)>) -> Result<Vec<LinkPrediction>> {
587        // Sort by score descending
588        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
589
590        // Take top-K
591        scored.truncate(self.config.top_k);
592
593        // Normalize scores to confidence (0-1)
594        let max_score = scored.first().map(|(_, s)| *s).unwrap_or(1.0);
595        let min_score = scored.last().map(|(_, s)| *s).unwrap_or(0.0);
596        let score_range = (max_score - min_score).max(1e-10);
597
598        let predictions: Vec<LinkPrediction> = scored
599            .into_iter()
600            .enumerate()
601            .filter_map(|(rank, (id, score))| {
602                let confidence = (score - min_score) / score_range;
603
604                if confidence >= self.config.min_confidence as f64 {
605                    Some(LinkPrediction {
606                        predicted_id: id,
607                        score: score as f32,
608                        confidence: confidence as f32,
609                        rank: rank + 1, // 1-indexed
610                    })
611                } else {
612                    None
613                }
614            })
615            .collect();
616
617        Ok(predictions)
618    }
619
620    /// Get reference to underlying model
621    pub fn model(&self) -> &M {
622        &self.model
623    }
624
625    /// Get mutable reference to underlying model
626    pub fn model_mut(&mut self) -> &mut M {
627        &mut self.model
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use crate::models::transe::TransE;
635    use crate::{ModelConfig, NamedNode};
636
637    #[tokio::test]
638    async fn test_link_prediction_tail() {
639        let config = ModelConfig {
640            dimensions: 50,
641            learning_rate: 0.01,
642            max_epochs: 50,
643            ..Default::default()
644        };
645
646        let mut model = TransE::new(config);
647
648        // Add training data
649        model
650            .add_triple(Triple::new(
651                NamedNode::new("alice").unwrap(),
652                NamedNode::new("knows").unwrap(),
653                NamedNode::new("bob").unwrap(),
654            ))
655            .unwrap();
656
657        model
658            .add_triple(Triple::new(
659                NamedNode::new("alice").unwrap(),
660                NamedNode::new("knows").unwrap(),
661                NamedNode::new("charlie").unwrap(),
662            ))
663            .unwrap();
664
665        model
666            .add_triple(Triple::new(
667                NamedNode::new("bob").unwrap(),
668                NamedNode::new("likes").unwrap(),
669                NamedNode::new("dave").unwrap(),
670            ))
671            .unwrap();
672
673        // Train model
674        model.train(Some(50)).await.unwrap();
675
676        // Create link predictor
677        let pred_config = LinkPredictionConfig {
678            top_k: 5,
679            filter_known_triples: false,
680            ..Default::default()
681        };
682
683        let predictor = LinkPredictor::new(pred_config, model);
684
685        // Predict tails
686        let candidates = vec!["bob".to_string(), "charlie".to_string(), "dave".to_string()];
687
688        let predictions = predictor
689            .predict_tail("alice", "knows", &candidates)
690            .unwrap();
691
692        assert!(!predictions.is_empty());
693        assert!(predictions.len() <= 5);
694
695        // Check that predictions are ranked
696        for i in 0..predictions.len() - 1 {
697            assert!(predictions[i].score >= predictions[i + 1].score);
698        }
699    }
700
701    #[tokio::test]
702    async fn test_link_prediction_metrics() {
703        let mut metrics = LinkPredictionMetrics::new();
704
705        // Simulate some predictions
706        metrics.update(1); // Perfect prediction
707        metrics.update(3); // Rank 3
708        metrics.update(10); // Rank 10
709
710        assert_eq!(metrics.num_predictions, 3);
711        assert!(metrics.mrr > 0.0);
712        assert!(metrics.hits_at_1 > 0.0);
713        assert!(metrics.hits_at_10 == 1.0); // All within top 10
714    }
715
716    #[tokio::test]
717    async fn test_batch_prediction() {
718        let config = ModelConfig {
719            dimensions: 50,
720            max_epochs: 30,
721            ..Default::default()
722        };
723
724        let mut model = TransE::new(config);
725
726        model
727            .add_triple(Triple::new(
728                NamedNode::new("a").unwrap(),
729                NamedNode::new("r1").unwrap(),
730                NamedNode::new("b").unwrap(),
731            ))
732            .unwrap();
733
734        model.train(Some(30)).await.unwrap();
735
736        let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
737
738        let queries = vec![("a".to_string(), "r1".to_string())];
739
740        let candidates = vec!["b".to_string()];
741
742        let results = predictor
743            .predict_tails_batch(&queries, &candidates)
744            .unwrap();
745
746        assert_eq!(results.len(), 1);
747    }
748}