use super::InternedDataset;
use std::collections::{HashMap, HashSet};
pub struct FilterIndex {
by_head_rel: HashMap<(usize, usize), HashSet<usize>>,
by_rel_tail: HashMap<(usize, usize), HashSet<usize>>,
all: HashSet<(usize, usize, usize)>,
}
impl FilterIndex {
pub fn from_dataset(ds: &InternedDataset) -> Self {
let mut by_head_rel: HashMap<(usize, usize), HashSet<usize>> = HashMap::new();
let mut by_rel_tail: HashMap<(usize, usize), HashSet<usize>> = HashMap::new();
let mut all = HashSet::new();
for t in ds.train.iter().chain(&ds.valid).chain(&ds.test) {
by_head_rel
.entry((t.head, t.relation))
.or_default()
.insert(t.tail);
by_rel_tail
.entry((t.relation, t.tail))
.or_default()
.insert(t.head);
all.insert(t.as_tuple());
}
Self {
by_head_rel,
by_rel_tail,
all,
}
}
pub fn known_tails(&self, head: usize, relation: usize) -> &HashSet<usize> {
static EMPTY: std::sync::LazyLock<HashSet<usize>> = std::sync::LazyLock::new(HashSet::new);
self.by_head_rel.get(&(head, relation)).unwrap_or(&EMPTY)
}
pub fn known_heads(&self, relation: usize, tail: usize) -> &HashSet<usize> {
static EMPTY: std::sync::LazyLock<HashSet<usize>> = std::sync::LazyLock::new(HashSet::new);
self.by_rel_tail.get(&(relation, tail)).unwrap_or(&EMPTY)
}
pub fn contains(&self, head: usize, relation: usize, tail: usize) -> bool {
self.all.contains(&(head, relation, tail))
}
pub fn len(&self) -> usize {
self.all.len()
}
pub fn is_empty(&self) -> bool {
self.all.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kge::{Dataset, Triple};
fn sample_dataset() -> InternedDataset {
Dataset::new(
vec![
Triple::new("a", "r1", "b"),
Triple::new("a", "r1", "c"),
Triple::new("b", "r2", "c"),
],
vec![Triple::new("a", "r2", "c")],
vec![Triple::new("c", "r1", "a")],
)
.into_interned()
}
#[test]
fn filter_index_known_tails() {
let ds = sample_dataset();
let filter = FilterIndex::from_dataset(&ds);
let a = ds.entities.id("a").unwrap();
let r1 = ds.relations.id("r1").unwrap();
let tails = filter.known_tails(a, r1);
assert_eq!(tails.len(), 2); }
#[test]
fn filter_index_known_heads() {
let ds = sample_dataset();
let filter = FilterIndex::from_dataset(&ds);
let c = ds.entities.id("c").unwrap();
let r1 = ds.relations.id("r1").unwrap();
let heads = filter.known_heads(r1, c);
assert_eq!(heads.len(), 1); }
#[test]
fn filter_index_contains() {
let ds = sample_dataset();
let filter = FilterIndex::from_dataset(&ds);
let a = ds.entities.id("a").unwrap();
let b = ds.entities.id("b").unwrap();
let r1 = ds.relations.id("r1").unwrap();
assert!(filter.contains(a, r1, b));
assert!(!filter.contains(b, r1, a));
}
#[test]
fn filter_index_len() {
let ds = sample_dataset();
let filter = FilterIndex::from_dataset(&ds);
assert_eq!(filter.len(), 5);
}
#[test]
fn filter_index_unknown_pair_returns_empty() {
let ds = sample_dataset();
let filter = FilterIndex::from_dataset(&ds);
assert!(filter.known_tails(999, 999).is_empty());
assert!(filter.known_heads(999, 999).is_empty());
}
}