use anyhow::Result;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tracing::info;
use crate::{EmbeddingModel, Triple};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkPredictionConfig {
pub top_k: usize,
pub min_confidence: f32,
pub filter_known_triples: bool,
pub parallel: bool,
pub batch_size: usize,
}
impl Default for LinkPredictionConfig {
fn default() -> Self {
Self {
top_k: 10,
min_confidence: 0.5,
filter_known_triples: true,
parallel: true,
batch_size: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkPrediction {
pub predicted_id: String,
pub score: f32,
pub confidence: f32,
pub rank: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PredictionType {
TailEntity,
HeadEntity,
Relation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkPredictionMetrics {
pub mean_rank: f32,
pub mrr: f32,
pub hits_at_1: f32,
pub hits_at_3: f32,
pub hits_at_5: f32,
pub hits_at_10: f32,
pub num_predictions: usize,
}
impl LinkPredictionMetrics {
pub fn new() -> Self {
Self {
mean_rank: 0.0,
mrr: 0.0,
hits_at_1: 0.0,
hits_at_3: 0.0,
hits_at_5: 0.0,
hits_at_10: 0.0,
num_predictions: 0,
}
}
pub fn update(&mut self, rank: usize) {
self.num_predictions += 1;
let n = self.num_predictions as f32;
self.mean_rank = ((self.mean_rank * (n - 1.0)) + rank as f32) / n;
let reciprocal_rank = 1.0 / rank as f32;
self.mrr = ((self.mrr * (n - 1.0)) + reciprocal_rank) / n;
if rank <= 1 {
self.hits_at_1 = ((self.hits_at_1 * (n - 1.0)) + 1.0) / n;
} else {
self.hits_at_1 = (self.hits_at_1 * (n - 1.0)) / n;
}
if rank <= 3 {
self.hits_at_3 = ((self.hits_at_3 * (n - 1.0)) + 1.0) / n;
} else {
self.hits_at_3 = (self.hits_at_3 * (n - 1.0)) / n;
}
if rank <= 5 {
self.hits_at_5 = ((self.hits_at_5 * (n - 1.0)) + 1.0) / n;
} else {
self.hits_at_5 = (self.hits_at_5 * (n - 1.0)) / n;
}
if rank <= 10 {
self.hits_at_10 = ((self.hits_at_10 * (n - 1.0)) + 1.0) / n;
} else {
self.hits_at_10 = (self.hits_at_10 * (n - 1.0)) / n;
}
}
}
impl Default for LinkPredictionMetrics {
fn default() -> Self {
Self::new()
}
}
pub struct LinkPredictor<M: EmbeddingModel> {
config: LinkPredictionConfig,
model: M,
known_triples: HashSet<(String, String, String)>,
}
impl<M: EmbeddingModel> LinkPredictor<M> {
pub fn new(config: LinkPredictionConfig, model: M) -> Self {
Self {
config,
model,
known_triples: HashSet::new(),
}
}
pub fn add_known_triples(&mut self, triples: &[Triple]) {
for triple in triples {
self.known_triples.insert((
triple.subject.to_string(),
triple.predicate.to_string(),
triple.object.to_string(),
));
}
}
pub fn predict_tail(
&self,
subject: &str,
predicate: &str,
candidate_entities: &[String],
) -> Result<Vec<LinkPrediction>> {
let scored: Vec<(String, f64)> = if self.config.parallel {
candidate_entities
.par_iter()
.filter_map(|tail| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
subject.to_string(),
predicate.to_string(),
tail.clone(),
))
{
return None;
}
self.model
.score_triple(subject, predicate, tail)
.ok()
.map(|score| (tail.clone(), score))
})
.collect()
} else {
candidate_entities
.iter()
.filter_map(|tail| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
subject.to_string(),
predicate.to_string(),
tail.clone(),
))
{
return None;
}
self.model
.score_triple(subject, predicate, tail)
.ok()
.map(|score| (tail.clone(), score))
})
.collect()
};
self.rank_and_filter(scored)
}
pub fn predict_head(
&self,
predicate: &str,
object: &str,
candidate_entities: &[String],
) -> Result<Vec<LinkPrediction>> {
let scored: Vec<(String, f64)> = if self.config.parallel {
candidate_entities
.par_iter()
.filter_map(|head| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
head.clone(),
predicate.to_string(),
object.to_string(),
))
{
return None;
}
self.model
.score_triple(head, predicate, object)
.ok()
.map(|score| (head.clone(), score))
})
.collect()
} else {
candidate_entities
.iter()
.filter_map(|head| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
head.clone(),
predicate.to_string(),
object.to_string(),
))
{
return None;
}
self.model
.score_triple(head, predicate, object)
.ok()
.map(|score| (head.clone(), score))
})
.collect()
};
self.rank_and_filter(scored)
}
pub fn predict_relation(
&self,
subject: &str,
object: &str,
candidate_relations: &[String],
) -> Result<Vec<LinkPrediction>> {
let scored: Vec<(String, f64)> = if self.config.parallel {
candidate_relations
.par_iter()
.filter_map(|relation| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
subject.to_string(),
relation.clone(),
object.to_string(),
))
{
return None;
}
self.model
.score_triple(subject, relation, object)
.ok()
.map(|score| (relation.clone(), score))
})
.collect()
} else {
candidate_relations
.iter()
.filter_map(|relation| {
if self.config.filter_known_triples
&& self.known_triples.contains(&(
subject.to_string(),
relation.clone(),
object.to_string(),
))
{
return None;
}
self.model
.score_triple(subject, relation, object)
.ok()
.map(|score| (relation.clone(), score))
})
.collect()
};
self.rank_and_filter(scored)
}
pub fn predict_tails_batch(
&self,
queries: &[(String, String)], candidate_entities: &[String],
) -> Result<Vec<Vec<LinkPrediction>>> {
queries
.par_iter()
.map(|(head, relation)| {
self.predict_tail(head, relation, candidate_entities)
.unwrap_or_default()
})
.collect::<Vec<_>>()
.into_iter()
.map(Ok)
.collect()
}
pub fn evaluate(
&self,
test_triples: &[Triple],
candidate_entities: &[String],
) -> Result<LinkPredictionMetrics> {
let mut metrics = LinkPredictionMetrics::new();
info!(
"Evaluating link prediction on {} test triples",
test_triples.len()
);
for triple in test_triples {
if let Ok(predictions) = self.predict_tail(
&triple.subject.to_string(),
&triple.predicate.to_string(),
candidate_entities,
) {
if let Some(rank) = predictions
.iter()
.position(|pred| pred.predicted_id == triple.object.to_string())
{
metrics.update(rank + 1); }
}
}
info!(
"Evaluation complete: MRR={:.4}, Hits@10={:.4}",
metrics.mrr, metrics.hits_at_10
);
Ok(metrics)
}
fn rank_and_filter(&self, mut scored: Vec<(String, f64)>) -> Result<Vec<LinkPrediction>> {
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(self.config.top_k);
let max_score = scored.first().map(|(_, s)| *s).unwrap_or(1.0);
let min_score = scored.last().map(|(_, s)| *s).unwrap_or(0.0);
let score_range = (max_score - min_score).max(1e-10);
let predictions: Vec<LinkPrediction> = scored
.into_iter()
.enumerate()
.filter_map(|(rank, (id, score))| {
let confidence = (score - min_score) / score_range;
if confidence >= self.config.min_confidence as f64 {
Some(LinkPrediction {
predicted_id: id,
score: score as f32,
confidence: confidence as f32,
rank: rank + 1, })
} else {
None
}
})
.collect();
Ok(predictions)
}
pub fn model(&self) -> &M {
&self.model
}
pub fn model_mut(&mut self) -> &mut M {
&mut self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::transe::TransE;
use crate::{ModelConfig, NamedNode};
#[tokio::test]
async fn test_link_prediction_tail() {
let config = ModelConfig {
dimensions: 50,
learning_rate: 0.01,
max_epochs: 50,
..Default::default()
};
let mut model = TransE::new(config);
model
.add_triple(Triple::new(
NamedNode::new("alice").expect("should succeed"),
NamedNode::new("knows").expect("should succeed"),
NamedNode::new("bob").expect("should succeed"),
))
.expect("should succeed");
model
.add_triple(Triple::new(
NamedNode::new("alice").expect("should succeed"),
NamedNode::new("knows").expect("should succeed"),
NamedNode::new("charlie").expect("should succeed"),
))
.expect("should succeed");
model
.add_triple(Triple::new(
NamedNode::new("bob").expect("should succeed"),
NamedNode::new("likes").expect("should succeed"),
NamedNode::new("dave").expect("should succeed"),
))
.expect("should succeed");
model.train(Some(50)).await.expect("should succeed");
let pred_config = LinkPredictionConfig {
top_k: 5,
filter_known_triples: false,
..Default::default()
};
let predictor = LinkPredictor::new(pred_config, model);
let candidates = vec!["bob".to_string(), "charlie".to_string(), "dave".to_string()];
let predictions = predictor
.predict_tail("alice", "knows", &candidates)
.expect("should succeed");
assert!(!predictions.is_empty());
assert!(predictions.len() <= 5);
for i in 0..predictions.len() - 1 {
assert!(predictions[i].score >= predictions[i + 1].score);
}
}
#[tokio::test]
async fn test_link_prediction_metrics() {
let mut metrics = LinkPredictionMetrics::new();
metrics.update(1); metrics.update(3); metrics.update(10);
assert_eq!(metrics.num_predictions, 3);
assert!(metrics.mrr > 0.0);
assert!(metrics.hits_at_1 > 0.0);
assert!(metrics.hits_at_10 == 1.0); }
#[tokio::test]
async fn test_batch_prediction() {
let config = ModelConfig {
dimensions: 50,
max_epochs: 30,
..Default::default()
};
let mut model = TransE::new(config);
model
.add_triple(Triple::new(
NamedNode::new("a").expect("should succeed"),
NamedNode::new("r1").expect("should succeed"),
NamedNode::new("b").expect("should succeed"),
))
.expect("should succeed");
model.train(Some(30)).await.expect("should succeed");
let predictor = LinkPredictor::new(LinkPredictionConfig::default(), model);
let queries = vec![("a".to_string(), "r1".to_string())];
let candidates = vec!["b".to_string()];
let results = predictor
.predict_tails_batch(&queries, &candidates)
.expect("should succeed");
assert_eq!(results.len(), 1);
}
}