use crate::dataset::Triple;
use crate::metrics::{hits_at_k, mean_rank, mean_reciprocal_rank};
use std::collections::{HashMap, HashSet};
use super::{EvaluationResults, PerRelationResults, RelationTransform};
#[derive(Debug, Default, Clone)]
pub struct FilteredTripleIndex {
tails_by_head_rel: HashMap<String, HashMap<String, HashSet<String>>>,
heads_by_tail_rel: HashMap<String, HashMap<String, HashSet<String>>>,
}
impl FilteredTripleIndex {
pub fn from_triples<'a, I>(triples: I) -> Self
where
I: IntoIterator<Item = &'a Triple>,
{
let mut index = Self::default();
index.extend(triples);
index
}
pub fn from_dataset(dataset: &crate::dataset::Dataset) -> Self {
Self::from_triples(
dataset
.train
.iter()
.chain(dataset.valid.iter())
.chain(dataset.test.iter()),
)
}
pub fn extend<'a, I>(&mut self, triples: I)
where
I: IntoIterator<Item = &'a Triple>,
{
for t in triples {
self.tails_by_head_rel
.entry(t.head.clone())
.or_default()
.entry(t.relation.clone())
.or_default()
.insert(t.tail.clone());
self.heads_by_tail_rel
.entry(t.tail.clone())
.or_default()
.entry(t.relation.clone())
.or_default()
.insert(t.head.clone());
}
}
#[inline]
pub fn is_known_tail(&self, head: &str, relation: &str, tail: &str) -> bool {
self.tails_by_head_rel
.get(head)
.and_then(|by_rel| by_rel.get(relation))
.is_some_and(|tails| tails.contains(tail))
}
#[inline]
pub fn known_tails(&self, head: &str, relation: &str) -> Option<&HashSet<String>> {
self.tails_by_head_rel
.get(head)
.and_then(|by_rel| by_rel.get(relation))
}
#[inline]
pub fn is_known_head(&self, tail: &str, relation: &str, head: &str) -> bool {
self.heads_by_tail_rel
.get(tail)
.and_then(|by_rel| by_rel.get(relation))
.is_some_and(|heads| heads.contains(head))
}
#[inline]
pub fn known_heads(&self, tail: &str, relation: &str) -> Option<&HashSet<String>> {
self.heads_by_tail_rel
.get(tail)
.and_then(|by_rel| by_rel.get(relation))
}
}
#[derive(Debug, Default, Clone)]
pub struct FilteredTripleIndexIds {
tails_by_head_rel: HashMap<(usize, usize), HashSet<usize>>,
heads_by_tail_rel: HashMap<(usize, usize), HashSet<usize>>,
}
impl FilteredTripleIndexIds {
pub fn from_triples<'a, I>(triples: I) -> Self
where
I: IntoIterator<Item = &'a crate::dataset::TripleIds>,
{
let mut index = Self::default();
index.extend(triples);
index
}
pub fn from_dataset(dataset: &crate::dataset::InternedDataset) -> Self {
Self::from_triples(
dataset
.train
.iter()
.chain(dataset.valid.iter())
.chain(dataset.test.iter()),
)
}
pub fn extend<'a, I>(&mut self, triples: I)
where
I: IntoIterator<Item = &'a crate::dataset::TripleIds>,
{
for t in triples {
self.tails_by_head_rel
.entry((t.head, t.relation))
.or_default()
.insert(t.tail);
self.heads_by_tail_rel
.entry((t.tail, t.relation))
.or_default()
.insert(t.head);
}
}
#[inline]
pub fn is_known_tail(&self, head: usize, relation: usize, tail: usize) -> bool {
self.tails_by_head_rel
.get(&(head, relation))
.is_some_and(|tails| tails.contains(&tail))
}
#[inline]
pub fn known_tails(&self, head: usize, relation: usize) -> Option<&HashSet<usize>> {
self.tails_by_head_rel.get(&(head, relation))
}
#[inline]
pub fn is_known_head(&self, tail: usize, relation: usize, head: usize) -> bool {
self.heads_by_tail_rel
.get(&(tail, relation))
.is_some_and(|heads| heads.contains(&head))
}
#[inline]
pub fn known_heads(&self, tail: usize, relation: usize) -> Option<&HashSet<usize>> {
self.heads_by_tail_rel.get(&(tail, relation))
}
}
fn rank_among_entities<B, F>(
entity_boxes: &HashMap<String, B>,
target: &str,
score_fn: F,
filter_known: Option<&HashSet<String>>,
) -> Result<usize, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
F: Fn(&B) -> Result<f32, crate::BoxError>,
{
let target_box = match entity_boxes.get(target) {
Some(b) => b,
None => return Ok(usize::MAX),
};
let target_score = score_fn(target_box)?;
if target_score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered (target)".to_string(),
));
}
let mut better = 0usize;
let mut tie_before = 0usize;
for (entity, box_) in entity_boxes {
if entity == target {
continue;
}
let score = score_fn(box_)?;
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
better += 1;
} else if score == target_score && entity.as_str() < target {
tie_before += 1;
}
}
if let Some(known) = filter_known {
let mut filtered_better = 0usize;
let mut filtered_tie_before = 0usize;
for known_entity in known {
if known_entity == target {
continue;
}
let Some(box_) = entity_boxes.get(known_entity) else {
continue;
};
let score = score_fn(box_)?;
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
filtered_better += 1;
} else if score == target_score && known_entity.as_str() < target {
filtered_tie_before += 1;
}
}
better = better.saturating_sub(filtered_better);
tie_before = tie_before.saturating_sub(filtered_tie_before);
}
Ok(better + tie_before + 1)
}
pub(crate) fn evaluate_link_prediction_inner<B>(
test_triples: &[Triple],
entity_boxes: &HashMap<String, B>,
relation_transforms: Option<&HashMap<String, RelationTransform>>,
filter: Option<&FilteredTripleIndex>,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
if let Some(transforms) = relation_transforms {
for (rel, transform) in transforms {
if !transform.is_identity() {
return Err(crate::BoxError::Internal(format!(
"Non-Identity RelationTransform for relation '{}' requires the interned \
evaluation path with NdarrayBox",
rel
)));
}
}
}
let mut tail_ranks = Vec::with_capacity(test_triples.len());
let mut head_ranks = Vec::with_capacity(test_triples.len());
let mut per_triple: Vec<(&str, usize, usize)> = Vec::with_capacity(test_triples.len());
for triple in test_triples {
let head_box = entity_boxes
.get(&triple.head)
.ok_or_else(|| crate::BoxError::Internal(format!("Missing entity: {}", triple.head)))?;
let filter_tails = filter.and_then(|f| f.known_tails(&triple.head, &triple.relation));
let t_rank = rank_among_entities(
entity_boxes,
&triple.tail,
|candidate| head_box.containment_prob_fast(candidate),
filter_tails,
)?;
let tail_box = entity_boxes
.get(&triple.tail)
.ok_or_else(|| crate::BoxError::Internal(format!("Missing entity: {}", triple.tail)))?;
let filter_heads = filter.and_then(|f| f.known_heads(&triple.tail, &triple.relation));
let h_rank = rank_among_entities(
entity_boxes,
&triple.head,
|candidate| candidate.containment_prob_fast(tail_box),
filter_heads,
)?;
tail_ranks.push(t_rank);
head_ranks.push(h_rank);
per_triple.push((triple.relation.as_str(), t_rank, h_rank));
}
let all_ranks: Vec<usize> = tail_ranks
.iter()
.chain(head_ranks.iter())
.copied()
.collect();
let mrr = mean_reciprocal_rank(all_ranks.iter().copied());
let tail_mrr = mean_reciprocal_rank(tail_ranks.iter().copied());
let head_mrr = mean_reciprocal_rank(head_ranks.iter().copied());
let hits_at_1 = hits_at_k(all_ranks.iter().copied(), 1);
let hits_at_3 = hits_at_k(all_ranks.iter().copied(), 3);
let hits_at_10 = hits_at_k(all_ranks.iter().copied(), 10);
let mean_rank_val = mean_rank(all_ranks.iter().copied());
let per_relation = aggregate_per_relation(&per_triple);
Ok(EvaluationResults {
mrr,
head_mrr,
tail_mrr,
hits_at_1,
hits_at_3,
hits_at_10,
mean_rank: mean_rank_val,
per_relation,
})
}
fn aggregate_per_relation(per_triple: &[(&str, usize, usize)]) -> Vec<PerRelationResults> {
let mut by_rel: HashMap<&str, Vec<usize>> = HashMap::new();
for &(rel, t_rank, h_rank) in per_triple {
let ranks = by_rel.entry(rel).or_default();
ranks.push(t_rank);
ranks.push(h_rank);
}
let mut results: Vec<PerRelationResults> = by_rel
.into_iter()
.map(|(rel, ranks)| {
let count = ranks.len() / 2; let mrr = mean_reciprocal_rank(ranks.iter().copied());
let h10 = hits_at_k(ranks.iter().copied(), 10);
PerRelationResults {
relation: rel.to_string(),
mrr,
hits_at_10: h10,
count,
}
})
.collect();
results.sort_by(|a, b| a.relation.cmp(&b.relation));
results
}
pub(crate) enum ScoreDirection {
Forward,
Reverse,
}
pub(crate) fn rank_among_entities_interned<B>(
entity_boxes: &[B],
entities: &crate::dataset::Vocab,
target_id: usize,
query_box: &B,
direction: &ScoreDirection,
filter_known: Option<&HashSet<usize>>,
scores_buf: &mut Vec<f32>,
) -> Result<usize, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
const CHUNK: usize = 4096;
let target_box = match entity_boxes.get(target_id) {
Some(b) => b,
None => return Ok(usize::MAX),
};
let target_name = entities.get(target_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity label (target): {}", target_id))
})?;
let target_score = match direction {
ScoreDirection::Forward => query_box.containment_prob_fast(target_box)?,
ScoreDirection::Reverse => target_box.containment_prob_fast(query_box)?,
};
if target_score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered (target)".to_string(),
));
}
if scores_buf.len() < CHUNK {
scores_buf.resize(CHUNK, 0.0);
}
let mut better = 0usize;
let mut tie_before = 0usize;
match direction {
ScoreDirection::Forward => {
for start in (0..entity_boxes.len()).step_by(CHUNK) {
let end = (start + CHUNK).min(entity_boxes.len());
let slice = &entity_boxes[start..end];
let len = end - start;
query_box.containment_prob_many(slice, &mut scores_buf[..len])?;
for (i, &score) in scores_buf[..len].iter().enumerate() {
let entity_id = start + i;
if entity_id == target_id {
continue;
}
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
better += 1;
} else if score == target_score {
let name = entities.get(entity_id).ok_or_else(|| {
crate::BoxError::Internal(format!(
"Missing entity label (candidate): {}",
entity_id
))
})?;
if name < target_name {
tie_before += 1;
}
}
}
}
}
ScoreDirection::Reverse => {
for (entity_id, candidate) in entity_boxes.iter().enumerate() {
if entity_id == target_id {
continue;
}
let score = candidate.containment_prob_fast(query_box)?;
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
better += 1;
} else if score == target_score {
let name = entities.get(entity_id).ok_or_else(|| {
crate::BoxError::Internal(format!(
"Missing entity label (candidate): {}",
entity_id
))
})?;
if name < target_name {
tie_before += 1;
}
}
}
}
}
if let Some(known) = filter_known {
let mut filtered_better = 0usize;
let mut filtered_tie_before = 0usize;
for &known_id in known {
if known_id == target_id {
continue;
}
let Some(box_) = entity_boxes.get(known_id) else {
continue;
};
let score = match direction {
ScoreDirection::Forward => query_box.containment_prob_fast(box_)?,
ScoreDirection::Reverse => box_.containment_prob_fast(query_box)?,
};
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
filtered_better += 1;
} else if score == target_score {
let name = entities.get(known_id).ok_or_else(|| {
crate::BoxError::Internal(format!(
"Missing entity label (filtered): {}",
known_id
))
})?;
if name < target_name {
filtered_tie_before += 1;
}
}
}
better = better.saturating_sub(filtered_better);
tie_before = tie_before.saturating_sub(filtered_tie_before);
}
Ok(better + tie_before + 1)
}
#[cfg(feature = "ndarray-backend")]
pub(crate) fn rank_with_translated_query_forward(
entity_boxes: &[crate::ndarray_backend::NdarrayBox],
entities: &crate::dataset::Vocab,
target_id: usize,
query_box: &crate::ndarray_backend::NdarrayBox,
transform: &RelationTransform,
filter_known: Option<&HashSet<usize>>,
) -> Result<usize, crate::BoxError> {
use crate::Box as BoxTrait;
let (new_min, new_max) = transform.apply_to_bounds(
query_box.min().as_slice().unwrap_or(&[]),
query_box.max().as_slice().unwrap_or(&[]),
);
let translated = crate::ndarray_backend::NdarrayBox::new(
ndarray::Array1::from_vec(new_min),
ndarray::Array1::from_vec(new_max),
1.0,
)?;
let target_score =
translated.containment_prob_fast(entity_boxes.get(target_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (target): {target_id}"))
})?)?;
if target_score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered (target)".to_string(),
));
}
let target_name = entities
.get(target_id)
.ok_or_else(|| crate::BoxError::Internal(format!("Missing entity label: {target_id}")))?;
let mut better = 0usize;
let mut tie_before = 0usize;
for (entity_id, candidate) in entity_boxes.iter().enumerate() {
if entity_id == target_id {
continue;
}
let score = translated.containment_prob_fast(candidate)?;
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
better += 1;
} else if score == target_score {
let name = entities.get(entity_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity label: {entity_id}"))
})?;
if name < target_name {
tie_before += 1;
}
}
}
if let Some(known) = filter_known {
let mut filtered_better = 0usize;
let mut filtered_tie_before = 0usize;
for &known_id in known {
if known_id == target_id {
continue;
}
let Some(box_) = entity_boxes.get(known_id) else {
continue;
};
let score = translated.containment_prob_fast(box_)?;
if score.is_nan() {
continue;
}
if score > target_score {
filtered_better += 1;
} else if score == target_score {
let name = entities.get(known_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity label: {known_id}"))
})?;
if name < target_name {
filtered_tie_before += 1;
}
}
}
better = better.saturating_sub(filtered_better);
tie_before = tie_before.saturating_sub(filtered_tie_before);
}
Ok(better + tie_before + 1)
}
#[cfg(feature = "ndarray-backend")]
pub(crate) fn rank_with_translated_query_reverse(
entity_boxes: &[crate::ndarray_backend::NdarrayBox],
entities: &crate::dataset::Vocab,
target_id: usize,
query_box: &crate::ndarray_backend::NdarrayBox,
transform: &RelationTransform,
filter_known: Option<&HashSet<usize>>,
) -> Result<usize, crate::BoxError> {
use crate::Box as BoxTrait;
let inverse_transform = match transform {
RelationTransform::Identity => RelationTransform::Identity,
RelationTransform::Translation(d) => {
RelationTransform::Translation(d.iter().map(|x| -x).collect())
}
};
let (new_min, new_max) = inverse_transform.apply_to_bounds(
query_box.min().as_slice().unwrap_or(&[]),
query_box.max().as_slice().unwrap_or(&[]),
);
let translated = crate::ndarray_backend::NdarrayBox::new(
ndarray::Array1::from_vec(new_min),
ndarray::Array1::from_vec(new_max),
1.0,
)?;
let target_box = entity_boxes.get(target_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (target): {target_id}"))
})?;
let target_score = target_box.containment_prob_fast(&translated)?;
if target_score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered (target)".to_string(),
));
}
let target_name = entities
.get(target_id)
.ok_or_else(|| crate::BoxError::Internal(format!("Missing entity label: {target_id}")))?;
let mut better = 0usize;
let mut tie_before = 0usize;
for (entity_id, candidate) in entity_boxes.iter().enumerate() {
if entity_id == target_id {
continue;
}
let score = candidate.containment_prob_fast(&translated)?;
if score.is_nan() {
return Err(crate::BoxError::Internal(
"NaN containment score encountered".to_string(),
));
}
if score > target_score {
better += 1;
} else if score == target_score {
let name = entities.get(entity_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity label: {entity_id}"))
})?;
if name < target_name {
tie_before += 1;
}
}
}
if let Some(known) = filter_known {
let mut filtered_better = 0usize;
let mut filtered_tie_before = 0usize;
for &known_id in known {
if known_id == target_id {
continue;
}
let Some(box_) = entity_boxes.get(known_id) else {
continue;
};
let score = box_.containment_prob_fast(&translated)?;
if score.is_nan() {
continue;
}
if score > target_score {
filtered_better += 1;
} else if score == target_score {
let name = entities.get(known_id).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity label: {known_id}"))
})?;
if name < target_name {
filtered_tie_before += 1;
}
}
}
better = better.saturating_sub(filtered_better);
tie_before = tie_before.saturating_sub(filtered_tie_before);
}
Ok(better + tie_before + 1)
}
pub(crate) fn evaluate_link_prediction_interned_inner<B>(
test_triples: &[crate::dataset::TripleIds],
entity_boxes: &[B],
entities: &crate::dataset::Vocab,
filter: Option<&FilteredTripleIndexIds>,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
let mut tail_ranks = Vec::with_capacity(test_triples.len());
let mut head_ranks = Vec::with_capacity(test_triples.len());
let mut per_triple: Vec<(usize, usize, usize)> = Vec::with_capacity(test_triples.len());
let mut scores_buf = vec![0.0f32; 4096];
for triple in test_triples {
let head_box = entity_boxes.get(triple.head).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (head): {}", triple.head))
})?;
let tail_box = entity_boxes.get(triple.tail).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (tail): {}", triple.tail))
})?;
let filter_tails = filter.and_then(|f| f.known_tails(triple.head, triple.relation));
let t_rank = rank_among_entities_interned(
entity_boxes,
entities,
triple.tail,
head_box,
&ScoreDirection::Forward,
filter_tails,
&mut scores_buf,
)?;
let filter_heads = filter.and_then(|f| f.known_heads(triple.tail, triple.relation));
let h_rank = rank_among_entities_interned(
entity_boxes,
entities,
triple.head,
tail_box,
&ScoreDirection::Reverse,
filter_heads,
&mut scores_buf,
)?;
tail_ranks.push(t_rank);
head_ranks.push(h_rank);
per_triple.push((triple.relation, t_rank, h_rank));
}
collect_evaluation_results(&tail_ranks, &head_ranks, &per_triple)
}
#[cfg(feature = "ndarray-backend")]
pub(crate) fn evaluate_interned_with_transforms_inner(
test_triples: &[crate::dataset::TripleIds],
entity_boxes: &[crate::ndarray_backend::NdarrayBox],
entities: &crate::dataset::Vocab,
relation_transforms: &[RelationTransform],
filter: Option<&FilteredTripleIndexIds>,
) -> Result<EvaluationResults, crate::BoxError> {
let mut tail_ranks = Vec::with_capacity(test_triples.len());
let mut head_ranks = Vec::with_capacity(test_triples.len());
let mut per_triple: Vec<(usize, usize, usize)> = Vec::with_capacity(test_triples.len());
let mut scores_buf = vec![0.0f32; 4096];
for triple in test_triples {
let head_box = entity_boxes.get(triple.head).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (head): {}", triple.head))
})?;
let tail_box = entity_boxes.get(triple.tail).ok_or_else(|| {
crate::BoxError::Internal(format!("Missing entity id (tail): {}", triple.tail))
})?;
let transform = relation_transforms
.get(triple.relation)
.unwrap_or(&RelationTransform::Identity);
let filter_tails = filter.and_then(|f| f.known_tails(triple.head, triple.relation));
let t_rank = if transform.is_identity() {
rank_among_entities_interned(
entity_boxes,
entities,
triple.tail,
head_box,
&ScoreDirection::Forward,
filter_tails,
&mut scores_buf,
)?
} else {
rank_with_translated_query_forward(
entity_boxes,
entities,
triple.tail,
head_box,
transform,
filter_tails,
)?
};
let filter_heads = filter.and_then(|f| f.known_heads(triple.tail, triple.relation));
let h_rank = if transform.is_identity() {
rank_among_entities_interned(
entity_boxes,
entities,
triple.head,
tail_box,
&ScoreDirection::Reverse,
filter_heads,
&mut scores_buf,
)?
} else {
rank_with_translated_query_reverse(
entity_boxes,
entities,
triple.head,
tail_box,
transform,
filter_heads,
)?
};
tail_ranks.push(t_rank);
head_ranks.push(h_rank);
per_triple.push((triple.relation, t_rank, h_rank));
}
collect_evaluation_results(&tail_ranks, &head_ranks, &per_triple)
}
pub(crate) fn collect_evaluation_results(
tail_ranks: &[usize],
head_ranks: &[usize],
per_triple: &[(usize, usize, usize)],
) -> Result<EvaluationResults, crate::BoxError> {
let all_ranks: Vec<usize> = tail_ranks
.iter()
.chain(head_ranks.iter())
.copied()
.collect();
let mrr = mean_reciprocal_rank(all_ranks.iter().copied());
let tail_mrr = mean_reciprocal_rank(tail_ranks.iter().copied());
let head_mrr = mean_reciprocal_rank(head_ranks.iter().copied());
let hits_at_1 = hits_at_k(all_ranks.iter().copied(), 1);
let hits_at_3 = hits_at_k(all_ranks.iter().copied(), 3);
let hits_at_10 = hits_at_k(all_ranks.iter().copied(), 10);
let mean_rank_val = mean_rank(all_ranks.iter().copied());
let per_relation = aggregate_per_relation_ids(per_triple);
Ok(EvaluationResults {
mrr,
head_mrr,
tail_mrr,
hits_at_1,
hits_at_3,
hits_at_10,
mean_rank: mean_rank_val,
per_relation,
})
}
fn aggregate_per_relation_ids(per_triple: &[(usize, usize, usize)]) -> Vec<PerRelationResults> {
let mut by_rel: HashMap<usize, Vec<usize>> = HashMap::new();
for &(rel, t_rank, h_rank) in per_triple {
let ranks = by_rel.entry(rel).or_default();
ranks.push(t_rank);
ranks.push(h_rank);
}
let mut results: Vec<PerRelationResults> = by_rel
.into_iter()
.map(|(rel, ranks)| {
let count = ranks.len() / 2;
let mrr = mean_reciprocal_rank(ranks.iter().copied());
let h10 = hits_at_k(ranks.iter().copied(), 10);
PerRelationResults {
relation: rel.to_string(),
mrr,
hits_at_10: h10,
count,
}
})
.collect();
results.sort_by(|a, b| a.relation.cmp(&b.relation));
results
}
pub fn evaluate_link_prediction<B>(
test_triples: &[Triple],
entity_boxes: &HashMap<String, B>,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
evaluate_link_prediction_inner(test_triples, entity_boxes, None, None)
}
pub fn evaluate_link_prediction_filtered<B>(
test_triples: &[Triple],
entity_boxes: &HashMap<String, B>,
filter: &FilteredTripleIndex,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
evaluate_link_prediction_inner(test_triples, entity_boxes, None, Some(filter))
}
pub fn evaluate_link_prediction_interned<B>(
test_triples: &[crate::dataset::TripleIds],
entity_boxes: &[B],
entities: &crate::dataset::Vocab,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
evaluate_link_prediction_interned_inner(test_triples, entity_boxes, entities, None)
}
pub fn evaluate_link_prediction_interned_filtered<B>(
test_triples: &[crate::dataset::TripleIds],
entity_boxes: &[B],
entities: &crate::dataset::Vocab,
filter: &FilteredTripleIndexIds,
) -> Result<EvaluationResults, crate::BoxError>
where
B: crate::Box<Scalar = f32>,
{
evaluate_link_prediction_interned_inner(test_triples, entity_boxes, entities, Some(filter))
}
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub fn evaluate_link_prediction_interned_with_transforms(
test_triples: &[crate::dataset::TripleIds],
entity_boxes: &[crate::ndarray_backend::NdarrayBox],
entities: &crate::dataset::Vocab,
relation_transforms: &[RelationTransform],
filter: Option<&FilteredTripleIndexIds>,
) -> Result<EvaluationResults, crate::BoxError> {
evaluate_interned_with_transforms_inner(
test_triples,
entity_boxes,
entities,
relation_transforms,
filter,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn filtered_triple_index_membership() {
let triples = [
Triple {
head: "h".to_string(),
relation: "r".to_string(),
tail: "t1".to_string(),
},
Triple {
head: "h".to_string(),
relation: "r".to_string(),
tail: "t2".to_string(),
},
Triple {
head: "h".to_string(),
relation: "r2".to_string(),
tail: "t3".to_string(),
},
];
let idx = FilteredTripleIndex::from_triples(triples.iter());
assert!(idx.is_known_tail("h", "r", "t1"));
assert!(idx.is_known_tail("h", "r", "t2"));
assert!(!idx.is_known_tail("h", "r", "t3"));
assert!(idx.is_known_tail("h", "r2", "t3"));
assert!(!idx.is_known_tail("missing", "r", "t1"));
}
#[test]
fn link_prediction_rank_linear_matches_deterministic_sort() {
for n in [1usize, 2, 10, 100] {
let ids: Vec<String> = (0..n).map(|i| format!("e{i:03}")).collect();
let mut scores: Vec<(String, f32)> = Vec::with_capacity(n);
for j in 0..n {
let i = (j.wrapping_mul(17) + 3) % n;
let s = (i % 7) as f32 / 7.0;
scores.push((ids[i].clone(), s));
}
let tail = ids[n / 2].clone();
let tail_score = ((n / 2) % 7) as f32 / 7.0;
let mut better = 0usize;
let mut tie_before = 0usize;
for (id, s) in &scores {
if id == &tail {
continue;
}
if *s > tail_score {
better += 1;
} else if *s == tail_score && id.as_str() < tail.as_str() {
tie_before += 1;
}
}
let rank_linear = better + tie_before + 1;
scores.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("no NaNs in test scores")
.then_with(|| a.0.cmp(&b.0))
});
let rank_sort = scores
.iter()
.position(|(id, _)| id == &tail)
.map(|pos| pos + 1)
.unwrap_or(usize::MAX);
assert_eq!(rank_linear, rank_sort);
}
}
#[test]
#[allow(unused_variables)] fn test_evaluate_link_prediction_basic() {
let _empty_boxes: HashMap<String, ()> = HashMap::new();
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_with_ndarray_boxes() {
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let a = NdarrayBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap();
let b = NdarrayBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap();
let c = NdarrayBox::new(array![50.0, 50.0], array![51.0, 51.0], 1.0).unwrap();
let mut entity_boxes = HashMap::new();
entity_boxes.insert("A".to_string(), a);
entity_boxes.insert("B".to_string(), b);
entity_boxes.insert("C".to_string(), c);
let test_triples = vec![Triple {
head: "A".to_string(),
relation: "r".to_string(),
tail: "B".to_string(),
}];
let results = evaluate_link_prediction(&test_triples, &entity_boxes).unwrap();
assert!(
results.mrr > 0.0,
"MRR should be positive, got {}",
results.mrr
);
assert!(results.mean_rank >= 1.0, "mean_rank should be >= 1");
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_empty_triples() {
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let a = NdarrayBox::new(array![0.0], array![1.0], 1.0).unwrap();
let mut entity_boxes = HashMap::new();
entity_boxes.insert("A".to_string(), a);
let results = evaluate_link_prediction::<NdarrayBox>(&[], &entity_boxes).unwrap();
let _ = results;
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_filtered_excludes_known_tails() {
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let a = NdarrayBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap();
let b = NdarrayBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap();
let c = NdarrayBox::new(array![2.0, 2.0], array![4.0, 4.0], 1.0).unwrap();
let d = NdarrayBox::new(array![50.0, 50.0], array![51.0, 51.0], 1.0).unwrap();
let mut entity_boxes = HashMap::new();
entity_boxes.insert("A".to_string(), a);
entity_boxes.insert("B".to_string(), b);
entity_boxes.insert("C".to_string(), c);
entity_boxes.insert("D".to_string(), d);
let test_triples = vec![Triple {
head: "A".to_string(),
relation: "r".to_string(),
tail: "B".to_string(),
}];
let filter_triples = [
Triple {
head: "A".into(),
relation: "r".into(),
tail: "C".into(),
},
Triple {
head: "A".into(),
relation: "r".into(),
tail: "B".into(),
},
];
let filter = FilteredTripleIndex::from_triples(filter_triples.iter());
let unfiltered = evaluate_link_prediction(&test_triples, &entity_boxes).unwrap();
let filtered =
evaluate_link_prediction_filtered(&test_triples, &entity_boxes, &filter).unwrap();
assert!(
filtered.mean_rank <= unfiltered.mean_rank,
"filtered rank ({}) should be <= unfiltered rank ({})",
filtered.mean_rank,
unfiltered.mean_rank
);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_interned_with_ndarray_boxes() {
use crate::dataset::{TripleIds, Vocab};
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let mut vocab = Vocab::default();
let id_a = vocab.intern("A".to_string());
let id_b = vocab.intern("B".to_string());
let _id_c = vocab.intern("C".to_string());
let id_r = 0usize;
let boxes = vec![
NdarrayBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap(), NdarrayBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap(), NdarrayBox::new(array![50.0, 50.0], array![51.0, 51.0], 1.0).unwrap(), ];
let test_triples = vec![TripleIds {
head: id_a,
relation: id_r,
tail: id_b,
}];
let results = evaluate_link_prediction_interned(&test_triples, &boxes, &vocab).unwrap();
assert!(
results.mrr > 0.0,
"MRR should be positive, got {}",
results.mrr
);
assert!(results.mean_rank >= 1.0);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_interned_filtered_with_ndarray_boxes() {
use crate::dataset::{TripleIds, Vocab};
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let mut vocab = Vocab::default();
let id_a = vocab.intern("A".to_string());
let id_b = vocab.intern("B".to_string());
let id_c = vocab.intern("C".to_string());
let id_r = 0usize;
let boxes = vec![
NdarrayBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap(),
NdarrayBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap(),
NdarrayBox::new(array![2.0, 2.0], array![4.0, 4.0], 1.0).unwrap(),
];
let test_triples = vec![TripleIds {
head: id_a,
relation: id_r,
tail: id_b,
}];
let known_triples = [
TripleIds {
head: id_a,
relation: id_r,
tail: id_c,
},
TripleIds {
head: id_a,
relation: id_r,
tail: id_b,
},
];
let filter = FilteredTripleIndexIds::from_triples(known_triples.iter());
let unfiltered = evaluate_link_prediction_interned(&test_triples, &boxes, &vocab).unwrap();
let filtered =
evaluate_link_prediction_interned_filtered(&test_triples, &boxes, &vocab, &filter)
.unwrap();
assert!(
filtered.mean_rank <= unfiltered.mean_rank,
"filtered rank ({}) should be <= unfiltered rank ({})",
filtered.mean_rank,
unfiltered.mean_rank
);
}
#[test]
fn filtered_triple_index_ids_membership() {
use crate::dataset::TripleIds;
let triples = [
TripleIds {
head: 0,
relation: 0,
tail: 1,
},
TripleIds {
head: 0,
relation: 0,
tail: 2,
},
TripleIds {
head: 0,
relation: 1,
tail: 3,
},
];
let idx = FilteredTripleIndexIds::from_triples(triples.iter());
assert!(idx.is_known_tail(0, 0, 1));
assert!(idx.is_known_tail(0, 0, 2));
assert!(!idx.is_known_tail(0, 0, 3)); assert!(idx.is_known_tail(0, 1, 3));
assert!(!idx.is_known_tail(1, 0, 1)); }
#[test]
fn filtered_triple_index_ids_known_tails() {
use crate::dataset::TripleIds;
let triples = [
TripleIds {
head: 0,
relation: 0,
tail: 10,
},
TripleIds {
head: 0,
relation: 0,
tail: 20,
},
];
let idx = FilteredTripleIndexIds::from_triples(triples.iter());
let tails = idx.known_tails(0, 0).unwrap();
assert!(tails.contains(&10));
assert!(tails.contains(&20));
assert!(!tails.contains(&30));
assert!(idx.known_tails(1, 0).is_none());
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn evaluate_link_prediction_deterministic() {
use crate::ndarray_backend::NdarrayBox;
use ndarray::array;
let a = NdarrayBox::new(array![0.0, 0.0], array![10.0, 10.0], 1.0).unwrap();
let b = NdarrayBox::new(array![1.0, 1.0], array![3.0, 3.0], 1.0).unwrap();
let c = NdarrayBox::new(array![50.0, 50.0], array![51.0, 51.0], 1.0).unwrap();
let mut entity_boxes = HashMap::new();
entity_boxes.insert("A".to_string(), a);
entity_boxes.insert("B".to_string(), b);
entity_boxes.insert("C".to_string(), c);
let test_triples = vec![Triple {
head: "A".to_string(),
relation: "r".to_string(),
tail: "B".to_string(),
}];
let r1 = evaluate_link_prediction(&test_triples, &entity_boxes).unwrap();
let r2 = evaluate_link_prediction(&test_triples, &entity_boxes).unwrap();
assert_eq!(r1.mrr, r2.mrr, "MRR differs across runs");
assert_eq!(r1.hits_at_1, r2.hits_at_1, "Hits@1 differs across runs");
assert_eq!(r1.hits_at_3, r2.hits_at_3, "Hits@3 differs across runs");
assert_eq!(r1.hits_at_10, r2.hits_at_10, "Hits@10 differs across runs");
assert_eq!(r1.mean_rank, r2.mean_rank, "mean_rank differs across runs");
}
}