use std::io::Write;
use subsume::dataset::load_dataset;
use subsume::trainer::{BoxEmbeddingTrainer, FilteredTripleIndexIds};
use subsume::TrainingConfig;
const TRAIN_DATA: &str = "\
dog.n.01\t_hypernym\tcanine.n.02
canine.n.02\t_hypernym\tcarnivore.n.01
carnivore.n.01\t_hypernym\tplacental.n.01
placental.n.01\t_hypernym\tmammal.n.01
mammal.n.01\t_hypernym\tvertebrate.n.01
vertebrate.n.01\t_hypernym\tchordate.n.01
chordate.n.01\t_hypernym\tanimal.n.01
animal.n.01\t_hypernym\torganism.n.01
organism.n.01\t_hypernym\tentity.n.01
cat.n.01\t_hypernym\tfeline.n.01
feline.n.01\t_hypernym\tcarnivore.n.01
wolf.n.01\t_hypernym\tcanine.n.02
fox.n.01\t_hypernym\tcanine.n.02
lion.n.01\t_hypernym\tfeline.n.01
tiger.n.01\t_hypernym\tfeline.n.01
horse.n.01\t_hypernym\tequine.n.01
equine.n.01\t_hypernym\tplacental.n.01
eagle.n.01\t_hypernym\tbird_of_prey.n.01
bird_of_prey.n.01\t_hypernym\tbird.n.01
bird.n.01\t_hypernym\tvertebrate.n.01
sparrow.n.01\t_hypernym\tpasserine.n.01
passerine.n.01\t_hypernym\tbird.n.01
salmon.n.01\t_hypernym\tfish.n.01
fish.n.01\t_hypernym\tvertebrate.n.01
trout.n.01\t_hypernym\tfish.n.01
oak.n.01\t_hypernym\ttree.n.01
tree.n.01\t_hypernym\tplant.n.02
plant.n.02\t_hypernym\torganism.n.01
pine.n.01\t_hypernym\ttree.n.01
rose.n.01\t_hypernym\tflower.n.01
flower.n.01\t_hypernym\tplant.n.02
tulip.n.01\t_hypernym\tflower.n.01
car.n.01\t_hypernym\tvehicle.n.01
vehicle.n.01\t_hypernym\tartifact.n.01
artifact.n.01\t_hypernym\tentity.n.01
truck.n.01\t_hypernym\tvehicle.n.01
bicycle.n.01\t_hypernym\tvehicle.n.01
whale.n.01\t_hypernym\tplacental.n.01
dolphin.n.01\t_hypernym\tplacental.n.01
snake.n.01\t_hypernym\treptile.n.01
reptile.n.01\t_hypernym\tvertebrate.n.01
lizard.n.01\t_hypernym\treptile.n.01
penguin.n.01\t_hypernym\tbird.n.01
bat.n.01\t_hypernym\tplacental.n.01
spider.n.01\t_hypernym\tarthropod.n.01
arthropod.n.01\t_hypernym\tanimal.n.01";
const VALID_DATA: &str = "\
wolf.n.01\t_hypernym\tcarnivore.n.01
horse.n.01\t_hypernym\tmammal.n.01
eagle.n.01\t_hypernym\tbird.n.01
oak.n.01\t_hypernym\tplant.n.02
car.n.01\t_hypernym\tartifact.n.01
snake.n.01\t_hypernym\tvertebrate.n.01
spider.n.01\t_hypernym\tanimal.n.01";
const TEST_DATA: &str = "\
fox.n.01\t_hypernym\tcarnivore.n.01
lion.n.01\t_hypernym\tcarnivore.n.01
tiger.n.01\t_hypernym\tcarnivore.n.01
trout.n.01\t_hypernym\tvertebrate.n.01
pine.n.01\t_hypernym\tplant.n.02
tulip.n.01\t_hypernym\tplant.n.02
truck.n.01\t_hypernym\tartifact.n.01";
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Training Box Embeddings on WordNet Subset ===\n");
let dir = tempfile::tempdir()?;
let mut train_file = std::fs::File::create(dir.path().join("train.txt"))?;
let mut valid_file = std::fs::File::create(dir.path().join("valid.txt"))?;
let mut test_file = std::fs::File::create(dir.path().join("test.txt"))?;
write!(train_file, "{TRAIN_DATA}")?;
write!(valid_file, "{VALID_DATA}")?;
write!(test_file, "{TEST_DATA}")?;
let dataset = load_dataset(dir.path())?;
let stats = dataset.stats();
println!(
"Dataset: {} entities, {} relations, {} train / {} valid / {} test triples",
stats.num_entities, stats.num_relations, stats.num_train, stats.num_valid, stats.num_test
);
let interned = dataset.into_interned();
let train_triples: Vec<(usize, usize, usize)> = interned
.train
.iter()
.map(|t| (t.head, t.relation, t.tail))
.collect();
let filter = FilteredTripleIndexIds::from_dataset(&interned);
let config = TrainingConfig {
learning_rate: 0.02,
epochs: 500,
margin: 0.1,
regularization: 0.0001,
negative_weight: 1.0,
early_stopping_patience: Some(50),
warmup_epochs: 10,
gumbel_beta: 10.0,
gumbel_beta_final: 50.0,
..Default::default()
};
let dim = 16;
let mut trainer = BoxEmbeddingTrainer::new(config, dim);
println!(
"\nTraining for up to 500 epochs (dim={}, {} train triples)...\n",
dim,
train_triples.len()
);
let result = trainer.fit(
&train_triples,
Some((&interned.valid, &interned.entities)),
Some(&filter),
)?;
let actual_epochs = result.loss_history.len();
println!(
" Trained {} epochs (best epoch: {})",
actual_epochs, result.best_epoch
);
if let (Some(&first), Some(&last)) = (result.loss_history.first(), result.loss_history.last()) {
println!(" Loss: {:.4} -> {:.4}", first, last);
}
println!("\n--- Test Set Evaluation (filtered) ---\n");
let test_results = trainer.evaluate(&interned.test, &interned.entities, Some(&filter))?;
println!(" MRR: {:.4}", test_results.mrr);
println!(" Hits@1: {:.4}", test_results.hits_at_1);
println!(" Hits@3: {:.4}", test_results.hits_at_3);
println!(" Hits@10: {:.4}", test_results.hits_at_10);
println!(" Mean Rank: {:.1}", test_results.mean_rank);
let checkpoint_path = std::path::Path::new("pretrained/wordnet_subset.json");
std::fs::create_dir_all(checkpoint_path.parent().unwrap())?;
let checkpoint = serde_json::to_string_pretty(&trainer)?;
std::fs::write(checkpoint_path, &checkpoint)?;
println!("\n--- Checkpoint ---\n");
println!(" Saved to: {}", checkpoint_path.display());
println!(
" Size: {} bytes ({:.1} KB)",
checkpoint.len(),
checkpoint.len() as f64 / 1024.0
);
println!(" Entities: {}", trainer.boxes.len());
let reloaded: BoxEmbeddingTrainer = serde_json::from_str(&checkpoint)?;
assert_eq!(reloaded.boxes.len(), trainer.boxes.len());
println!(" Round-trip: OK");
Ok(())
}