use crate::optimizer::AMSGradState;
use crate::trainable::TrainableCone;
use crate::BoxError;
use std::collections::{HashMap, HashSet};
use super::TrainingConfig;
const CONE_CENTER_WEIGHT: f32 = 0.02;
const CONE_GRADIENT_STRENGTH: f32 = 0.2;
const CONE_APERTURE_GRADIENT: f32 = 0.05;
const CONE_VIOLATION_CLAMP: f32 = 1.0;
pub fn compute_cone_pair_loss(
cone_a: &TrainableCone,
cone_b: &TrainableCone,
is_positive: bool,
config: &TrainingConfig,
) -> f32 {
let dense_a = cone_a.to_cone();
let dense_b = cone_b.to_cone();
let cen = CONE_CENTER_WEIGHT;
if is_positive {
let dist = dense_a.cone_distance(&dense_b, cen);
let mean_aper_a: f32 = dense_a.apertures.iter().sum::<f32>() / dense_a.dim() as f32;
let mean_aper_b: f32 = dense_b.apertures.iter().sum::<f32>() / dense_b.dim() as f32;
let reg = config.regularization * (mean_aper_a + mean_aper_b);
(dist + reg).max(0.0)
} else {
let dist = dense_a.cone_distance(&dense_b, cen);
let margin_loss = if dist < config.margin {
(config.margin - dist).powi(2)
} else {
0.0
};
config.negative_weight * margin_loss
}
}
pub fn compute_cone_analytical_gradients(
cone_a: &TrainableCone,
cone_b: &TrainableCone,
is_positive: bool,
config: &TrainingConfig,
) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let dim = cone_a.dim();
let mut grad_axes_a = vec![0.0f32; dim];
let mut grad_aper_a = vec![0.0f32; dim];
let mut grad_axes_b = vec![0.0f32; dim];
let mut grad_aper_b = vec![0.0f32; dim];
let dense_a = cone_a.to_cone();
let dense_b = cone_b.to_cone();
if is_positive {
for i in 0..dim {
let dist_to_axis = ((dense_b.axes[i] - dense_a.axes[i]) / 2.0).sin().abs();
let dist_base = (dense_a.apertures[i] / 2.0).sin().abs();
let diff = dense_b.axes[i] - dense_a.axes[i];
if dist_to_axis >= dist_base {
let violation = dist_to_axis - dist_base;
let strength = CONE_GRADIENT_STRENGTH * violation.min(CONE_VIOLATION_CLAMP);
grad_axes_a[i] = -strength * diff.signum(); grad_axes_b[i] = strength * diff.signum(); grad_aper_a[i] = -strength;
grad_aper_b[i] = CONE_APERTURE_GRADIENT * strength;
}
}
} else {
let dist = dense_a.cone_distance(&dense_b, CONE_CENTER_WEIGHT);
if dist < config.margin {
let urgency = (config.margin - dist) / config.margin; for i in 0..dim {
let dist_to_axis = ((dense_b.axes[i] - dense_a.axes[i]) / 2.0).sin().abs();
let dist_base = (dense_a.apertures[i] / 2.0).sin().abs();
if dist_to_axis < dist_base {
let diff = dense_b.axes[i] - dense_a.axes[i];
let margin = dist_base - dist_to_axis;
let strength =
CONE_GRADIENT_STRENGTH * urgency * margin.min(CONE_VIOLATION_CLAMP);
grad_axes_a[i] = strength * diff.signum(); grad_axes_b[i] = -strength * diff.signum();
grad_aper_a[i] = strength;
}
}
}
}
(grad_axes_a, grad_aper_a, grad_axes_b, grad_aper_b)
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ConeEmbeddingTrainer {
pub config: TrainingConfig,
pub cones: HashMap<usize, TrainableCone>,
pub optimizer_states: HashMap<usize, AMSGradState>,
pub dim: usize,
}
impl ConeEmbeddingTrainer {
pub fn new(
config: TrainingConfig,
dim: usize,
initial_embeddings: Option<HashMap<usize, Vec<f32>>>,
) -> Self {
let mut cones = HashMap::new();
let mut optimizer_states = HashMap::new();
if let Some(embeddings) = initial_embeddings {
for (entity_id, vector) in embeddings {
assert_eq!(vector.len(), dim);
let cone = TrainableCone::from_vector(&vector, std::f32::consts::FRAC_PI_2);
let n_params = cone.num_parameters();
cones.insert(entity_id, cone);
optimizer_states
.insert(entity_id, AMSGradState::new(n_params, config.learning_rate));
}
}
Self {
config,
cones,
optimizer_states,
dim,
}
}
pub fn ensure_entity(&mut self, id: usize) {
if !self.cones.contains_key(&id) {
let mut init_vec = vec![0.0f32; self.dim];
if self.dim > 0 {
init_vec[id % self.dim] = 1.0;
}
let cone = TrainableCone::from_vector(&init_vec, std::f32::consts::FRAC_PI_2);
let n_params = cone.num_parameters();
self.cones.insert(id, cone);
self.optimizer_states
.insert(id, AMSGradState::new(n_params, self.config.learning_rate));
}
}
pub fn train_step(&mut self, id_a: usize, id_b: usize, is_positive: bool) -> f32 {
self.ensure_entity(id_a);
self.ensure_entity(id_b);
let cone_a = self
.cones
.get(&id_a)
.cloned()
.expect("ensure_entity guarantees key exists");
let cone_b = self
.cones
.get(&id_b)
.cloned()
.expect("ensure_entity guarantees key exists");
let loss = compute_cone_pair_loss(&cone_a, &cone_b, is_positive, &self.config);
let (grad_axes_a, grad_aper_a, grad_axes_b, grad_aper_b) =
compute_cone_analytical_gradients(&cone_a, &cone_b, is_positive, &self.config);
if let (Some(c), Some(s)) = (
self.cones.get_mut(&id_a),
self.optimizer_states.get_mut(&id_a),
) {
c.update_amsgrad(&grad_axes_a, &grad_aper_a, s);
}
if let (Some(c), Some(s)) = (
self.cones.get_mut(&id_b),
self.optimizer_states.get_mut(&id_b),
) {
c.update_amsgrad(&grad_axes_b, &grad_aper_b, s);
}
loss
}
pub fn train_step_batch(&mut self, triples: &[(usize, usize, usize)]) -> Result<f32, BoxError> {
if triples.is_empty() {
return Err(BoxError::Internal("empty triple set".to_string()));
}
let entity_ids: Vec<usize> = triples
.iter()
.flat_map(|&(h, _, t)| [h, t])
.collect::<HashSet<_>>()
.into_iter()
.collect();
let mut total_loss = 0.0f32;
let mut count = 0usize;
for &(h, _r, t) in triples {
total_loss += self.train_step(h, t, true);
count += 1;
if entity_ids.len() > 1 {
let idx = (h.wrapping_mul(31).wrapping_add(t).wrapping_add(7)) % entity_ids.len();
let candidate = entity_ids[idx];
let neg_t = if candidate == t {
entity_ids[(idx + 1) % entity_ids.len()]
} else {
candidate
};
total_loss += self.train_step(h, neg_t, false);
count += 1;
}
}
Ok(total_loss / count as f32)
}
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub fn get_cone(&self, entity_id: usize) -> Option<crate::ndarray_backend::NdarrayCone> {
self.cones
.get(&entity_id)
.and_then(|c| c.to_ndarray_cone().ok())
}
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub fn get_all_cones(&self) -> HashMap<usize, crate::ndarray_backend::NdarrayCone> {
self.cones
.iter()
.filter_map(|(&id, c): (&usize, &TrainableCone)| {
c.to_ndarray_cone().ok().map(|nc| (id, nc))
})
.collect()
}
pub fn export_embeddings(&self) -> (Vec<usize>, Vec<f32>, Vec<f32>) {
let mut ids: Vec<usize> = self.cones.keys().copied().collect();
ids.sort_unstable();
let n = ids.len();
let mut axes = Vec::with_capacity(n * self.dim);
let mut apertures = Vec::with_capacity(n * self.dim);
for &id in &ids {
let c = &self.cones[&id];
axes.extend_from_slice(&c.axes());
apertures.extend_from_slice(&c.apertures());
}
(ids, axes, apertures)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trainable::TrainableCone;
#[test]
fn cone_pair_loss_positive_prefers_containment() {
let cfg = TrainingConfig::default();
let a = TrainableCone::new(vec![0.0, 0.0], vec![2.0, 2.0]).unwrap(); let b_in = TrainableCone::new(vec![0.0, 0.0], vec![-2.0, -2.0]).unwrap();
let b_out = TrainableCone::new(vec![3.0, 3.0], vec![-2.0, -2.0]).unwrap();
let l_in = compute_cone_pair_loss(&a, &b_in, true, &cfg);
let l_out = compute_cone_pair_loss(&a, &b_out, true, &cfg);
assert!(l_in.is_finite() && l_out.is_finite());
assert!(
l_in < l_out,
"positive loss should be lower for contained cones (got l_in={l_in}, l_out={l_out})"
);
}
#[test]
fn cone_trainer_train_step_does_not_panic() {
let cfg = TrainingConfig::default();
let mut trainer = ConeEmbeddingTrainer::new(cfg, 4, None);
let loss = trainer.train_step(0, 1, true);
assert!(loss.is_finite(), "loss must be finite, got {}", loss);
let loss_neg = trainer.train_step(0, 2, false);
assert!(
loss_neg.is_finite(),
"negative loss must be finite, got {}",
loss_neg
);
}
#[test]
fn cone_trainer_reduces_loss_over_steps() {
let cfg = TrainingConfig {
learning_rate: 0.01,
temperature: 1.0,
regularization: 0.0,
..Default::default()
};
let mut trainer = ConeEmbeddingTrainer::new(cfg, 4, None);
let mut losses = Vec::new();
for _ in 0..50 {
let loss = trainer.train_step(0, 1, true);
losses.push(loss);
}
let early_avg: f32 = losses[..10].iter().sum::<f32>() / 10.0;
let late_avg: f32 = losses[40..].iter().sum::<f32>() / 10.0;
assert!(
late_avg <= early_avg + 0.5,
"loss should generally decrease: early_avg={early_avg}, late_avg={late_avg}"
);
}
}