use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead};
use std::path::Path;
use crate::dataset::{DatasetError, Triple};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TaxonomyNode {
pub id: usize,
pub name: String,
pub definition: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TaxonomyDataset {
pub nodes: Vec<TaxonomyNode>,
pub edges: Vec<(usize, usize)>,
pub node_index: HashMap<usize, usize>,
}
impl TaxonomyDataset {
pub fn load(
terms_path: &Path,
taxo_path: &Path,
dict_path: Option<&Path>,
) -> Result<Self, DatasetError> {
let definitions: HashMap<String, String> = match dict_path {
Some(path) if path.exists() => load_definitions(path)?,
_ => HashMap::new(),
};
let (nodes, node_index) = load_terms(terms_path, &definitions)?;
let edges = load_edges(taxo_path, &node_index)?;
Ok(Self {
nodes,
edges,
node_index,
})
}
pub fn to_triples(&self) -> Vec<Triple> {
self.edges
.iter()
.map(|&(parent_id, child_id)| {
let parent_name = &self.nodes[self.node_index[&parent_id]].name;
let child_name = &self.nodes[self.node_index[&child_id]].name;
Triple {
head: child_name.clone(),
relation: "hypernym".to_string(),
tail: parent_name.clone(),
}
})
.collect()
}
#[allow(clippy::type_complexity)]
pub fn split(
&self,
train_ratio: f64,
val_ratio: f64,
seed: u64,
) -> (
Vec<(usize, usize)>,
Vec<(usize, usize)>,
Vec<(usize, usize)>,
) {
let mut edges = self.edges.clone();
deterministic_shuffle(&mut edges, seed);
let n = edges.len();
let train_end = (n as f64 * train_ratio).round() as usize;
let val_end = train_end + (n as f64 * val_ratio).round() as usize;
let val_end = val_end.min(n);
let test = edges.split_off(val_end);
let val = edges.split_off(train_end);
let train = edges;
(train, val, test)
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_edges(&self) -> usize {
self.edges.len()
}
}
fn deterministic_shuffle<T>(slice: &mut [T], seed: u64) {
let mut state = seed.wrapping_add(1); for i in (1..slice.len()).rev() {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let j = (state as usize) % (i + 1);
slice.swap(i, j);
}
}
fn load_terms(
path: &Path,
definitions: &HashMap<String, String>,
) -> Result<(Vec<TaxonomyNode>, HashMap<usize, usize>), DatasetError> {
if !path.exists() {
return Err(DatasetError::MissingFile(format!(
"Terms file not found: {}",
path.display()
)));
}
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let mut nodes = Vec::new();
let mut node_index = HashMap::new();
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let parts: Vec<&str> = trimmed.splitn(2, '\t').collect();
if parts.len() != 2 {
return Err(DatasetError::InvalidFormat(format!(
"{}:{}: expected tab-separated 'id\\tname', got '{}'",
path.display(),
line_num + 1,
trimmed,
)));
}
let id: usize = parts[0].parse().map_err(|_| {
DatasetError::InvalidFormat(format!(
"{}:{}: invalid node ID '{}'",
path.display(),
line_num + 1,
parts[0],
))
})?;
let name = parts[1].to_string();
let definition = definitions.get(&name).cloned();
let idx = nodes.len();
node_index.insert(id, idx);
nodes.push(TaxonomyNode {
id,
name,
definition,
});
}
Ok((nodes, node_index))
}
fn load_edges(
path: &Path,
node_index: &HashMap<usize, usize>,
) -> Result<Vec<(usize, usize)>, DatasetError> {
if !path.exists() {
return Err(DatasetError::MissingFile(format!(
"Taxonomy file not found: {}",
path.display()
)));
}
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let mut edges = Vec::new();
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let parts: Vec<&str> = trimmed.split('\t').collect();
if parts.len() != 2 {
return Err(DatasetError::InvalidFormat(format!(
"{}:{}: expected tab-separated 'parent_id\\tchild_id', got '{}'",
path.display(),
line_num + 1,
trimmed,
)));
}
let parent_id: usize = parts[0].parse().map_err(|_| {
DatasetError::InvalidFormat(format!(
"{}:{}: invalid parent ID '{}'",
path.display(),
line_num + 1,
parts[0],
))
})?;
let child_id: usize = parts[1].parse().map_err(|_| {
DatasetError::InvalidFormat(format!(
"{}:{}: invalid child ID '{}'",
path.display(),
line_num + 1,
parts[1],
))
})?;
if !node_index.contains_key(&parent_id) {
return Err(DatasetError::InvalidFormat(format!(
"{}:{}: parent ID {} not found in terms file",
path.display(),
line_num + 1,
parent_id,
)));
}
if !node_index.contains_key(&child_id) {
return Err(DatasetError::InvalidFormat(format!(
"{}:{}: child ID {} not found in terms file",
path.display(),
line_num + 1,
child_id,
)));
}
edges.push((parent_id, child_id));
}
Ok(edges)
}
#[cfg(feature = "ndarray-backend")]
fn load_definitions(path: &Path) -> Result<HashMap<String, String>, DatasetError> {
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let map: HashMap<String, String> = serde_json::from_reader(reader).map_err(|e| {
DatasetError::InvalidFormat(format!("Failed to parse dictionary JSON: {e}"))
})?;
Ok(map)
}
#[cfg(not(feature = "ndarray-backend"))]
fn load_definitions(_path: &Path) -> Result<HashMap<String, String>, DatasetError> {
Ok(HashMap::new())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::tempdir;
fn write_file(dir: &Path, name: &str, content: &str) -> std::path::PathBuf {
let path = dir.join(name);
let mut f = File::create(&path).unwrap();
f.write_all(content.as_bytes()).unwrap();
path
}
#[test]
fn load_small_taxonomy() {
let dir = tempdir().unwrap();
let terms = write_file(
dir.path(),
"test.terms",
"0\tanimal\n1\tdog\n2\tcat\n3\tmammal\n",
);
let taxo = write_file(dir.path(), "test.taxo", "0\t3\n3\t1\n3\t2\n");
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
assert_eq!(ds.num_nodes(), 4);
assert_eq!(ds.num_edges(), 3);
let animal_idx = ds.node_index[&0];
assert_eq!(ds.nodes[animal_idx].name, "animal");
assert!(ds.nodes[animal_idx].definition.is_none());
}
#[test]
fn to_triples_produces_hypernym_relation() {
let dir = tempdir().unwrap();
let terms = write_file(dir.path(), "t.terms", "10\tparent\n20\tchild\n");
let taxo = write_file(dir.path(), "t.taxo", "10\t20\n");
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
let triples = ds.to_triples();
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].head, "child");
assert_eq!(triples[0].relation, "hypernym");
assert_eq!(triples[0].tail, "parent");
}
#[test]
fn split_covers_all_edges() {
let dir = tempdir().unwrap();
let terms_content: String = (0..10).map(|i| format!("{i}\tn{i}\n")).collect();
let taxo_content: String = (0..9).map(|i| format!("{i}\t{}\n", i + 1)).collect();
let terms = write_file(dir.path(), "s.terms", &terms_content);
let taxo = write_file(dir.path(), "s.taxo", &taxo_content);
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
let (train, val, test) = ds.split(0.6, 0.2, 42);
assert_eq!(train.len() + val.len() + test.len(), 9);
assert!(
train.len() >= 4 && train.len() <= 6,
"train len = {}",
train.len()
);
assert!(!val.is_empty() && val.len() <= 3, "val len = {}", val.len());
}
#[test]
fn split_is_deterministic() {
let dir = tempdir().unwrap();
let terms_content: String = (0..20).map(|i| format!("{i}\tn{i}\n")).collect();
let taxo_content: String = (0..19).map(|i| format!("{i}\t{}\n", i + 1)).collect();
let terms = write_file(dir.path(), "d.terms", &terms_content);
let taxo = write_file(dir.path(), "d.taxo", &taxo_content);
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
let (t1, v1, e1) = ds.split(0.7, 0.15, 123);
let (t2, v2, e2) = ds.split(0.7, 0.15, 123);
assert_eq!(t1, t2);
assert_eq!(v1, v2);
assert_eq!(e1, e2);
}
#[test]
fn missing_terms_file_errors() {
let dir = tempdir().unwrap();
let taxo = write_file(dir.path(), "x.taxo", "0\t1\n");
let result = TaxonomyDataset::load(&dir.path().join("missing.terms"), &taxo, None);
assert!(matches!(result, Err(DatasetError::MissingFile(_))));
}
#[test]
fn invalid_id_in_taxo_errors() {
let dir = tempdir().unwrap();
let terms = write_file(dir.path(), "e.terms", "0\ta\n1\tb\n");
let taxo = write_file(dir.path(), "e.taxo", "0\t99\n");
let result = TaxonomyDataset::load(&terms, &taxo, None);
assert!(matches!(result, Err(DatasetError::InvalidFormat(_))));
}
#[test]
fn test_load_rejects_malformed_terms() {
let dir = tempdir().unwrap();
let terms = write_file(dir.path(), "bad.terms", "0 animal\n1\tdog\n");
let taxo = write_file(dir.path(), "bad.taxo", "0\t1\n");
let result = TaxonomyDataset::load(&terms, &taxo, None);
assert!(
matches!(result, Err(DatasetError::InvalidFormat(_))),
"should reject terms line without tab separator, got {result:?}"
);
}
#[test]
fn test_split_deterministic() {
let dir = tempdir().unwrap();
let terms_content: String = (0..50).map(|i| format!("{i}\tn{i}\n")).collect();
let taxo_content: String = (0..49).map(|i| format!("{i}\t{}\n", i + 1)).collect();
let terms = write_file(dir.path(), "det.terms", &terms_content);
let taxo = write_file(dir.path(), "det.taxo", &taxo_content);
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
for seed in [0, 42, 12345, u64::MAX] {
let (t1, v1, e1) = ds.split(0.6, 0.2, seed);
let (t2, v2, e2) = ds.split(0.6, 0.2, seed);
assert_eq!(t1, t2, "train differs for seed {seed}");
assert_eq!(v1, v2, "val differs for seed {seed}");
assert_eq!(e1, e2, "test differs for seed {seed}");
}
let (t_a, _, _) = ds.split(0.6, 0.2, 1);
let (t_b, _, _) = ds.split(0.6, 0.2, 2);
assert_ne!(
t_a, t_b,
"different seeds should (almost surely) produce different splits"
);
}
#[test]
fn test_to_triples_parent_child_direction() {
let dir = tempdir().unwrap();
let terms = write_file(dir.path(), "dir.terms", "100\tanimal\n200\tdog\n");
let taxo = write_file(dir.path(), "dir.taxo", "100\t200\n");
let ds = TaxonomyDataset::load(&terms, &taxo, None).unwrap();
let triples = ds.to_triples();
assert_eq!(triples.len(), 1);
assert_eq!(triples[0].head, "dog", "child should be head");
assert_eq!(triples[0].tail, "animal", "parent should be tail");
assert_eq!(triples[0].relation, "hypernym");
}
#[cfg(feature = "ndarray-backend")]
#[test]
fn load_with_definitions() {
let dir = tempdir().unwrap();
let terms = write_file(dir.path(), "def.terms", "0\tanimal\n1\tdog\n");
let taxo = write_file(dir.path(), "def.taxo", "0\t1\n");
let dict = write_file(
dir.path(),
"dic.json",
r#"{"animal": "A living organism", "dog": "A domesticated canid"}"#,
);
let ds = TaxonomyDataset::load(&terms, &taxo, Some(&dict)).unwrap();
assert_eq!(
ds.nodes[ds.node_index[&0]].definition.as_deref(),
Some("A living organism"),
);
assert_eq!(
ds.nodes[ds.node_index[&1]].definition.as_deref(),
Some("A domesticated canid"),
);
}
}