use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead};
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Triple {
pub head: String,
pub relation: String,
pub tail: String,
}
impl Triple {
pub fn new(
head: impl Into<String>,
relation: impl Into<String>,
tail: impl Into<String>,
) -> Self {
Self {
head: head.into(),
relation: relation.into(),
tail: tail.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Vocab {
to_id: HashMap<String, usize>,
from_id: Vec<String>,
}
impl Vocab {
pub fn len(&self) -> usize {
self.from_id.len()
}
pub fn is_empty(&self) -> bool {
self.from_id.is_empty()
}
pub fn intern(&mut self, s: String) -> usize {
if let Some(&id) = self.to_id.get(&s) {
return id;
}
let id = self.from_id.len();
self.from_id.push(s.clone());
self.to_id.insert(s, id);
id
}
pub fn get(&self, id: usize) -> Option<&str> {
self.from_id.get(id).map(|s| s.as_str())
}
pub fn id(&self, s: &str) -> Option<usize> {
self.to_id.get(s).copied()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TripleIds {
pub head: usize,
pub relation: usize,
pub tail: usize,
}
#[derive(Debug, Clone)]
pub struct InternedDataset {
pub train: Vec<TripleIds>,
pub valid: Vec<TripleIds>,
pub test: Vec<TripleIds>,
pub entities: Vocab,
pub relations: Vocab,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum DatasetError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid data format: {0}")]
InvalidFormat(String),
#[error("Missing file: {0}")]
MissingFile(String),
}
#[derive(Debug, Clone)]
pub struct Dataset {
pub train: Vec<Triple>,
pub valid: Vec<Triple>,
pub test: Vec<Triple>,
pub entity_map: Option<HashMap<String, String>>,
pub relation_map: Option<HashMap<String, String>>,
}
impl Dataset {
pub fn new(train: Vec<Triple>, valid: Vec<Triple>, test: Vec<Triple>) -> Self {
Self {
train,
valid,
test,
entity_map: None,
relation_map: None,
}
}
pub fn from_triples(train: Vec<Triple>, valid: Vec<Triple>, test: Vec<Triple>) -> Self {
Self::new(train, valid, test)
}
pub fn entities(&self) -> std::collections::HashSet<String> {
let mut entities = std::collections::HashSet::new();
for triple in self
.train
.iter()
.chain(self.valid.iter())
.chain(self.test.iter())
{
entities.insert(triple.head.clone());
entities.insert(triple.tail.clone());
}
entities
}
pub fn relations(&self) -> std::collections::HashSet<String> {
let mut relations = std::collections::HashSet::new();
for triple in self
.train
.iter()
.chain(self.valid.iter())
.chain(self.test.iter())
{
relations.insert(triple.relation.clone());
}
relations
}
pub fn stats(&self) -> DatasetStats {
DatasetStats {
num_entities: self.entities().len(),
num_relations: self.relations().len(),
num_train: self.train.len(),
num_valid: self.valid.len(),
num_test: self.test.len(),
}
}
pub fn into_interned(self) -> InternedDataset {
let mut entities = Vocab::default();
let mut relations = Vocab::default();
let mut intern_triple = |t: Triple| -> TripleIds {
let head = entities.intern(t.head);
let relation = relations.intern(t.relation);
let tail = entities.intern(t.tail);
TripleIds {
head,
relation,
tail,
}
};
let train = self.train.into_iter().map(&mut intern_triple).collect();
let valid = self.valid.into_iter().map(&mut intern_triple).collect();
let test = self.test.into_iter().map(&mut intern_triple).collect();
InternedDataset {
train,
valid,
test,
entities,
relations,
}
}
}
#[derive(Debug, Clone)]
pub struct DatasetStats {
pub num_entities: usize,
pub num_relations: usize,
pub num_train: usize,
pub num_valid: usize,
pub num_test: usize,
}
pub fn load_dataset(path: &Path) -> Result<Dataset, DatasetError> {
let train_path = path.join("train.txt");
let valid_path = path.join("valid.txt");
let test_path = path.join("test.txt");
let train_triples = load_triples(&train_path)?;
let valid_triples = load_triples(&valid_path)?;
let test_triples = load_triples(&test_path)?;
let entity_map = load_map(&path.join("entities.dict")).ok();
let relation_map = load_map(&path.join("relations.dict")).ok();
let mut dataset = Dataset::new(train_triples, valid_triples, test_triples);
dataset.entity_map = entity_map;
dataset.relation_map = relation_map;
Ok(dataset)
}
fn load_triples(file_path: &Path) -> Result<Vec<Triple>, DatasetError> {
if !file_path.exists() {
return Err(DatasetError::MissingFile(format!(
"Dataset file not found: {}",
file_path.display()
)));
}
let file = File::open(file_path)?;
let reader = io::BufReader::new(file);
let mut triples = Vec::new();
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue; }
let parts: Vec<&str> = if trimmed.contains('\t') {
trimmed.split('\t').collect()
} else {
trimmed.split_whitespace().collect()
};
if parts.len() == 3 {
triples.push(Triple {
head: parts[0].to_string(),
relation: parts[1].to_string(),
tail: parts[2].to_string(),
});
} else if !trimmed.is_empty() {
return Err(DatasetError::InvalidFormat(format!(
"Line {} has invalid format: '{}'. Expected 3 parts (head, relation, tail).",
line_num + 1,
trimmed
)));
}
}
Ok(triples)
}
pub fn load_map(file_path: &Path) -> Result<HashMap<String, String>, DatasetError> {
if !file_path.exists() {
return Err(DatasetError::MissingFile(format!(
"Mapping file not found: {}",
file_path.display()
)));
}
let file = File::open(file_path)?;
let reader = io::BufReader::new(file);
let mut map = HashMap::new();
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let parts: Vec<&str> = if trimmed.contains('\t') {
trimmed.split('\t').collect()
} else {
trimmed.split_whitespace().collect()
};
if parts.len() >= 2 {
let id = parts[0].to_string();
let name = parts[1..].join(" ");
map.insert(id, name);
} else if !trimmed.is_empty() {
return Err(DatasetError::InvalidFormat(format!(
"Line {} has invalid format: '{}'. Expected 2 parts (ID, Name).",
line_num + 1,
trimmed
)));
}
}
Ok(map)
}
#[cfg(test)]
mod intern_tests {
use super::*;
#[test]
fn dataset_into_interned_roundtrips_ids() {
let ds = Dataset::new(
vec![
Triple {
head: "a".to_string(),
relation: "r".to_string(),
tail: "b".to_string(),
},
Triple {
head: "b".to_string(),
relation: "r".to_string(),
tail: "c".to_string(),
},
],
vec![Triple {
head: "a".to_string(),
relation: "r".to_string(),
tail: "c".to_string(),
}],
vec![],
);
let interned = ds.into_interned();
assert_eq!(interned.relations.len(), 1);
assert_eq!(interned.entities.len(), 3);
assert_eq!(interned.train.len(), 2);
assert_eq!(interned.valid.len(), 1);
let t0 = interned.train[0];
assert_eq!(interned.entities.get(t0.head), Some("a"));
assert_eq!(interned.relations.get(t0.relation), Some("r"));
assert_eq!(interned.entities.get(t0.tail), Some("b"));
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::tempdir;
#[test]
fn test_load_triples_success() -> Result<(), DatasetError> {
let dir = tempdir()?;
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path)?;
writeln!(file, "e1 r1 e2")?;
writeln!(file, "e3 r2 e4")?;
let triples = load_triples(&file_path)?;
assert_eq!(triples.len(), 2);
assert_eq!(
triples[0],
Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string()
}
);
Ok(())
}
#[test]
fn test_load_triples_tab_separated() -> Result<(), DatasetError> {
let dir = tempdir()?;
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path)?;
writeln!(file, "e1\tr1\te2")?;
let triples = load_triples(&file_path)?;
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].head, "e1");
Ok(())
}
#[test]
fn test_load_triples_invalid_format() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let mut file = File::create(&file_path).unwrap();
writeln!(file, "e1 r1").unwrap();
let err = load_triples(&file_path).unwrap_err();
assert!(matches!(err, DatasetError::InvalidFormat(_)));
}
#[test]
fn test_load_map_success() -> Result<(), DatasetError> {
let dir = tempdir()?;
let file_path = dir.path().join("map.txt");
let mut file = File::create(&file_path)?;
writeln!(file, "0 entity_0")?;
writeln!(file, "1 entity_1")?;
let map = load_map(&file_path)?;
assert_eq!(map.len(), 2);
assert_eq!(map["0"], "entity_0");
Ok(())
}
#[test]
fn test_load_dataset_success() -> Result<(), DatasetError> {
let dir = tempdir()?;
let train_path = dir.path().join("train.txt");
let valid_path = dir.path().join("valid.txt");
let test_path = dir.path().join("test.txt");
File::create(&train_path)?.write_all(b"e1 r1 e2\n")?;
File::create(&valid_path)?.write_all(b"e3 r2 e4\n")?;
File::create(&test_path)?.write_all(b"e5 r3 e6\n")?;
let dataset = load_dataset(dir.path())?;
assert_eq!(dataset.train.len(), 1);
assert_eq!(dataset.valid.len(), 1);
assert_eq!(dataset.test.len(), 1);
Ok(())
}
#[test]
fn test_dataset_entities() {
let dataset = Dataset::new(
vec![
Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string(),
},
Triple {
head: "e2".to_string(),
relation: "r1".to_string(),
tail: "e3".to_string(),
},
],
vec![],
vec![],
);
let entities = dataset.entities();
assert_eq!(entities.len(), 3);
assert!(entities.contains("e1"));
assert!(entities.contains("e2"));
assert!(entities.contains("e3"));
}
#[test]
fn test_dataset_stats() {
let dataset = Dataset::new(
vec![Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string(),
}],
vec![Triple {
head: "e2".to_string(),
relation: "r1".to_string(),
tail: "e3".to_string(),
}],
vec![Triple {
head: "e3".to_string(),
relation: "r1".to_string(),
tail: "e4".to_string(),
}],
);
let stats = dataset.stats();
assert_eq!(stats.num_train, 1);
assert_eq!(stats.num_valid, 1);
assert_eq!(stats.num_test, 1);
assert_eq!(stats.num_entities, 4);
assert_eq!(stats.num_relations, 1);
}
}