mod dataset;
mod filter;
mod metrics;
mod triple;
pub use dataset::{Dataset, InternedDataset, Vocab};
pub use filter::FilterIndex;
pub use metrics::{
adjusted_mean_rank, hits_at_k, mean_rank, mean_reciprocal_rank, per_relation_mrr,
realistic_rank,
};
pub use triple::{Triple, TripleIds};
use crate::{Error, Result};
use std::path::Path;
pub fn load_dataset(path: &Path) -> Result<Dataset> {
let train = load_triples(&path.join("train.txt"))?;
let valid = load_triples(&path.join("valid.txt"))?;
let test = load_triples(&path.join("test.txt"))?;
Ok(Dataset::new(train, valid, test))
}
pub fn load_triples(path: &Path) -> Result<Vec<Triple>> {
use std::fs::File;
use std::io::{BufRead, BufReader};
if !path.exists() {
return Err(Error::MissingFile(format!("{}", path.display())));
}
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut triples = Vec::new();
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let parts: Vec<&str> = if trimmed.contains('\t') {
trimmed.split('\t').collect()
} else {
trimmed.split_whitespace().collect()
};
if parts.len() == 3 {
triples.push(Triple::new(parts[0], parts[1], parts[2]));
} else {
return Err(Error::InvalidFormat(format!(
"{}:{}: expected 3 fields (head, relation, tail), got {}",
path.display(),
line_num + 1,
parts.len()
)));
}
}
Ok(triples)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn load_triples_tab_separated() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::File::create(&path)
.unwrap()
.write_all(b"e1\tr1\te2\ne3\tr2\te4\n")
.unwrap();
let triples = load_triples(&path).unwrap();
assert_eq!(triples.len(), 2);
assert_eq!(triples[0].head, "e1");
assert_eq!(triples[0].relation, "r1");
assert_eq!(triples[0].tail, "e2");
}
#[test]
fn load_triples_whitespace_separated() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::File::create(&path)
.unwrap()
.write_all(b"e1 r1 e2\n")
.unwrap();
let triples = load_triples(&path).unwrap();
assert_eq!(triples.len(), 1);
}
#[test]
fn load_triples_skips_comments_and_blanks() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::File::create(&path)
.unwrap()
.write_all(b"# comment\n\ne1\tr1\te2\n")
.unwrap();
let triples = load_triples(&path).unwrap();
assert_eq!(triples.len(), 1);
}
#[test]
fn load_triples_errors_on_malformed() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
std::fs::File::create(&path)
.unwrap()
.write_all(b"e1 r1\n")
.unwrap();
assert!(load_triples(&path).is_err());
}
#[test]
fn load_triples_errors_on_missing_file() {
let path = Path::new("/nonexistent/test.txt");
assert!(load_triples(path).is_err());
}
#[test]
fn load_dataset_roundtrip() {
let dir = tempfile::tempdir().unwrap();
for name in &["train.txt", "valid.txt", "test.txt"] {
std::fs::File::create(dir.path().join(name))
.unwrap()
.write_all(b"e1\tr1\te2\n")
.unwrap();
}
let ds = load_dataset(dir.path()).unwrap();
assert_eq!(ds.train.len(), 1);
assert_eq!(ds.valid.len(), 1);
assert_eq!(ds.test.len(), 1);
}
#[test]
fn kge_triple_to_core_triple_conversion() {
let kge_t = Triple::new("Alice", "knows", "Bob");
let core_t: crate::Triple = kge_t.into();
assert_eq!(core_t.subject().as_str(), "Alice");
assert_eq!(core_t.predicate().as_str(), "knows");
assert_eq!(core_t.object().as_str(), "Bob");
}
#[test]
fn core_triple_to_kge_triple_conversion() {
let core_t = crate::Triple::new("Alice", "knows", "Bob");
let kge_t: Triple = core_t.into();
assert_eq!(kge_t.head, "Alice");
assert_eq!(kge_t.relation, "knows");
assert_eq!(kge_t.tail, "Bob");
}
}