use std::collections::HashMap;
use subsume::{ConeEmbeddingTrainer, TrainingConfig};
fn main() {
println!("=== Cone Embeddings (ConE): Training on a Taxonomy (18 entities, 4 levels) ===\n");
let entity_names: HashMap<usize, &str> = [
(0, "entity"),
(1, "animal"),
(2, "vehicle"),
(3, "mammal"),
(4, "bird"),
(5, "fish"),
(6, "land_vehicle"),
(7, "aircraft"),
(8, "dog"),
(9, "cat"),
(10, "whale"),
(11, "eagle"),
(12, "sparrow"),
(13, "salmon"),
(14, "tuna"),
(15, "car"),
(16, "truck"),
(17, "helicopter"),
]
.into_iter()
.collect();
let positive_pairs: Vec<(usize, usize)> = vec![
(0, 1), (0, 2), (1, 3), (1, 4), (1, 5), (2, 6), (2, 7), (3, 8), (3, 9), (3, 10), (4, 11), (4, 12), (5, 13), (5, 14), (6, 15), (6, 16), (7, 17), ];
let negative_pairs: Vec<(usize, usize)> = vec![
(8, 0), (15, 0), (11, 1), (13, 1), (1, 2), (2, 1), (3, 6), (4, 7), (8, 9), (9, 8), (8, 10), (10, 8), (10, 9), (11, 12), (13, 14), (14, 13), (15, 16), (16, 15), (17, 15), (3, 1), (5, 1), (6, 2), (7, 2), (5, 6), (6, 5), (3, 7), (4, 6), (5, 3), (12, 14), ];
let n_entities = entity_names.len();
let dim = 16;
let warmup_epochs = 50;
let joint_epochs = 450;
let config = TrainingConfig {
learning_rate: 0.02,
margin: 1.0,
regularization: 0.0,
negative_weight: 0.5,
..Default::default()
};
let mut trainer = ConeEmbeddingTrainer::new(config.clone(), dim, None);
for &id in entity_names.keys() {
trainer.ensure_entity(id);
}
let total_epochs = warmup_epochs + joint_epochs;
println!(
"Training for {} epochs ({} warmup + {} joint, dim={}, {} entities, {} pos + {} neg pairs)...\n",
total_epochs,
warmup_epochs,
joint_epochs,
dim,
n_entities,
positive_pairs.len(),
negative_pairs.len()
);
for epoch in 0..total_epochs {
let mut epoch_loss = 0.0;
let mut n_pairs = 0;
for &(head, tail) in &positive_pairs {
let loss = trainer.train_step(head, tail, true);
epoch_loss += loss;
n_pairs += 1;
}
if epoch >= warmup_epochs {
for &(head, tail) in &negative_pairs {
let loss = trainer.train_step(head, tail, false);
epoch_loss += loss;
n_pairs += 1;
}
}
let avg_loss = epoch_loss / n_pairs as f32;
if epoch % 125 == 0 || epoch == total_epochs - 1 {
let phase = if epoch < warmup_epochs {
"warmup"
} else {
"joint"
};
println!(
" Epoch {:>4} [{}]: avg_loss = {:.4}",
epoch, phase, avg_loss
);
}
}
println!("\n--- Learned Cone Properties ---\n");
println!("{:>14} {:>12} {:>12}", "entity", "mean_aper", "mean_deg");
println!("{}", "-".repeat(40));
let mut entity_ids: Vec<usize> = entity_names.keys().copied().collect();
entity_ids.sort();
for &id in &entity_ids {
let cone = &trainer.cones()[&id];
let mean_aper = cone.mean_aperture();
let degrees = mean_aper.to_degrees();
println!(
"{:>14} {:>12.4} {:>12.1}",
entity_names[&id], mean_aper, degrees
);
}
println!("\n--- Selected Containment Distances (lower = better containment) ---\n");
let cen = 0.02;
let selected_checks: Vec<(&str, usize, usize, bool)> = vec![
("entity > animal", 0, 1, true),
("entity > vehicle", 0, 2, true),
("animal > mammal", 1, 3, true),
("animal > bird", 1, 4, true),
("mammal > dog", 3, 8, true),
("mammal > cat", 3, 9, true),
("bird > eagle", 4, 11, true),
("fish > salmon", 5, 13, true),
("land_vehicle > car", 6, 15, true),
("aircraft > helicopter", 7, 17, true),
("dog > entity (reverse)", 8, 0, false),
("dog > cat (sibling)", 8, 9, false),
("animal > vehicle (cross)", 1, 2, false),
("mammal > land_vehicle (cross)", 3, 6, false),
];
let mut pos_dists = Vec::new();
let mut neg_dists = Vec::new();
for (label, head, tail, expect_low) in &selected_checks {
let d = trainer.cones()[head].cone_distance(&trainer.cones()[tail], cen);
let status = if *expect_low { "POS" } else { "NEG" };
println!(" [{:>3}] {:<30} dist = {:.4}", status, label, d);
if *expect_low {
pos_dists.push(d);
} else {
neg_dists.push(d);
}
}
let avg_pos: f32 = pos_dists.iter().sum::<f32>() / pos_dists.len() as f32;
let avg_neg: f32 = neg_dists.iter().sum::<f32>() / neg_dists.len() as f32;
println!(
"\nAvg positive distance: {:.4}, Avg negative distance: {:.4}",
avg_pos, avg_neg
);
if avg_pos < avg_neg {
println!("Positive pairs have lower distance than negatives (as expected).");
} else {
println!("Warning: separation not achieved. Consider more epochs or tuning.");
}
println!("\nKey takeaways:");
println!(" - More general concepts (entity, animal) get wider mean apertures");
println!(" - More specific concepts (dog, car) get narrower mean apertures");
println!(" - Containment is directional: animal > mammal, but NOT mammal > animal");
println!(" - Cross-branch distance (animal > vehicle) stays high");
println!(" - Unlike boxes, cones support negation: complement of a cone is a cone");
}