pub use lattix::kge::{
load_dataset, load_triples, Dataset, FilterIndex, InternedDataset, Triple, TripleIds, Vocab,
};
pub trait InternedDatasetExt {
fn add_reciprocals(&mut self);
}
impl InternedDatasetExt for InternedDataset {
fn add_reciprocals(&mut self) {
let n_rel = self.relations.len();
for i in 0..n_rel {
let name = format!("{}_inv", self.relations.get(i).unwrap());
self.relations.intern(name);
}
fn augment(triples: &mut Vec<TripleIds>, n_rel: usize) {
let originals: Vec<_> = triples.clone();
triples.reserve(originals.len());
for t in &originals {
triples.push(TripleIds::new(t.tail, t.relation + n_rel, t.head));
}
}
augment(&mut self.train, n_rel);
augment(&mut self.valid, n_rel);
augment(&mut self.test, n_rel);
}
}
pub trait DatasetExt {
fn split(self, valid_frac: f32, test_frac: f32) -> Dataset;
fn load_flexible(path: &std::path::Path) -> Result<Dataset, lattix::Error>;
}
impl DatasetExt for Dataset {
fn split(self, valid_frac: f32, test_frac: f32) -> Dataset {
let total = self.train.len() + self.valid.len() + self.test.len();
let mut all = self.train;
let mut v = self.valid;
let mut t = self.test;
all.append(&mut v);
all.append(&mut t);
let n_test = (total as f32 * test_frac).round() as usize;
let n_valid = (total as f32 * valid_frac).round() as usize;
let test = all.split_off(all.len().saturating_sub(n_test));
let valid = all.split_off(all.len().saturating_sub(n_valid));
Dataset::new(all, valid, test)
}
fn load_flexible(path: &std::path::Path) -> Result<Dataset, lattix::Error> {
let content = std::fs::read_to_string(path)?;
let triples = parse_flexible(&content);
Ok(Dataset::new(triples, Vec::new(), Vec::new()))
}
}
fn parse_flexible(content: &str) -> Vec<Triple> {
let mut triples = Vec::new();
let mut dropped = 0usize;
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let sep = if trimmed.contains('\t') { '\t' } else { ',' };
let parts: Vec<&str> = trimmed.split(sep).map(str::trim).collect();
if parts.len() >= 3 {
triples.push(Triple::new(parts[0], parts[1], parts[2]));
} else {
dropped += 1;
}
}
if dropped > 0 {
eprintln!("warning: skipped {dropped} malformed lines (expected 3 fields per line)");
}
triples
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
fn write_triples(dir: &std::path::Path, name: &str, triples: &[(&str, &str, &str)]) {
let path = dir.join(name);
let mut f = std::fs::File::create(path).unwrap();
for (h, r, t) in triples {
writeln!(f, "{h}\t{r}\t{t}").unwrap();
}
}
#[test]
fn load_and_intern() {
let dir = tempfile::tempdir().unwrap();
write_triples(
dir.path(),
"train.txt",
&[("A", "r1", "B"), ("B", "r2", "C")],
);
write_triples(dir.path(), "valid.txt", &[("A", "r1", "C")]);
write_triples(dir.path(), "test.txt", &[("C", "r2", "A")]);
let ds = load_dataset(dir.path()).unwrap();
assert_eq!(ds.train.len(), 2);
assert_eq!(ds.valid.len(), 1);
assert_eq!(ds.test.len(), 1);
let interned = ds.into_interned();
assert_eq!(interned.num_entities(), 3);
assert_eq!(interned.num_relations(), 2);
assert_eq!(interned.all_triples().len(), 4);
assert_eq!(interned.entities.id("A"), Some(0));
assert_eq!(interned.entities.id("B"), Some(1));
assert_eq!(interned.entities.id("C"), Some(2));
}
#[test]
fn reciprocal_relations() {
let dir = tempfile::tempdir().unwrap();
write_triples(dir.path(), "train.txt", &[("A", "r1", "B")]);
write_triples(dir.path(), "valid.txt", &[("B", "r1", "C")]);
write_triples(dir.path(), "test.txt", &[("C", "r1", "A")]);
let ds = load_dataset(dir.path()).unwrap();
let mut interned = ds.into_interned();
assert_eq!(interned.num_relations(), 1);
interned.add_reciprocals();
assert_eq!(interned.num_relations(), 2);
assert_eq!(interned.relations.get(1), Some("r1_inv"));
assert_eq!(interned.train.len(), 2);
let t = interned.train[1];
assert_eq!(t.head, interned.entities.id("B").unwrap());
assert_eq!(t.relation, 1); assert_eq!(t.tail, interned.entities.id("A").unwrap());
}
#[test]
fn load_flexible_csv() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("triples.csv");
std::fs::write(&path, "# comment\nAlice,knows,Bob\nBob,works_at,Acme\n").unwrap();
let ds = Dataset::load_flexible(&path).unwrap();
assert_eq!(ds.train.len(), 2);
assert_eq!(ds.train[0].head, "Alice");
assert!(ds.valid.is_empty());
}
#[test]
fn dataset_split() {
let ds = Dataset::new(
(0..100)
.map(|i| Triple::new(format!("e{i}"), "r", format!("e{}", i + 1)))
.collect(),
Vec::new(),
Vec::new(),
);
let ds = ds.split(0.1, 0.1);
assert_eq!(ds.test.len(), 10);
assert_eq!(ds.valid.len(), 10);
assert_eq!(ds.train.len(), 80);
}
#[test]
fn parse_flexible_drops_malformed_lines() {
let content = "Alice\tknows\tBob\nbad line\nCarol\tknows\tDave\n";
let triples = super::parse_flexible(content);
assert_eq!(triples.len(), 2);
assert_eq!(triples[0].head, "Alice");
assert_eq!(triples[1].head, "Carol");
}
#[test]
fn reciprocal_with_multiple_relations() {
let dir = tempfile::tempdir().unwrap();
write_triples(
dir.path(),
"train.txt",
&[("A", "r1", "B"), ("C", "r2", "D")],
);
write_triples(dir.path(), "valid.txt", &[]);
write_triples(dir.path(), "test.txt", &[]);
let ds = load_dataset(dir.path()).unwrap();
let mut interned = ds.into_interned();
assert_eq!(interned.num_relations(), 2);
interned.add_reciprocals();
assert_eq!(interned.num_relations(), 4);
assert_eq!(interned.relations.get(2), Some("r1_inv"));
assert_eq!(interned.relations.get(3), Some("r2_inv"));
assert_eq!(interned.train.len(), 4);
}
}