use candle_core::{Device, Result, Tensor, Var};
use crate::el_training::{Axiom, Ontology};
pub struct CandleElTrainer {
pub concept_centers: Var,
pub concept_offsets: Var,
pub bumps: Var,
pub role_heads: Var,
pub role_tails: Var,
pub dim: usize,
pub num_concepts: usize,
pub num_roles: usize,
pub margin: f32,
pub neg_dist: f32,
pub device: Device,
}
impl CandleElTrainer {
pub fn new(
num_concepts: usize,
num_roles: usize,
dim: usize,
margin: f32,
neg_dist: f32,
device: &Device,
) -> Result<Self> {
let cc_raw = Tensor::rand(-1.0_f32, 1.0, (num_concepts, dim), device)?;
let cc_norm =
cc_raw.broadcast_div(&cc_raw.sqr()?.sum(1)?.sqrt()?.reshape((num_concepts, 1))?)?;
let concept_centers = Var::from_tensor(&cc_norm)?;
let co_raw = Tensor::rand(-1.0_f32, 1.0, (num_concepts, dim), device)?;
let co_norm =
co_raw.broadcast_div(&co_raw.sqr()?.sum(1)?.sqrt()?.reshape((num_concepts, 1))?)?;
let concept_offsets = Var::from_tensor(&co_norm)?;
let bump_raw = Tensor::rand(-1.0_f32, 1.0, (num_concepts, dim), device)?;
let bump_norm =
bump_raw.broadcast_div(&bump_raw.sqr()?.sum(1)?.sqrt()?.reshape((num_concepts, 1))?)?;
let bumps = Var::from_tensor(&bump_norm)?;
let nr = num_roles.max(1);
let rh_raw = Tensor::rand(-1.0_f32, 1.0, (nr, dim * 2), device)?;
let rh_norm = rh_raw.broadcast_div(&rh_raw.sqr()?.sum(1)?.sqrt()?.reshape((nr, 1))?)?;
let role_heads = Var::from_tensor(&rh_norm)?;
let rt_raw = Tensor::rand(-1.0_f32, 1.0, (nr, dim * 2), device)?;
let rt_norm = rt_raw.broadcast_div(&rt_raw.sqr()?.sum(1)?.sqrt()?.reshape((nr, 1))?)?;
let role_tails = Var::from_tensor(&rt_norm)?;
Ok(Self {
concept_centers,
concept_offsets,
bumps,
role_heads,
role_tails,
dim,
num_concepts,
num_roles,
margin,
neg_dist,
device: device.clone(),
})
}
fn inclusion_loss(
centers_a: &Tensor,
offsets_a: &Tensor,
centers_b: &Tensor,
offsets_b: &Tensor,
margin: f32,
) -> Result<Tensor> {
let diffs = centers_a.sub(centers_b)?.abs()?;
let violation = diffs
.add(offsets_a)?
.sub(offsets_b)?
.affine(1.0, -(margin as f64))?
.relu()?;
let norm_sq = violation.sqr()?.sum(1)?;
norm_sq.affine(1.0, 1e-8)?.sqrt()
}
fn neg_loss_fn(
centers_a: &Tensor,
offsets_a: &Tensor,
centers_b: &Tensor,
offsets_b: &Tensor,
margin: f32,
) -> Result<Tensor> {
let diffs = centers_a.sub(centers_b)?.abs()?;
let gap = diffs
.sub(offsets_a)?
.sub(offsets_b)?
.affine(1.0, margin as f64)?
.relu()?;
let norm_sq = gap.sqr()?.sum(1)?;
norm_sq.affine(1.0, 1e-8)?.sqrt()
}
fn disjointness_score(
centers_a: &Tensor,
offsets_a: &Tensor,
centers_b: &Tensor,
offsets_b: &Tensor,
margin: f32,
) -> Result<Tensor> {
let diffs = centers_a.sub(centers_b)?.abs()?;
let gap = diffs
.sub(offsets_a)?
.sub(offsets_b)?
.affine(1.0, margin as f64)?
.relu()?;
let norm_sq = gap.sqr()?.sum(1)?;
norm_sq.affine(1.0, 1e-8)?.sqrt()
}
fn concept_boxes(&self, ids: &Tensor) -> Result<(Tensor, Tensor)> {
let centers = self.concept_centers.as_tensor().index_select(ids, 0)?;
let offsets = self
.concept_offsets
.as_tensor()
.index_select(ids, 0)?
.abs()?;
Ok((centers, offsets))
}
fn concept_bumps(&self, ids: &Tensor) -> Result<Tensor> {
self.bumps.as_tensor().index_select(ids, 0)
}
fn role_box(&self, ids: &Tensor, head: bool) -> Result<(Tensor, Tensor)> {
let embed = if head {
self.role_heads.as_tensor().index_select(ids, 0)?
} else {
self.role_tails.as_tensor().index_select(ids, 0)?
};
let centers = embed.narrow(1, 0, self.dim)?;
let offsets = embed.narrow(1, self.dim, self.dim)?.abs()?;
Ok((centers, offsets))
}
#[allow(clippy::too_many_arguments)]
pub fn fit(
&self,
ontology: &Ontology,
epochs: usize,
lr: f64,
batch_size: usize,
negative_samples: usize,
reg_factor: f32,
) -> Result<Vec<f32>> {
use candle_nn::{AdamW, Optimizer, ParamsAdamW};
let mut vars = vec![
self.concept_centers.clone(),
self.concept_offsets.clone(),
self.bumps.clone(),
];
if self.num_roles > 0 {
vars.push(self.role_heads.clone());
vars.push(self.role_tails.clone());
}
let params = ParamsAdamW {
lr,
weight_decay: 0.0,
..Default::default()
};
let mut opt = AdamW::new(vars, params)?;
let mut nf2_axioms: Vec<(usize, usize)> = Vec::new(); let mut nf1_axioms: Vec<(usize, usize, usize)> = Vec::new(); let mut nf3_axioms: Vec<(usize, usize, usize)> = Vec::new(); let mut nf4_axioms: Vec<(usize, usize, usize)> = Vec::new();
for ax in &ontology.axioms {
match *ax {
Axiom::SubClassOf { sub, sup } => nf2_axioms.push((sub, sup)),
Axiom::Intersection { c1, c2, target } => nf1_axioms.push((c1, c2, target)),
Axiom::ExistentialRight { sub, role, filler } => {
nf3_axioms.push((sub, role, filler))
}
Axiom::Existential {
role,
filler,
target,
} => nf4_axioms.push((role, filler, target)),
_ => {} }
}
let nc = self.num_concepts;
let mut epoch_losses = Vec::with_capacity(epochs);
let mut rng: u64 = 42;
let lcg = |s: &mut u64| -> usize {
*s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*s >> 33) as usize
};
let lr_min = lr * 0.01;
for epoch in 0..epochs {
let progress = epoch as f64 / epochs.max(1) as f64;
let current_lr =
lr_min + 0.5 * (lr - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos());
opt.set_learning_rate(current_lr);
let mut epoch_loss = Tensor::zeros((), candle_core::DType::F32, &self.device)?;
if !nf2_axioms.is_empty() {
let bs = batch_size.min(nf2_axioms.len());
let mut sub_ids = Vec::with_capacity(bs);
let mut sup_ids = Vec::with_capacity(bs);
for _ in 0..bs {
let idx = lcg(&mut rng) % nf2_axioms.len();
let (s, d) = nf2_axioms[idx];
sub_ids.push(s as u32);
sup_ids.push(d as u32);
}
let sub_t = Tensor::from_vec(sub_ids, (bs,), &self.device)?;
let sup_t = Tensor::from_vec(sup_ids, (bs,), &self.device)?;
let (c_sub, o_sub) = self.concept_boxes(&sub_t)?;
let (c_sup, o_sup) = self.concept_boxes(&sup_t)?;
let pos_loss = Self::inclusion_loss(&c_sub, &o_sub, &c_sup, &o_sup, self.margin)?
.sqr()?
.mean(0)?;
let mut neg_loss_sum = Tensor::zeros((), candle_core::DType::F32, &self.device)?;
for _ in 0..negative_samples {
let neg_ids: Vec<u32> = (0..bs).map(|_| (lcg(&mut rng) % nc) as u32).collect();
let neg_t = Tensor::from_vec(neg_ids, (bs,), &self.device)?;
let (c_neg, o_neg) = self.concept_boxes(&neg_t)?;
let disj =
Self::disjointness_score(&c_sub, &o_sub, &c_neg, &o_neg, self.margin)?;
let target = Tensor::full(self.neg_dist, disj.shape(), &self.device)?;
let gap = target.sub(&disj)?;
let neg_loss = gap.sqr()?.mean(0)?;
neg_loss_sum = neg_loss_sum.add(&neg_loss)?;
}
let batch_loss = pos_loss.add(&neg_loss_sum)?;
epoch_loss = epoch_loss.add(&batch_loss)?;
}
if !nf1_axioms.is_empty() {
let bs = batch_size.min(nf1_axioms.len());
let mut c1_ids = Vec::with_capacity(bs);
let mut c2_ids = Vec::with_capacity(bs);
let mut d_ids = Vec::with_capacity(bs);
for _ in 0..bs {
let idx = lcg(&mut rng) % nf1_axioms.len();
let (c1, c2, d) = nf1_axioms[idx];
c1_ids.push(c1 as u32);
c2_ids.push(c2 as u32);
d_ids.push(d as u32);
}
let c1_t = Tensor::from_vec(c1_ids, (bs,), &self.device)?;
let c2_t = Tensor::from_vec(c2_ids, (bs,), &self.device)?;
let d_t = Tensor::from_vec(d_ids, (bs,), &self.device)?;
let (cc1, oc1) = self.concept_boxes(&c1_t)?;
let (cc2, oc2) = self.concept_boxes(&c2_t)?;
let (cd, od) = self.concept_boxes(&d_t)?;
let min1 = cc1.sub(&oc1)?;
let max1 = cc1.add(&oc1)?;
let min2 = cc2.sub(&oc2)?;
let max2 = cc2.add(&oc2)?;
let inter_min = min1.maximum(&min2)?;
let inter_max = max1.minimum(&max2)?;
let inter_max = inter_max.maximum(&inter_min)?;
let inter_center = inter_min.add(&inter_max)?.affine(0.5, 0.0)?;
let inter_offset = inter_max.sub(&inter_min)?.affine(0.5, 0.0)?;
let nf1_loss =
Self::inclusion_loss(&inter_center, &inter_offset, &cd, &od, self.margin)?
.mean(0)?;
epoch_loss = epoch_loss.add(&nf1_loss)?;
}
if !nf3_axioms.is_empty() {
let bs = batch_size.min(nf3_axioms.len());
let mut sub_ids = Vec::with_capacity(bs);
let mut role_ids = Vec::with_capacity(bs);
let mut filler_ids = Vec::with_capacity(bs);
for _ in 0..bs {
let idx = lcg(&mut rng) % nf3_axioms.len();
let (s, r, f) = nf3_axioms[idx];
sub_ids.push(s as u32);
role_ids.push(r as u32);
filler_ids.push(f as u32);
}
let sub_t = Tensor::from_vec(sub_ids, (bs,), &self.device)?;
let role_t = Tensor::from_vec(role_ids, (bs,), &self.device)?;
let filler_t = Tensor::from_vec(filler_ids, (bs,), &self.device)?;
let (c_sub, o_sub) = self.concept_boxes(&sub_t)?;
let (c_filler, o_filler) = self.concept_boxes(&filler_t)?;
let bump_sub = self.concept_bumps(&sub_t)?;
let bump_filler = self.concept_bumps(&filler_t)?;
let (c_head, o_head) = self.role_box(&role_t, true)?;
let (c_tail, o_tail) = self.role_box(&role_t, false)?;
let c_sub_bumped = c_sub.add(&bump_filler)?;
let dist1 =
Self::inclusion_loss(&c_sub_bumped, &o_sub, &c_head, &o_head, self.margin)?;
let c_filler_bumped = c_filler.add(&bump_sub)?;
let dist2 = Self::inclusion_loss(
&c_filler_bumped,
&o_filler,
&c_tail,
&o_tail,
self.margin,
)?;
let nf3_loss = dist1.add(&dist2)?.affine(0.5, 0.0)?.mean(0)?;
let mut nf3_neg_sum = Tensor::zeros((), candle_core::DType::F32, &self.device)?;
for _ in 0..negative_samples {
let neg_tail_ids: Vec<u32> =
(0..bs).map(|_| (lcg(&mut rng) % nc) as u32).collect();
let neg_tail_t = Tensor::from_vec(neg_tail_ids, (bs,), &self.device)?;
let bump_neg_tail = self.concept_bumps(&neg_tail_t)?;
let c_sub_bumped_neg = c_sub.add(&bump_neg_tail)?;
let neg_loss1 = Self::neg_loss_fn(
&c_sub_bumped_neg,
&o_sub,
&c_head,
&o_head,
self.margin,
)?;
let neg_head_ids: Vec<u32> =
(0..bs).map(|_| (lcg(&mut rng) % nc) as u32).collect();
let neg_head_t = Tensor::from_vec(neg_head_ids, (bs,), &self.device)?;
let (c_neg_head, o_neg_head) = self.concept_boxes(&neg_head_t)?;
let c_neg_bumped = c_neg_head.add(&bump_sub)?;
let neg_loss2 = Self::neg_loss_fn(
&c_neg_bumped,
&o_neg_head,
&c_tail,
&o_tail,
self.margin,
)?;
let target1 = Tensor::full(self.neg_dist, neg_loss1.shape(), &self.device)?;
let target2 = Tensor::full(self.neg_dist, neg_loss2.shape(), &self.device)?;
let nl1 = target1.sub(&neg_loss1)?.sqr()?.mean(0)?;
let nl2 = target2.sub(&neg_loss2)?.sqr()?.mean(0)?;
nf3_neg_sum = nf3_neg_sum.add(&nl1)?.add(&nl2)?;
}
let nf3_total = nf3_loss.add(&nf3_neg_sum)?;
epoch_loss = epoch_loss.add(&nf3_total)?;
}
if !nf4_axioms.is_empty() {
let bs = batch_size.min(nf4_axioms.len());
let mut role_ids = Vec::with_capacity(bs);
let mut filler_ids = Vec::with_capacity(bs);
let mut target_ids = Vec::with_capacity(bs);
for _ in 0..bs {
let idx = lcg(&mut rng) % nf4_axioms.len();
let (r, f, t) = nf4_axioms[idx];
role_ids.push(r as u32);
filler_ids.push(f as u32);
target_ids.push(t as u32);
}
let role_t = Tensor::from_vec(role_ids, (bs,), &self.device)?;
let filler_t = Tensor::from_vec(filler_ids, (bs,), &self.device)?;
let target_t = Tensor::from_vec(target_ids, (bs,), &self.device)?;
let (c_target, o_target) = self.concept_boxes(&target_t)?;
let bump_filler = self.concept_bumps(&filler_t)?;
let (c_head, o_head) = self.role_box(&role_t, true)?;
let c_head_shifted = c_head.sub(&bump_filler)?;
let nf4_loss = Self::inclusion_loss(
&c_head_shifted,
&o_head,
&c_target,
&o_target,
self.margin,
)?
.sqr()?
.mean(0)?;
let mut nf4_neg_sum = Tensor::zeros((), candle_core::DType::F32, &self.device)?;
for _ in 0..negative_samples {
let neg_target_ids: Vec<u32> =
(0..bs).map(|_| (lcg(&mut rng) % nc) as u32).collect();
let neg_target_t = Tensor::from_vec(neg_target_ids, (bs,), &self.device)?;
let (c_neg_target, o_neg_target) = self.concept_boxes(&neg_target_t)?;
let neg_loss1 = Self::neg_loss_fn(
&c_head_shifted,
&o_head,
&c_neg_target,
&o_neg_target,
self.margin,
)?;
let neg_filler_ids: Vec<u32> =
(0..bs).map(|_| (lcg(&mut rng) % nc) as u32).collect();
let neg_filler_t = Tensor::from_vec(neg_filler_ids, (bs,), &self.device)?;
let bump_neg_filler = self.concept_bumps(&neg_filler_t)?;
let c_head_neg_shifted = c_head.sub(&bump_neg_filler)?;
let neg_loss2 = Self::neg_loss_fn(
&c_head_neg_shifted,
&o_head,
&c_target,
&o_target,
self.margin,
)?;
let target1 = Tensor::full(self.neg_dist, neg_loss1.shape(), &self.device)?;
let target2 = Tensor::full(self.neg_dist, neg_loss2.shape(), &self.device)?;
let nl1 = target1.sub(&neg_loss1)?.sqr()?.mean(0)?;
let nl2 = target2.sub(&neg_loss2)?.sqr()?.mean(0)?;
nf4_neg_sum = nf4_neg_sum.add(&nl1)?.add(&nl2)?;
}
let nf4_total = nf4_loss.add(&nf4_neg_sum)?;
epoch_loss = epoch_loss.add(&nf4_total)?;
}
if reg_factor > 0.0 {
let bump_reg = self
.bumps
.as_tensor()
.sqr()?
.sum(1)?
.sqrt()?
.mean(0)?
.affine(reg_factor as f64, 0.0)?;
epoch_loss = epoch_loss.add(&bump_reg)?;
}
let loss_val = epoch_loss.to_scalar::<f32>()?;
opt.backward_step(&epoch_loss)?;
epoch_losses.push(loss_val);
if loss_val.is_nan() || loss_val.is_infinite() {
eprintln!(
" WARNING: loss diverged at epoch {} (loss={loss_val}). Stopping.",
epoch + 1
);
break;
}
if (epoch + 1) % 100 == 0 || epoch == 0 {
let c_mean = self
.concept_centers
.as_tensor()
.abs()?
.mean_all()?
.to_scalar::<f32>()?;
let o_mean = self
.concept_offsets
.as_tensor()
.abs()?
.mean_all()?
.to_scalar::<f32>()?;
eprintln!(
" epoch {:>5}/{epochs}: loss={loss_val:.4} lr={current_lr:.6} |c|={c_mean:.3} |o|={o_mean:.3}",
epoch + 1
);
}
}
Ok(epoch_losses)
}
pub fn evaluate_subsumption(
&self,
test_axioms: &[(usize, usize)], ) -> Result<(f32, f32, f32)> {
let centers: Vec<f32> = self
.concept_centers
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let nc = self.num_concepts;
let dim = self.dim;
let mut hits1 = 0usize;
let mut hits10 = 0usize;
let mut rr_sum = 0.0f32;
let mut total = 0usize;
for &(sub, sup) in test_axioms {
if sub >= nc || sup >= nc {
continue;
}
let sub_offset = sub * dim;
let mut scores: Vec<(usize, f32)> = (0..nc)
.filter(|&c| c != sub)
.map(|c| {
let c_offset = c * dim;
let dist_sq: f32 = (0..dim)
.map(|d| {
let diff = centers[sub_offset + d] - centers[c_offset + d];
diff * diff
})
.sum();
(c, dist_sq.sqrt())
})
.collect();
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let rank = scores
.iter()
.position(|(c, _)| *c == sup)
.map(|p| p + 1)
.unwrap_or(nc);
if rank == 1 {
hits1 += 1;
}
if rank <= 10 {
hits10 += 1;
}
rr_sum += 1.0 / rank as f32;
total += 1;
}
if total == 0 {
return Ok((0.0, 0.0, 0.0));
}
Ok((
hits1 as f32 / total as f32,
hits10 as f32 / total as f32,
rr_sum / total as f32,
))
}
pub fn evaluate_nf1(
&self,
test_axioms: &[(usize, usize, usize)], ) -> Result<(f32, f32, f32)> {
let centers: Vec<f32> = self
.concept_centers
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let offsets: Vec<f32> = self
.concept_offsets
.as_tensor()
.abs()?
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let nc = self.num_concepts;
let dim = self.dim;
let mut hits1 = 0usize;
let mut hits10 = 0usize;
let mut rr_sum = 0.0f32;
let mut total = 0usize;
for &(c1, c2, d) in test_axioms {
if c1 >= nc || c2 >= nc || d >= nc {
continue;
}
let c1_off = c1 * dim;
let c2_off = c2 * dim;
let mut inter_center = vec![0.0f32; dim];
for i in 0..dim {
let min1 = centers[c1_off + i] - offsets[c1_off + i];
let max1 = centers[c1_off + i] + offsets[c1_off + i];
let min2 = centers[c2_off + i] - offsets[c2_off + i];
let max2 = centers[c2_off + i] + offsets[c2_off + i];
let inter_min = min1.max(min2);
let inter_max = max1.min(max2).max(inter_min);
inter_center[i] = (inter_min + inter_max) / 2.0;
}
let mut scores: Vec<(usize, f32)> = (0..nc)
.filter(|&c| c != c1 && c != c2)
.map(|c| {
let c_off = c * dim;
let dist_sq: f32 = (0..dim)
.map(|i| {
let diff = inter_center[i] - centers[c_off + i];
diff * diff
})
.sum();
(c, dist_sq.sqrt())
})
.collect();
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let rank = scores
.iter()
.position(|(c, _)| *c == d)
.map(|p| p + 1)
.unwrap_or(nc);
if rank == 1 {
hits1 += 1;
}
if rank <= 10 {
hits10 += 1;
}
rr_sum += 1.0 / rank as f32;
total += 1;
}
if total == 0 {
return Ok((0.0, 0.0, 0.0));
}
Ok((
hits1 as f32 / total as f32,
hits10 as f32 / total as f32,
rr_sum / total as f32,
))
}
pub fn evaluate_nf3(
&self,
test_axioms: &[(usize, usize, usize)], ) -> Result<(f32, f32, f32)> {
let centers: Vec<f32> = self
.concept_centers
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let bump_vecs: Vec<f32> = self
.bumps
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let role_heads_data: Vec<f32> = self
.role_heads
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let nc = self.num_concepts;
let nr = self.num_roles;
let dim = self.dim;
let mut hits1 = 0usize;
let mut hits10 = 0usize;
let mut rr_sum = 0.0f32;
let mut total = 0usize;
for &(sub, role, filler) in test_axioms {
if sub >= nc || filler >= nc || role >= nr {
continue;
}
let sub_off = sub * dim;
let rh_off = role * dim * 2;
let mut scores: Vec<(usize, f32)> = (0..nc)
.map(|d| {
let bump_off = d * dim;
let dist_sq: f32 = (0..dim)
.map(|i| {
let bumped = centers[sub_off + i] + bump_vecs[bump_off + i];
let diff = bumped - role_heads_data[rh_off + i];
diff * diff
})
.sum();
(d, dist_sq.sqrt())
})
.collect();
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let rank = scores
.iter()
.position(|(c, _)| *c == filler)
.map(|p| p + 1)
.unwrap_or(nc);
if rank == 1 {
hits1 += 1;
}
if rank <= 10 {
hits10 += 1;
}
rr_sum += 1.0 / rank as f32;
total += 1;
}
if total == 0 {
return Ok((0.0, 0.0, 0.0));
}
Ok((
hits1 as f32 / total as f32,
hits10 as f32 / total as f32,
rr_sum / total as f32,
))
}
pub fn evaluate_nf4(
&self,
test_axioms: &[(usize, usize, usize)], ) -> Result<(f32, f32, f32)> {
let centers: Vec<f32> = self
.concept_centers
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let bump_vecs: Vec<f32> = self
.bumps
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let role_heads_data: Vec<f32> = self
.role_heads
.as_tensor()
.to_vec2::<f32>()?
.into_iter()
.flatten()
.collect();
let nc = self.num_concepts;
let nr = self.num_roles;
let dim = self.dim;
let mut hits1 = 0usize;
let mut hits10 = 0usize;
let mut rr_sum = 0.0f32;
let mut total = 0usize;
for &(role, filler, target) in test_axioms {
if filler >= nc || target >= nc || role >= nr {
continue;
}
let rh_off = role * dim * 2;
let bump_off = filler * dim;
let mut query_center = vec![0.0f32; dim];
for i in 0..dim {
query_center[i] = role_heads_data[rh_off + i] - bump_vecs[bump_off + i];
}
let mut scores: Vec<(usize, f32)> = (0..nc)
.map(|c| {
let c_off = c * dim;
let dist_sq: f32 = (0..dim)
.map(|i| {
let diff = query_center[i] - centers[c_off + i];
diff * diff
})
.sum();
(c, dist_sq.sqrt())
})
.collect();
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let rank = scores
.iter()
.position(|(c, _)| *c == target)
.map(|p| p + 1)
.unwrap_or(nc);
if rank == 1 {
hits1 += 1;
}
if rank <= 10 {
hits10 += 1;
}
rr_sum += 1.0 / rank as f32;
total += 1;
}
if total == 0 {
return Ok((0.0, 0.0, 0.0));
}
Ok((
hits1 as f32 / total as f32,
hits10 as f32 / total as f32,
rr_sum / total as f32,
))
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensors: std::collections::HashMap<String, Tensor> = [
(
"concept_centers".to_string(),
self.concept_centers.as_tensor().clone(),
),
(
"concept_offsets".to_string(),
self.concept_offsets.as_tensor().clone(),
),
("bumps".to_string(), self.bumps.as_tensor().clone()),
(
"role_heads".to_string(),
self.role_heads.as_tensor().clone(),
),
(
"role_tails".to_string(),
self.role_tails.as_tensor().clone(),
),
]
.into_iter()
.collect();
candle_core::safetensors::save(&tensors, path)?;
Ok(())
}
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let tensors = candle_core::safetensors::load(path, &self.device)?;
let get = |name: &str| -> Result<Tensor> {
tensors
.get(name)
.cloned()
.ok_or_else(|| candle_core::Error::Msg(format!("missing tensor: {name}")))
};
self.concept_centers = Var::from_tensor(&get("concept_centers")?)?;
self.concept_offsets = Var::from_tensor(&get("concept_offsets")?)?;
self.bumps = Var::from_tensor(&get("bumps")?)?;
self.role_heads = Var::from_tensor(&get("role_heads")?)?;
self.role_tails = Var::from_tensor(&get("role_tails")?)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candle_el_trainer_creates() {
let device = Device::Cpu;
let trainer = CandleElTrainer::new(100, 5, 32, 0.1, 2.0, &device).unwrap();
assert_eq!(trainer.num_concepts, 100);
assert_eq!(trainer.num_roles, 5);
}
#[test]
fn test_candle_el_trainer_fits() {
let device = Device::Cpu;
let trainer = CandleElTrainer::new(20, 3, 16, 0.1, 2.0, &device).unwrap();
let mut ont = Ontology::new();
for i in 0..20 {
ont.concept(&format!("C{i}"));
}
for i in 0..3 {
ont.role(&format!("R{i}"));
}
ont.axioms.push(Axiom::SubClassOf { sub: 0, sup: 1 });
ont.axioms.push(Axiom::SubClassOf { sub: 2, sup: 3 });
ont.axioms.push(Axiom::Intersection {
c1: 0,
c2: 2,
target: 4,
});
ont.axioms.push(Axiom::ExistentialRight {
sub: 5,
role: 0,
filler: 6,
});
let losses = trainer.fit(&ont, 50, 0.01, 4, 1, 0.0).unwrap();
assert_eq!(losses.len(), 50);
assert!(losses[0].is_finite());
assert!(losses.last().unwrap() < &losses[0], "loss should decrease");
}
#[test]
fn test_candle_el_eval_works() {
let device = Device::Cpu;
let trainer = CandleElTrainer::new(20, 2, 16, 0.1, 2.0, &device).unwrap();
let mut ont = Ontology::new();
for i in 0..20 {
ont.concept(&format!("C{i}"));
}
ont.role("R0");
ont.role("R1");
for i in 0..15 {
ont.axioms.push(Axiom::SubClassOf {
sub: i,
sup: (i + 1) % 20,
});
}
let _losses = trainer.fit(&ont, 200, 0.01, 8, 2, 0.0).unwrap();
let test_pairs: Vec<(usize, usize)> = ont
.axioms
.iter()
.filter_map(|a| match a {
Axiom::SubClassOf { sub, sup } => Some((*sub, *sup)),
_ => None,
})
.collect();
let (h1, h10, mrr) = trainer.evaluate_subsumption(&test_pairs).unwrap();
assert!(
mrr > 0.0,
"MRR should be positive on training data, got {mrr}"
);
eprintln!("CandleElTrainer eval: H@1={h1:.3} H@10={h10:.3} MRR={mrr:.4}");
}
}