use crate::taxobell::{CombinedLossResult, TaxoBellConfig};
use crate::BoxError;
use candle_core::{DType, Device, Result as CResult, Tensor, Var, D};
fn xavier_uniform_scale(fan_in: usize, fan_out: usize) -> f32 {
(6.0 / (fan_in + fan_out) as f64).sqrt() as f32
}
fn softplus(x: &Tensor) -> CResult<Tensor> {
let zero = x.zeros_like()?;
let relu_x = x.maximum(&zero)?;
let neg_abs = x.abs()?.neg()?;
let log_term = neg_abs
.exp()?
.broadcast_add(&Tensor::new(&[1.0f32], x.device())?.broadcast_as(x.shape())?)?;
relu_x.add(&log_term.log()?)
}
#[derive(Debug)]
pub struct Mlp {
w1: Var,
b1: Var,
w2: Var,
b2: Var,
pub input_dim: usize,
pub hidden_dim: usize,
pub output_dim: usize,
}
impl Mlp {
pub fn new(
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
device: &Device,
) -> Result<Self, BoxError> {
let map_err = |e: candle_core::Error| BoxError::Internal(e.to_string());
let scale1 = xavier_uniform_scale(input_dim, hidden_dim);
let w1 = Var::rand(-scale1, scale1, (hidden_dim, input_dim), device).map_err(map_err)?;
let b1 = Var::zeros(hidden_dim, DType::F32, device).map_err(map_err)?;
let scale2 = xavier_uniform_scale(hidden_dim, output_dim);
let w2 = Var::rand(-scale2, scale2, (output_dim, hidden_dim), device).map_err(map_err)?;
let b2 = Var::zeros(output_dim, DType::F32, device).map_err(map_err)?;
Ok(Self {
w1,
b1,
w2,
b2,
input_dim,
hidden_dim,
output_dim,
})
}
pub fn forward(&self, x: &Tensor) -> CResult<Tensor> {
let h = x
.matmul(&self.w1.as_tensor().t()?)?
.broadcast_add(self.b1.as_tensor())?
.relu()?;
h.matmul(&self.w2.as_tensor().t()?)?
.broadcast_add(self.b2.as_tensor())
}
pub fn vars(&self) -> Vec<&Var> {
vec![&self.w1, &self.b1, &self.w2, &self.b2]
}
#[must_use]
pub fn num_params(&self) -> usize {
self.hidden_dim * self.input_dim
+ self.hidden_dim
+ self.output_dim * self.hidden_dim
+ self.output_dim
}
}
#[derive(Debug)]
pub struct TaxoBellEncoder {
pub center_mlp: Mlp,
pub offset_mlp: Mlp,
pub embed_dim: usize,
pub box_dim: usize,
pub device: Device,
}
impl TaxoBellEncoder {
pub fn new(
embed_dim: usize,
hidden_dim: usize,
box_dim: usize,
device: &Device,
) -> Result<Self, BoxError> {
let center_mlp = Mlp::new(embed_dim, hidden_dim, box_dim, device)?;
let offset_mlp = Mlp::new(embed_dim, hidden_dim, box_dim, device)?;
Ok(Self {
center_mlp,
offset_mlp,
embed_dim,
box_dim,
device: device.clone(),
})
}
pub fn encode(&self, embeddings: &Tensor) -> CResult<(Tensor, Tensor)> {
let mu = self.center_mlp.forward(embeddings)?;
let raw_offset = self.offset_mlp.forward(embeddings)?;
let sigma = softplus(&raw_offset)?;
Ok((mu, sigma))
}
pub fn encode_one(&self, embedding: &[f32]) -> Result<crate::gaussian::GaussianBox, BoxError> {
let map_err = |e: candle_core::Error| BoxError::Internal(e.to_string());
let t = Tensor::new(embedding, &self.device).map_err(map_err)?;
let t = t.unsqueeze(0).map_err(map_err)?; let (mu, sigma) = self.encode(&t).map_err(map_err)?;
let mu_vec: Vec<f32> = mu.squeeze(0).map_err(map_err)?.to_vec1().map_err(map_err)?;
let sigma_vec: Vec<f32> = sigma
.squeeze(0)
.map_err(map_err)?
.to_vec1()
.map_err(map_err)?;
crate::gaussian::GaussianBox::new(mu_vec, sigma_vec)
}
pub fn vars(&self) -> Vec<&Var> {
let mut v = self.center_mlp.vars();
v.extend(self.offset_mlp.vars());
v
}
#[must_use]
pub fn num_params(&self) -> usize {
self.center_mlp.num_params() + self.offset_mlp.num_params()
}
}
fn bhattacharyya_coeff_batch(
mu1: &Tensor,
s1: &Tensor,
mu2: &Tensor,
s2: &Tensor,
) -> CResult<Tensor> {
let v1 = s1.sqr()?; let v2 = s2.sqr()?;
let sigma_avg = v1.add(&v2)?.affine(0.5, 0.0)?;
let mu_diff = mu1.sub(mu2)?;
let t1 = mu_diff
.sqr()?
.div(&sigma_avg)?
.sum(D::Minus1)?
.affine(0.25, 0.0)?;
let t2 = sigma_avg.log()?.sum(D::Minus1)?.affine(0.5, 0.0)?;
let t3 = v1.log()?.sum(D::Minus1)?.affine(0.25, 0.0)?;
let t4 = v2.log()?.sum(D::Minus1)?.affine(0.25, 0.0)?;
let bd = t1.add(&t2)?.sub(&t3)?.sub(&t4)?;
bd.neg()?.exp() }
fn kl_divergence_batch(
mu_q: &Tensor,
s_q: &Tensor,
mu_p: &Tensor,
s_p: &Tensor,
) -> CResult<Tensor> {
let vq = s_q.sqr()?;
let vp = s_p.sqr()?;
let mu_diff = mu_p.sub(mu_q)?;
let ratio = vq.div(&vp)?;
let mu_term = mu_diff.sqr()?.div(&vp)?;
let log_term = vp.div(&vq)?.log()?;
let per_dim = ratio.add(&mu_term)?.add(&log_term)?.affine(1.0, -1.0)?;
per_dim.sum(D::Minus1)?.affine(0.5, 0.0)
}
fn log_volume_batch(sigma: &Tensor) -> CResult<Tensor> {
sigma.log()?.sum(D::Minus1)
}
fn volume_reg_batch(sigma: &Tensor, min_var: f32) -> CResult<Tensor> {
let var = sigma.sqr()?;
let zero = var.zeros_like()?;
let gap = var.affine(-1.0, min_var as f64)?;
let hinge = gap.maximum(&zero)?;
let d = sigma.dim(D::Minus1)? as f64;
hinge.sqr()?.sum(D::Minus1)?.affine(1.0 / d, 0.0)
}
fn sigma_ceiling_batch(sigma: &Tensor, max_var: f32) -> CResult<Tensor> {
let var = sigma.sqr()?;
let zero = var.zeros_like()?;
let gap = var.affine(1.0, -(max_var as f64))?;
let hinge = gap.maximum(&zero)?;
let d = sigma.dim(D::Minus1)? as f64;
hinge.sum(D::Minus1)?.affine(1.0 / d, 0.0)
}
#[allow(clippy::too_many_arguments)]
fn combined_loss_tensor(
mu_child: &Tensor,
s_child: &Tensor,
mu_parent: &Tensor,
s_parent: &Tensor,
mu_anchor: &Tensor,
s_anchor: &Tensor,
mu_pos: &Tensor,
s_pos: &Tensor,
mu_neg: &Tensor,
s_neg: &Tensor,
mu_all: &Tensor,
s_all: &Tensor,
config: &TaxoBellConfig,
) -> CResult<(Tensor, CombinedLossResult)> {
let device = mu_child.device();
let eps_val = 1e-7f64;
let one_minus_eps = 1.0 - eps_val;
let n_neg = mu_neg.dim(0)?;
let l_sym_scalar;
let l_sym_t;
if n_neg > 0 {
let bc_pos = bhattacharyya_coeff_batch(mu_anchor, s_anchor, mu_pos, s_pos)?;
let bc_neg = bhattacharyya_coeff_batch(mu_anchor, s_anchor, mu_neg, s_neg)?;
let eps_t = Tensor::new(&[eps_val as f32], device)?.broadcast_as(bc_pos.shape())?;
let one_me = Tensor::new(&[one_minus_eps as f32], device)?.broadcast_as(bc_pos.shape())?;
let bc_pos_c = bc_pos.maximum(&eps_t)?.minimum(&one_me)?;
let bc_neg_c = bc_neg.maximum(&eps_t)?.minimum(&one_me)?;
let term1 = bc_pos_c.log()?.neg()?;
let one_t = Tensor::new(&[1.0f32], device)?.broadcast_as(bc_neg_c.shape())?;
let term2 = one_t.sub(&bc_neg_c)?.log()?.neg()?;
let per_sample = term1.add(&term2)?;
l_sym_t = per_sample.mean_all()?;
l_sym_scalar = l_sym_t.to_vec0::<f32>()?;
} else {
l_sym_t = Tensor::new(0.0f32, device)?;
l_sym_scalar = 0.0;
}
let n_edges = mu_child.dim(0)?;
let l_asym_scalar;
let l_asym_t;
if n_edges > 0 {
let kl = kl_divergence_batch(mu_child, s_child, mu_parent, s_parent)?;
let zero_edges = kl.zeros_like()?;
let l_align = kl
.affine(1.0, -(config.asymmetric_margin as f64))?
.maximum(&zero_edges)?;
let l_diverge = if config.asymmetric_diverge_c > 0.0 {
let kl_rev = kl_divergence_batch(mu_parent, s_parent, mu_child, s_child)?;
let lv_parent = log_volume_batch(s_parent)?;
let lv_child = log_volume_batch(s_child)?;
let d_rep = lv_parent.sub(&lv_child)?;
let zero_d = d_rep.zeros_like()?;
d_rep
.affine(config.asymmetric_diverge_c as f64, 0.0)?
.sub(&kl_rev)?
.maximum(&zero_d)?
} else {
kl.zeros_like()?
};
let asym_per = l_align.add(&l_diverge.affine(config.diverge_lambda as f64, 0.0)?)?;
l_asym_t = asym_per.mean_all()?;
l_asym_scalar = l_asym_t.to_vec0::<f32>()?;
} else {
l_asym_t = Tensor::new(0.0f32, device)?;
l_asym_scalar = 0.0;
}
let n_all = mu_all.dim(0)?;
let l_reg_scalar;
let l_reg_t;
let l_clip_scalar;
let l_clip_t;
if n_all > 0 {
l_reg_t = volume_reg_batch(s_all, config.min_var)?.mean_all()?;
l_clip_t = sigma_ceiling_batch(s_all, config.max_var)?.mean_all()?;
l_reg_scalar = l_reg_t.to_vec0::<f32>()?;
l_clip_scalar = l_clip_t.to_vec0::<f32>()?;
} else {
l_reg_t = Tensor::new(0.0f32, device)?;
l_clip_t = Tensor::new(0.0f32, device)?;
l_reg_scalar = 0.0;
l_clip_scalar = 0.0;
}
let total = l_sym_t
.affine(config.alpha as f64, 0.0)?
.add(&l_asym_t.affine(config.beta as f64, 0.0)?)?
.add(&l_reg_t.affine(config.gamma as f64, 0.0)?)?
.add(&l_clip_t.affine(config.delta as f64, 0.0)?)?;
let total_scalar = total.to_vec0::<f32>()?;
let result = CombinedLossResult {
total: total_scalar,
l_sym: l_sym_scalar,
l_asym: l_asym_scalar,
l_reg: l_reg_scalar,
l_clip: l_clip_scalar,
};
Ok((total, result))
}
struct AmsGrad {
m: Vec<Tensor>,
v: Vec<Tensor>,
v_hat: Vec<Tensor>,
beta1: f64,
beta2: f64,
eps: f64,
t: usize,
}
impl AmsGrad {
fn new(vars: &[&Var], beta1: f64, beta2: f64, eps: f64) -> CResult<Self> {
let mut m = Vec::with_capacity(vars.len());
let mut v = Vec::with_capacity(vars.len());
let mut v_hat = Vec::with_capacity(vars.len());
for var in vars {
let z = var.as_tensor().zeros_like()?;
m.push(z.clone());
v.push(z.clone());
v_hat.push(z);
}
Ok(Self {
m,
v,
v_hat,
beta1,
beta2,
eps,
t: 0,
})
}
fn step(
&mut self,
vars: &[&Var],
grads: &candle_core::backprop::GradStore,
lr: f64,
) -> CResult<()> {
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
for (i, var) in vars.iter().enumerate() {
if let Some(grad) = grads.get(var.as_tensor()) {
self.m[i] = self.m[i]
.affine(self.beta1, 0.0)?
.add(&grad.affine(1.0 - self.beta1, 0.0)?)?;
let v_new = self.v[i]
.affine(self.beta2, 0.0)?
.add(&grad.sqr()?.affine(1.0 - self.beta2, 0.0)?)?;
self.v[i] = v_new.clone();
self.v_hat[i] = self.v_hat[i].maximum(&v_new)?;
let m_hat = self.m[i].affine(1.0 / bc1, 0.0)?;
let denom = self.v_hat[i].sqrt()?.affine(1.0, self.eps)?;
let update = m_hat.affine(lr, 0.0)?.div(&denom)?;
let new_val = var.as_tensor().sub(&update)?;
var.set(&new_val)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TaxoBellTrainingConfig {
pub learning_rate: f32,
pub epochs: usize,
pub num_negatives: usize,
pub loss_config: TaxoBellConfig,
pub hidden_dim: usize,
pub box_dim: usize,
pub seed: u64,
pub warmup_epochs: usize,
}
impl Default for TaxoBellTrainingConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
epochs: 100,
num_negatives: 3,
loss_config: TaxoBellConfig::default(),
hidden_dim: 64,
box_dim: 16,
seed: 42,
warmup_epochs: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct TaxoBellEvalResult {
pub mrr: f32,
pub hits_at_1: f32,
pub hits_at_3: f32,
pub hits_at_10: f32,
}
#[derive(Debug, Clone)]
pub struct TrainingSnapshot {
pub epoch: usize,
pub loss: CombinedLossResult,
pub lr: f32,
}
pub fn train_taxobell(
embeddings: &[Vec<f32>],
edges: &[(usize, usize)],
all_node_ids: &[usize],
node_index: &std::collections::HashMap<usize, usize>,
config: &TaxoBellTrainingConfig,
) -> Result<(TaxoBellEncoder, Vec<TrainingSnapshot>), BoxError> {
let map_err = |e: candle_core::Error| BoxError::Internal(e.to_string());
let device = Device::Cpu;
let embed_dim = embeddings[0].len();
let flat: Vec<f32> = embeddings.iter().flat_map(|e| e.iter().copied()).collect();
let n_nodes = embeddings.len();
let all_embeds = Tensor::from_vec(flat, (n_nodes, embed_dim), &device).map_err(map_err)?;
let child_indices: Vec<u32> = edges
.iter()
.map(|&(_, child_id)| node_index[&child_id] as u32)
.collect();
let parent_indices: Vec<u32> = edges
.iter()
.map(|&(parent_id, _)| node_index[&parent_id] as u32)
.collect();
let all_indices: Vec<u32> = (0..n_nodes as u32).collect();
let child_idx_t =
Tensor::from_vec(child_indices.clone(), edges.len(), &device).map_err(map_err)?;
let parent_idx_t =
Tensor::from_vec(parent_indices.clone(), edges.len(), &device).map_err(map_err)?;
let all_idx_t = Tensor::from_vec(all_indices, n_nodes, &device).map_err(map_err)?;
let encoder = TaxoBellEncoder::new(embed_dim, config.hidden_dim, config.box_dim, &device)?;
let vars = encoder.vars();
let mut opt = AmsGrad::new(&vars, 0.9, 0.999, 1e-8).map_err(map_err)?;
let mut snapshots = Vec::with_capacity(config.epochs);
let mut rng_state = config.seed.wrapping_add(1);
let mut rng_next = |bound: usize| -> usize {
rng_state ^= rng_state << 13;
rng_state ^= rng_state >> 7;
rng_state ^= rng_state << 17;
(rng_state as usize) % bound
};
let n_edges = edges.len();
let n_neg = config.num_negatives;
let n_total_neg = n_edges * n_neg;
for epoch in 0..config.epochs {
let lr = crate::optimizer::get_learning_rate(
epoch,
config.epochs,
config.learning_rate,
config.warmup_epochs,
);
let neg_node_indices: Vec<u32> = (0..n_total_neg)
.map(|_| {
let node_id = all_node_ids[rng_next(all_node_ids.len())];
node_index[&node_id] as u32
})
.collect();
let anchor_indices: Vec<u32> = child_indices
.iter()
.flat_map(|&idx| std::iter::repeat(idx).take(n_neg))
.collect();
let pos_indices: Vec<u32> = parent_indices
.iter()
.flat_map(|&idx| std::iter::repeat(idx).take(n_neg))
.collect();
let (mu_all, s_all) = encoder.encode(&all_embeds).map_err(map_err)?;
let mu_child = mu_all.index_select(&child_idx_t, 0).map_err(map_err)?;
let s_child = s_all.index_select(&child_idx_t, 0).map_err(map_err)?;
let mu_parent = mu_all.index_select(&parent_idx_t, 0).map_err(map_err)?;
let s_parent = s_all.index_select(&parent_idx_t, 0).map_err(map_err)?;
let (mu_anchor, s_anchor, mu_pos, s_pos, mu_neg_t, s_neg_t) = if n_total_neg > 0 {
let anchor_t =
Tensor::from_vec(anchor_indices, n_total_neg, &device).map_err(map_err)?;
let pos_t = Tensor::from_vec(pos_indices, n_total_neg, &device).map_err(map_err)?;
let neg_t =
Tensor::from_vec(neg_node_indices, n_total_neg, &device).map_err(map_err)?;
(
mu_all.index_select(&anchor_t, 0).map_err(map_err)?,
s_all.index_select(&anchor_t, 0).map_err(map_err)?,
mu_all.index_select(&pos_t, 0).map_err(map_err)?,
s_all.index_select(&pos_t, 0).map_err(map_err)?,
mu_all.index_select(&neg_t, 0).map_err(map_err)?,
s_all.index_select(&neg_t, 0).map_err(map_err)?,
)
} else {
let empty = Tensor::zeros((0, config.box_dim), DType::F32, &device).map_err(map_err)?;
(
empty.clone(),
empty.clone(),
empty.clone(),
empty.clone(),
empty.clone(),
empty,
)
};
let mu_all_reg = mu_all.index_select(&all_idx_t, 0).map_err(map_err)?;
let s_all_reg = s_all.index_select(&all_idx_t, 0).map_err(map_err)?;
let (loss_t, loss_result) = combined_loss_tensor(
&mu_child,
&s_child,
&mu_parent,
&s_parent,
&mu_anchor,
&s_anchor,
&mu_pos,
&s_pos,
&mu_neg_t,
&s_neg_t,
&mu_all_reg,
&s_all_reg,
&config.loss_config,
)
.map_err(map_err)?;
snapshots.push(TrainingSnapshot {
epoch,
loss: loss_result,
lr,
});
let grads = loss_t.backward().map_err(map_err)?;
let vars = encoder.vars();
opt.step(&vars, &grads, lr as f64).map_err(map_err)?;
}
Ok((encoder, snapshots))
}
pub fn evaluate_taxobell(
encoder: &TaxoBellEncoder,
embeddings: &[Vec<f32>],
test_edges: &[(usize, usize)],
all_node_ids: &[usize],
node_index: &std::collections::HashMap<usize, usize>,
) -> Result<TaxoBellEvalResult, BoxError> {
if test_edges.is_empty() {
return Ok(TaxoBellEvalResult {
mrr: 0.0,
hits_at_1: 0.0,
hits_at_3: 0.0,
hits_at_10: 0.0,
});
}
let boxes: Vec<crate::gaussian::GaussianBox> = all_node_ids
.iter()
.map(|&id| {
let idx = node_index[&id];
encoder.encode_one(&embeddings[idx])
})
.collect::<Result<Vec<_>, _>>()?;
let mut reciprocal_ranks = Vec::with_capacity(test_edges.len());
let mut hits1 = 0usize;
let mut hits3 = 0usize;
let mut hits10 = 0usize;
for &(parent_id, child_id) in test_edges {
let child_idx = node_index[&child_id];
let child_box = encoder.encode_one(&embeddings[child_idx])?;
let mut scores: Vec<(usize, f32)> = all_node_ids
.iter()
.enumerate()
.filter(|(_, &cand_id)| cand_id != child_id)
.map(|(pos, &cand_id)| {
let kl =
crate::gaussian::kl_divergence(&child_box, &boxes[pos]).unwrap_or(f32::MAX);
(cand_id, kl)
})
.collect();
scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let rank = scores
.iter()
.position(|&(id, _)| id == parent_id)
.map(|r| r + 1)
.unwrap_or(scores.len());
reciprocal_ranks.push(1.0 / rank as f32);
if rank <= 1 {
hits1 += 1;
}
if rank <= 3 {
hits3 += 1;
}
if rank <= 10 {
hits10 += 1;
}
}
let n = test_edges.len() as f32;
Ok(TaxoBellEvalResult {
mrr: reciprocal_ranks.iter().sum::<f32>() / n,
hits_at_1: hits1 as f32 / n,
hits_at_3: hits3 as f32 / n,
hits_at_10: hits10 as f32 / n,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn device() -> Device {
Device::Cpu
}
#[test]
fn mlp_forward_shape() {
let mlp = Mlp::new(8, 16, 4, &device()).unwrap();
let input = Tensor::ones((3, 8), DType::F32, &device()).unwrap();
let output = mlp.forward(&input).unwrap();
assert_eq!(output.dims(), &[3, 4]);
}
#[test]
fn encoder_produces_positive_sigma() {
let enc = TaxoBellEncoder::new(8, 16, 4, &device()).unwrap();
let embed = Tensor::ones((2, 8), DType::F32, &device()).unwrap();
let (_mu, sigma) = enc.encode(&embed).unwrap();
let vals: Vec<Vec<f32>> = sigma.to_vec2().unwrap();
for row in &vals {
for &s in row {
assert!(s > 0.0, "sigma must be positive, got {s}");
}
}
}
#[test]
fn encoder_different_inputs_different_outputs() {
let enc = TaxoBellEncoder::new(8, 16, 4, &device()).unwrap();
let a = Tensor::zeros((1, 8), DType::F32, &device()).unwrap();
let b = Tensor::ones((1, 8), DType::F32, &device()).unwrap();
let (mu_a, _) = enc.encode(&a).unwrap();
let (mu_b, _) = enc.encode(&b).unwrap();
let va: Vec<f32> = mu_a.squeeze(0).unwrap().to_vec1().unwrap();
let vb: Vec<f32> = mu_b.squeeze(0).unwrap().to_vec1().unwrap();
assert_ne!(va, vb, "different inputs should produce different centers");
}
#[test]
fn backward_produces_gradients() {
let enc = TaxoBellEncoder::new(4, 8, 2, &device()).unwrap();
let embed = Tensor::ones((2, 4), DType::F32, &device()).unwrap();
let (mu, sigma) = enc.encode(&embed).unwrap();
let loss = mu
.sum_all()
.unwrap()
.add(&sigma.sum_all().unwrap())
.unwrap();
let grads = loss.backward().unwrap();
for var in enc.vars() {
assert!(
grads.get(var.as_tensor()).is_some(),
"missing gradient for var"
);
}
}
#[test]
fn train_loss_decreases() {
let node_ids = vec![0usize, 1, 2];
let edges = vec![(0, 1), (0, 2)];
let embeddings = vec![
vec![0.0, 0.5, 1.0, 0.2],
vec![1.0, 0.0, 0.5, 0.8],
vec![0.5, 1.0, 0.0, 0.3],
];
let node_index: std::collections::HashMap<usize, usize> = node_ids
.iter()
.enumerate()
.map(|(i, &id)| (id, i))
.collect();
let config = TaxoBellTrainingConfig {
learning_rate: 5e-3,
epochs: 30,
num_negatives: 1,
hidden_dim: 8,
box_dim: 4,
seed: 42,
warmup_epochs: 3,
..Default::default()
};
let (_, snapshots) =
train_taxobell(&embeddings, &edges, &node_ids, &node_index, &config).unwrap();
assert_eq!(snapshots.len(), 30);
let first = snapshots[0].loss.total;
let last = snapshots.last().unwrap().loss.total;
assert!(first.is_finite());
assert!(last.is_finite());
assert!(
last < first,
"loss should decrease: first={first}, last={last}"
);
}
#[test]
fn evaluate_returns_valid_metrics() {
let node_ids = vec![0usize, 1, 2, 3];
let embeddings = vec![
vec![0.0, 0.5],
vec![1.0, 0.0],
vec![0.5, 1.0],
vec![0.2, 0.8],
];
let node_index: std::collections::HashMap<usize, usize> = node_ids
.iter()
.enumerate()
.map(|(i, &id)| (id, i))
.collect();
let encoder = TaxoBellEncoder::new(2, 4, 2, &Device::Cpu).unwrap();
let test_edges = vec![(0, 1), (0, 2)];
let result =
evaluate_taxobell(&encoder, &embeddings, &test_edges, &node_ids, &node_index).unwrap();
assert!(result.mrr >= 0.0 && result.mrr <= 1.0);
assert!(result.hits_at_1 >= 0.0 && result.hits_at_1 <= 1.0);
assert!(result.hits_at_10 >= 0.0 && result.hits_at_10 <= 1.0);
}
#[test]
fn evaluate_empty_edges() {
let node_ids = vec![0usize, 1];
let embeddings = vec![vec![0.0, 0.5], vec![1.0, 0.0]];
let node_index: std::collections::HashMap<usize, usize> = node_ids
.iter()
.enumerate()
.map(|(i, &id)| (id, i))
.collect();
let encoder = TaxoBellEncoder::new(2, 4, 2, &Device::Cpu).unwrap();
let result = evaluate_taxobell(&encoder, &embeddings, &[], &node_ids, &node_index).unwrap();
assert_eq!(result.mrr, 0.0);
}
#[test]
fn encode_one_matches_batch() {
let enc = TaxoBellEncoder::new(4, 8, 2, &device()).unwrap();
let embed = vec![0.5f32, -0.3, 1.0, 0.2];
let t = Tensor::new(&embed[..], &device())
.unwrap()
.unsqueeze(0)
.unwrap();
let (mu_batch, sigma_batch) = enc.encode(&t).unwrap();
let mu_b: Vec<f32> = mu_batch.squeeze(0).unwrap().to_vec1().unwrap();
let sigma_b: Vec<f32> = sigma_batch.squeeze(0).unwrap().to_vec1().unwrap();
let gb = enc.encode_one(&embed).unwrap();
for (a, b) in mu_b.iter().zip(gb.mu().iter()) {
assert!((a - b).abs() < 1e-5, "mu mismatch: {a} vs {b}");
}
for (a, b) in sigma_b.iter().zip(gb.sigma().iter()) {
assert!((a - b).abs() < 1e-5, "sigma mismatch: {a} vs {b}");
}
}
}