use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EvaluationTriple {
pub head: String,
pub relation: String,
pub tail: String,
}
impl EvaluationTriple {
pub fn new(
head: impl Into<String>,
relation: impl Into<String>,
tail: impl Into<String>,
) -> Self {
Self {
head: head.into(),
relation: relation.into(),
tail: tail.into(),
}
}
pub fn as_tuple(&self) -> (&str, &str, &str) {
(&self.head, &self.relation, &self.tail)
}
}
#[derive(Debug, Clone)]
pub struct KgcDataset {
pub train: Vec<EvaluationTriple>,
pub valid: Vec<EvaluationTriple>,
pub test: Vec<EvaluationTriple>,
pub entity_set: HashSet<String>,
pub relation_set: HashSet<String>,
}
impl KgcDataset {
pub fn from_splits(
train: Vec<EvaluationTriple>,
valid: Vec<EvaluationTriple>,
test: Vec<EvaluationTriple>,
) -> Self {
let mut entity_set = HashSet::new();
let mut relation_set = HashSet::new();
for triple in train.iter().chain(valid.iter()).chain(test.iter()) {
entity_set.insert(triple.head.clone());
entity_set.insert(triple.tail.clone());
relation_set.insert(triple.relation.clone());
}
Self {
train,
valid,
test,
entity_set,
relation_set,
}
}
pub fn tiny_synthetic() -> Self {
let all: Vec<EvaluationTriple> = vec![
EvaluationTriple::new("alice", "knows", "bob"),
EvaluationTriple::new("alice", "knows", "carol"),
EvaluationTriple::new("bob", "knows", "carol"),
EvaluationTriple::new("alice", "lives_in", "paris"),
EvaluationTriple::new("bob", "lives_in", "london"),
EvaluationTriple::new("carol", "works_at", "acme"),
EvaluationTriple::new("paris", "located_in", "france"),
EvaluationTriple::new("london", "located_in", "uk"),
];
let n = all.len();
let train_end = n.saturating_sub(2);
let valid_end = n.saturating_sub(1);
let train = all[..train_end].to_vec();
let valid = all[train_end..valid_end].to_vec();
let test = all[valid_end..].to_vec();
Self::from_splits(train, valid, test)
}
pub fn from_tsv(train_tsv: &str, valid_tsv: &str, test_tsv: &str) -> Self {
let parse = |tsv: &str| -> Vec<EvaluationTriple> {
tsv.lines()
.filter(|l| {
let trimmed = l.trim();
!trimmed.is_empty() && !trimmed.starts_with('#')
})
.filter_map(|line| {
let mut parts = line.splitn(3, '\t');
let head = parts.next()?.trim();
let relation = parts.next()?.trim();
let tail = parts.next()?.trim();
if head.is_empty() || relation.is_empty() || tail.is_empty() {
return None;
}
Some(EvaluationTriple::new(head, relation, tail))
})
.collect()
};
let train = parse(train_tsv);
let valid = parse(valid_tsv);
let test = parse(test_tsv);
Self::from_splits(train, valid, test)
}
pub fn all_positives(&self) -> HashSet<(String, String, String)> {
self.train
.iter()
.chain(self.valid.iter())
.chain(self.test.iter())
.map(|t| (t.head.clone(), t.relation.clone(), t.tail.clone()))
.collect()
}
pub fn sorted_entities(&self) -> Vec<String> {
let mut v: Vec<String> = self.entity_set.iter().cloned().collect();
v.sort_unstable();
v
}
pub fn sorted_relations(&self) -> Vec<String> {
let mut v: Vec<String> = self.relation_set.iter().cloned().collect();
v.sort_unstable();
v
}
pub fn total_triples(&self) -> usize {
self.train.len() + self.valid.len() + self.test.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tiny_synthetic_non_empty_splits() {
let ds = KgcDataset::tiny_synthetic();
assert!(!ds.train.is_empty(), "train should be non-empty");
assert!(!ds.valid.is_empty(), "valid should be non-empty");
assert!(!ds.test.is_empty(), "test should be non-empty");
}
#[test]
fn test_entity_set_populated() {
let ds = KgcDataset::tiny_synthetic();
for entity in &[
"alice", "bob", "carol", "paris", "london", "acme", "france", "uk",
] {
assert!(
ds.entity_set.contains(*entity),
"entity_set should contain '{entity}'"
);
}
}
#[test]
fn test_relation_set_populated() {
let ds = KgcDataset::tiny_synthetic();
for rel in &["knows", "lives_in", "works_at", "located_in"] {
assert!(
ds.relation_set.contains(*rel),
"relation_set should contain '{rel}'"
);
}
}
#[test]
fn test_all_positives_coverage() {
let ds = KgcDataset::tiny_synthetic();
let positives = ds.all_positives();
for t in ds.train.iter().chain(ds.valid.iter()).chain(ds.test.iter()) {
let key = (t.head.clone(), t.relation.clone(), t.tail.clone());
assert!(
positives.contains(&key),
"all_positives missing ({}, {}, {})",
t.head,
t.relation,
t.tail
);
}
assert_eq!(positives.len(), ds.total_triples());
}
#[test]
fn test_from_tsv_parsing() {
let train_tsv = "alice\tknows\tbob\nbob\tknows\tcarol\n";
let valid_tsv = "alice\tlives_in\tparis\n";
let test_tsv = "bob\tlives_in\tlondon\n";
let ds = KgcDataset::from_tsv(train_tsv, valid_tsv, test_tsv);
assert_eq!(ds.train.len(), 2);
assert_eq!(ds.valid.len(), 1);
assert_eq!(ds.test.len(), 1);
assert!(ds.entity_set.contains("alice"));
assert!(ds.relation_set.contains("knows"));
}
#[test]
fn test_from_tsv_skips_blanks_and_comments() {
let train_tsv = "# header\nalice\tknows\tbob\n\n# another comment\nbob\tknows\tcarol\n";
let ds = KgcDataset::from_tsv(train_tsv, "", "");
assert_eq!(ds.train.len(), 2, "should parse exactly 2 data lines");
}
#[test]
fn test_sorted_entities_is_sorted() {
let ds = KgcDataset::tiny_synthetic();
let sorted = ds.sorted_entities();
let mut copy = sorted.clone();
copy.sort_unstable();
assert_eq!(sorted, copy, "sorted_entities should return sorted output");
}
#[test]
fn test_total_triples_consistency() {
let ds = KgcDataset::tiny_synthetic();
assert_eq!(
ds.total_triples(),
ds.train.len() + ds.valid.len() + ds.test.len()
);
}
}