pub use lattix::kge::{
load_dataset, load_triples, Dataset, FilterIndex, InternedDataset, Triple, TripleIds, Vocab,
};
pub use lattix::Error as DatasetError;
pub trait DatasetExt {
fn from_all_triples(
triples: Vec<Triple>,
train_ratio: f64,
valid_ratio: f64,
test_ratio: f64,
seed: u64,
) -> Dataset;
}
impl DatasetExt for Dataset {
fn from_all_triples(
mut triples: Vec<Triple>,
train_ratio: f64,
valid_ratio: f64,
test_ratio: f64,
seed: u64,
) -> Dataset {
let n = triples.len();
if n == 0 {
return Dataset::new(Vec::new(), Vec::new(), Vec::new());
}
let mut rng = seed;
let lcg = |s: &mut u64| -> usize {
*s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*s >> 33) as usize
};
for i in (1..n).rev() {
let j = lcg(&mut rng) % (i + 1);
triples.swap(i, j);
}
assert!(
train_ratio >= 0.0 && valid_ratio >= 0.0 && test_ratio >= 0.0,
"split ratios must be non-negative"
);
let total = train_ratio + valid_ratio + test_ratio;
assert!(total > 0.0, "at least one split ratio must be positive");
let train_end = ((train_ratio / total) * n as f64).round() as usize;
let valid_end = train_end + ((valid_ratio / total) * n as f64).round() as usize;
let test = triples.split_off(valid_end.min(n));
let valid = triples.split_off(train_end.min(triples.len()));
let train = triples;
Dataset::new(train, valid, test)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dataset_from_all_triples_splits_correctly() {
let triples = vec![
Triple::new("a", "r", "b"),
Triple::new("b", "r", "c"),
Triple::new("c", "r", "d"),
Triple::new("d", "r", "e"),
Triple::new("e", "r", "f"),
];
let ds = Dataset::from_all_triples(triples, 0.6, 0.2, 0.2, 42);
assert_eq!(ds.train.len() + ds.valid.len() + ds.test.len(), 5);
}
#[test]
fn dataset_into_interned_roundtrips() {
let ds = Dataset::new(
vec![Triple::new("a", "r", "b"), Triple::new("b", "r", "c")],
vec![Triple::new("a", "r", "c")],
vec![],
);
let interned = ds.into_interned();
assert_eq!(interned.entities.len(), 3);
assert_eq!(interned.relations.len(), 1);
assert_eq!(interned.train.len(), 2);
assert_eq!(interned.valid.len(), 1);
let t0 = interned.train[0];
assert_eq!(interned.entities.get(t0.head), Some("a"));
assert_eq!(interned.relations.get(t0.relation), Some("r"));
assert_eq!(interned.entities.get(t0.tail), Some("b"));
}
#[test]
fn from_arrays_roundtrips() {
let train = vec![(0, 0, 1), (1, 0, 2)];
let valid = vec![(0, 0, 2)];
let test = vec![(2, 0, 0)];
let ds = InternedDataset::from_arrays(&train, &valid, &test, 3, 1);
assert_eq!(ds.num_entities(), 3);
assert_eq!(ds.num_relations(), 1);
assert_eq!(ds.entities.get(0), Some("e0"));
}
}