use candle_core::{DType, Device, IndexOp, Result, Tensor, Var, D};
use candle_nn::optim::{AdamW, Optimizer, ParamsAdamW};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizerType {
AdamW,
Adagrad,
}
struct Adagrad {
vars: Vec<Var>,
sum_sq: Vec<Var>,
lr: f64,
eps: f64,
}
impl Adagrad {
fn new(vars: Vec<Var>, lr: f64) -> Result<Self> {
let sum_sq: Vec<Var> = vars
.iter()
.map(|v| Var::zeros(v.shape(), v.dtype(), v.device()))
.collect::<Result<_>>()?;
Ok(Self {
vars,
sum_sq,
lr,
eps: 1e-10,
})
}
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
let grads = loss.backward()?;
for (var, ss) in self.vars.iter().zip(self.sum_sq.iter()) {
if let Some(grad) = grads.get(var) {
let new_ss = (ss.as_tensor() + grad.sqr()?)?;
let adjusted = (grad / (new_ss.sqrt()? + self.eps)?)?;
var.set(&(var.as_tensor() - (adjusted * self.lr)?)?)?;
ss.set(&new_ss)?;
}
}
Ok(())
}
fn set_learning_rate(&mut self, lr: f64) {
self.lr = lr;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelType {
TransE,
RotatE,
ComplEx,
DistMult,
}
#[derive(Debug, Clone)]
pub struct TrainConfig {
pub model_type: ModelType,
pub optimizer: OptimizerType,
pub dim: usize,
pub init_scale: f64,
pub num_negatives: usize,
pub one_to_n: bool,
pub label_smoothing: f32,
pub multi_hot: bool,
pub gamma: f32,
pub distance_norm: u32,
pub subsampling: bool,
pub adversarial_temperature: f32,
pub lr: f64,
pub embedding_dropout: f32,
pub n3_reg: f32,
pub l2_reg: f32,
pub batch_size: usize,
pub epochs: usize,
pub normalize_entities: bool,
pub warmup_epochs: usize,
pub cosine_cycles: usize,
pub cosine_min_lr_frac: f64,
pub log_interval: usize,
pub eval_interval: usize,
pub patience: usize,
pub checkpoint_dir: Option<std::path::PathBuf>,
pub checkpoint_interval: usize,
pub swa_start_epoch: usize,
pub relation_prediction_weight: f32,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
model_type: ModelType::TransE,
optimizer: OptimizerType::AdamW,
dim: 200,
init_scale: 1e-3,
num_negatives: 256,
one_to_n: false,
label_smoothing: 0.0,
multi_hot: false,
gamma: 12.0,
distance_norm: 1,
subsampling: false,
adversarial_temperature: 1.0,
lr: 0.001,
embedding_dropout: 0.0,
n3_reg: 0.0,
l2_reg: 0.0,
batch_size: 512,
epochs: 1000,
normalize_entities: false,
warmup_epochs: 0,
cosine_cycles: 0,
cosine_min_lr_frac: 0.1,
log_interval: 0,
eval_interval: 0,
patience: 5,
checkpoint_dir: None,
checkpoint_interval: 0,
swa_start_epoch: 0,
relation_prediction_weight: 0.0,
}
}
}
pub struct TrainableModel {
entity_embeddings: Var,
relation_embeddings: Var,
model_type: ModelType,
dim: usize,
gamma: f32,
distance_norm: u32,
embedding_dropout: f32,
device: Device,
}
impl TrainableModel {
pub fn new(
num_entities: usize,
num_relations: usize,
config: &TrainConfig,
device: &Device,
) -> Result<Self> {
let dim = config.dim;
let gamma = config.gamma;
let s = config.init_scale;
let (entity_embeddings, relation_embeddings) = match config.model_type {
ModelType::TransE => {
let ent = Var::randn_f64(0.0, s, (num_entities, dim), DType::F32, device)?;
let rel = Var::randn_f64(0.0, s, (num_relations, dim), DType::F32, device)?;
(ent, rel)
}
ModelType::RotatE => {
let ent = Var::randn_f64(0.0, s, (num_entities, dim * 2), DType::F32, device)?;
let rel = Var::rand_f64(
-std::f64::consts::PI,
std::f64::consts::PI,
(num_relations, dim),
DType::F32,
device,
)?;
(ent, rel)
}
ModelType::ComplEx | ModelType::DistMult => {
let ent_cols = if config.model_type == ModelType::ComplEx {
dim * 2
} else {
dim
};
let rel_cols = ent_cols;
let ent = Var::randn_f64(0.0, s, (num_entities, ent_cols), DType::F32, device)?;
let rel = Var::randn_f64(0.0, s, (num_relations, rel_cols), DType::F32, device)?;
(ent, rel)
}
};
Ok(Self {
entity_embeddings,
relation_embeddings,
model_type: config.model_type,
dim,
gamma,
distance_norm: config.distance_norm,
embedding_dropout: config.embedding_dropout,
device: device.clone(),
})
}
pub fn score_batch(
&self,
heads: &Tensor,
relations: &Tensor,
tails: &Tensor,
) -> Result<Tensor> {
let mut h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
let mut r = self
.relation_embeddings
.as_tensor()
.index_select(relations, 0)?;
let mut t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
if self.embedding_dropout > 0.0 {
h = candle_nn::ops::dropout(&h, self.embedding_dropout)?;
r = candle_nn::ops::dropout(&r, self.embedding_dropout)?;
t = candle_nn::ops::dropout(&t, self.embedding_dropout)?;
}
match self.model_type {
ModelType::TransE => {
let diff = ((h + r)? - t)?;
match self.distance_norm {
1 => diff.abs()?.sum(D::Minus1),
_ => diff.sqr()?.sum(D::Minus1)?.sqrt(),
}
}
ModelType::RotatE => {
let dim = self.dim;
let h_re = h.i((.., ..dim))?;
let h_im = h.i((.., dim..))?;
let t_re = t.i((.., ..dim))?;
let t_im = t.i((.., dim..))?;
let r_cos = r.cos()?;
let r_sin = r.sin()?;
let hr_re = ((&h_re * &r_cos)? - (&h_im * &r_sin)?)?;
let hr_im = ((&h_re * &r_sin)? + (&h_im * &r_cos)?)?;
let d_re = (hr_re - t_re)?;
let d_im = (hr_im - t_im)?;
match self.distance_norm {
1 => {
let dist = (d_re.abs()? + d_im.abs()?)?;
dist.sum(D::Minus1)
}
_ => {
let dist_sq = (d_re.sqr()? + d_im.sqr()?)?;
dist_sq.sum(D::Minus1)?.sqrt()
}
}
}
ModelType::ComplEx => {
let dim = self.dim;
let h_re = h.i((.., ..dim))?;
let h_im = h.i((.., dim..))?;
let r_re = r.i((.., ..dim))?;
let r_im = r.i((.., dim..))?;
let t_re = t.i((.., ..dim))?;
let t_im = t.i((.., dim..))?;
let hr_re = ((&h_re * &r_re)? - (&h_im * &r_im)?)?;
let hr_im = ((&h_re * &r_im)? + (&h_im * &r_re)?)?;
let score = ((&hr_re * &t_re)? + (&hr_im * &t_im)?)?;
score.sum(D::Minus1)?.neg()
}
ModelType::DistMult => {
let score = ((&h * &r)? * &t)?;
score.sum(D::Minus1)?.neg()
}
}
}
pub fn score_1n(&self, heads: &Tensor, relations: &Tensor) -> Result<Tensor> {
let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
let r = self
.relation_embeddings
.as_tensor()
.index_select(relations, 0)?;
let ent_matrix = self.entity_embeddings.as_tensor();
match self.model_type {
ModelType::TransE => {
let hr = (h + r)?; let hr_sq = hr.sqr()?.sum(D::Minus1)?; let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?; let cross = hr.matmul(&ent_matrix.t()?)?; let dist_sq = (hr_sq
.unsqueeze(D::Minus1)?
.broadcast_add(&ent_sq.unsqueeze(0)?)?
- (cross * 2.0)?)?;
dist_sq.neg()
}
ModelType::DistMult => {
let hr = (h * r)?; hr.matmul(&ent_matrix.t()?) }
ModelType::ComplEx => {
let dim = self.dim;
let h_re = h.i((.., ..dim))?;
let h_im = h.i((.., dim..))?;
let r_re = r.i((.., ..dim))?;
let r_im = r.i((.., dim..))?;
let hr_re = ((&h_re * &r_re)? - (&h_im * &r_im)?)?;
let hr_im = ((&h_re * &r_im)? + (&h_im * &r_re)?)?;
let e_re = ent_matrix.i((.., ..dim))?.contiguous()?;
let e_im = ent_matrix.i((.., dim..))?.contiguous()?;
let score = (hr_re.matmul(&e_re.t()?)? + hr_im.matmul(&e_im.t()?)?)?;
Ok(score) }
ModelType::RotatE => {
let dim = self.dim;
let h_re = h.i((.., ..dim))?;
let h_im = h.i((.., dim..))?;
let r_cos = r.cos()?;
let r_sin = r.sin()?;
let hr_re = ((&h_re * &r_cos)? - (&h_im * &r_sin)?)?;
let hr_im = ((&h_re * &r_sin)? + (&h_im * &r_cos)?)?;
let hr = Tensor::cat(&[&hr_re, &hr_im], D::Minus1)?;
let hr_sq = hr.sqr()?.sum(D::Minus1)?;
let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
let cross = hr.matmul(&ent_matrix.t()?)?;
let dist_sq = (hr_sq
.unsqueeze(D::Minus1)?
.broadcast_add(&ent_sq.unsqueeze(0)?)?
- (cross * 2.0)?)?;
dist_sq.neg()
}
}
}
pub fn score_1n_heads(&self, relations: &Tensor, tails: &Tensor) -> Result<Tensor> {
let r = self
.relation_embeddings
.as_tensor()
.index_select(relations, 0)?;
let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
let ent_matrix = self.entity_embeddings.as_tensor();
match self.model_type {
ModelType::TransE => {
let tr = (t - r)?; let tr_sq = tr.sqr()?.sum(D::Minus1)?;
let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
let cross = tr.matmul(&ent_matrix.t()?)?;
let dist_sq = (ent_sq
.unsqueeze(0)?
.broadcast_add(&tr_sq.unsqueeze(D::Minus1)?)?
- (cross * 2.0)?)?;
dist_sq.neg()
}
ModelType::DistMult => {
let rt = (r * t)?;
rt.matmul(&ent_matrix.t()?)
}
ModelType::ComplEx => {
let dim = self.dim;
let r_re = r.i((.., ..dim))?;
let r_im = r.i((.., dim..))?;
let t_re = t.i((.., ..dim))?;
let t_im = t.i((.., dim..))?;
let rc_re = ((&r_re * &t_re)? + (&r_im * &t_im)?)?;
let rc_im = ((&r_im * &t_re)? - (&r_re * &t_im)?)?;
let e_re = ent_matrix.i((.., ..dim))?.contiguous()?;
let e_im = ent_matrix.i((.., dim..))?.contiguous()?;
let score = (rc_re.matmul(&e_re.t()?)? + rc_im.matmul(&e_im.t()?)?)?;
Ok(score)
}
ModelType::RotatE => {
let dim = self.dim;
let t_re = t.i((.., ..dim))?;
let t_im = t.i((.., dim..))?;
let r_cos = r.cos()?;
let r_sin = r.sin()?;
let tr_re = ((&t_re * &r_cos)? + (&t_im * &r_sin)?)?;
let tr_im = ((&t_im * &r_cos)? - (&t_re * &r_sin)?)?;
let tr = Tensor::cat(&[&tr_re, &tr_im], D::Minus1)?;
let tr_sq = tr.sqr()?.sum(D::Minus1)?;
let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
let cross = tr.matmul(&ent_matrix.t()?)?;
let dist_sq = (ent_sq
.unsqueeze(0)?
.broadcast_add(&tr_sq.unsqueeze(D::Minus1)?)?
- (cross * 2.0)?)?;
dist_sq.neg()
}
}
}
pub fn score_1n_relations(
&self,
heads: &Tensor,
tails: &Tensor,
num_relations: usize,
) -> Result<Tensor> {
let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
let rel_matrix = self.relation_embeddings.as_tensor();
match self.model_type {
ModelType::DistMult => {
let ht = (h * t)?;
ht.matmul(&rel_matrix.t()?)
}
ModelType::ComplEx => {
let dim = self.dim;
let h_re = h.i((.., ..dim))?;
let h_im = h.i((.., dim..))?;
let t_re = t.i((.., ..dim))?;
let t_im = t.i((.., dim..))?;
let ht_re = ((&h_re * &t_re)? + (&h_im * &t_im)?)?;
let ht_im = ((&h_im * &t_re)? - (&h_re * &t_im)?)?;
let r_re = rel_matrix.i((.., ..dim))?.contiguous()?;
let r_im = rel_matrix.i((.., dim..))?.contiguous()?;
let score = (ht_re.matmul(&r_re.t()?)? + ht_im.matmul(&r_im.t()?)?)?;
Ok(score)
}
ModelType::TransE | ModelType::RotatE => {
let batch_size = h.dim(0)?;
let mut scores = Vec::with_capacity(num_relations);
for r_idx in 0..num_relations {
let r_ids = Tensor::full(r_idx as u32, batch_size, &self.device)?;
let s = self.score_batch(heads, &r_ids, tails)?;
scores.push(s.neg()?); }
Tensor::stack(&scores, 1)
}
}
}
fn n3_penalty(&self, heads: &Tensor, relations: &Tensor, tails: &Tensor) -> Result<Tensor> {
let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
let r = self
.relation_embeddings
.as_tensor()
.index_select(relations, 0)?;
let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
let cube_norm = |x: &Tensor, is_complex: bool, dim: usize| -> Result<Tensor> {
if is_complex {
let re = x.i((.., ..dim))?;
let im = x.i((.., dim..))?;
let moduli = (re.sqr()? + im.sqr()?)?.sqrt()?;
moduli
.powf(3.0)?
.sum_all()?
.affine(1.0 / x.dim(0)? as f64, 0.0)
} else {
x.abs()?.powf(3.0)?.mean_all()
}
};
let is_cx = self.model_type == ModelType::ComplEx;
let dim = self.dim;
let penalty =
(cube_norm(&h, is_cx, dim)? + cube_norm(&r, is_cx, dim)? + cube_norm(&t, is_cx, dim)?)?;
Ok(penalty)
}
pub fn model_type(&self) -> ModelType {
self.model_type
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn entity_embeddings(&self) -> &Tensor {
self.entity_embeddings.as_tensor()
}
pub fn relation_embeddings(&self) -> &Tensor {
self.relation_embeddings.as_tensor()
}
pub fn entity_vecs(&self) -> Result<Vec<Vec<f32>>> {
tensor_to_vecs(self.entity_embeddings.as_tensor())
}
pub fn relation_vecs(&self) -> Result<Vec<Vec<f32>>> {
tensor_to_vecs(self.relation_embeddings.as_tensor())
}
pub fn to_transe(&self) -> Result<crate::TransE> {
Ok(crate::TransE::from_vecs_with_norm(
self.entity_vecs()?,
self.relation_vecs()?,
self.dim,
self.distance_norm,
))
}
pub fn to_rotate(&self) -> Result<crate::RotatE> {
Ok(crate::RotatE::from_vecs(
self.entity_vecs()?,
self.relation_vecs()?,
self.dim,
self.gamma,
))
}
pub fn to_complex(&self) -> Result<crate::ComplEx> {
Ok(crate::ComplEx::from_vecs(
self.entity_vecs()?,
self.relation_vecs()?,
self.dim,
))
}
pub fn to_distmult(&self) -> Result<crate::DistMult> {
Ok(crate::DistMult::from_vecs(
self.entity_vecs()?,
self.relation_vecs()?,
self.dim,
))
}
}
pub struct Snapshot {
pub entity_vecs: Vec<Vec<f32>>,
pub relation_vecs: Vec<Vec<f32>>,
pub epoch: usize,
}
pub struct TrainResult {
pub model: TrainableModel,
pub losses: Vec<f32>,
pub epoch_times: Vec<f32>,
pub snapshots: Vec<Snapshot>,
pub swa_entity_vecs: Option<Vec<Vec<f32>>>,
pub swa_relation_vecs: Option<Vec<Vec<f32>>>,
}
pub struct ValidationData<'a> {
pub valid_triples: &'a [crate::dataset::TripleIds],
pub filter: &'a crate::dataset::FilterIndex,
}
pub fn learning_rate(epoch: usize, config: &TrainConfig) -> f64 {
let base_lr = config.lr;
if config.warmup_epochs > 0 && epoch < config.warmup_epochs {
base_lr * (epoch + 1) as f64 / config.warmup_epochs as f64
} else if config.cosine_cycles > 0 {
let effective_epoch = epoch.saturating_sub(config.warmup_epochs);
let total_effective = config.epochs.saturating_sub(config.warmup_epochs);
let epochs_per_cycle = total_effective / config.cosine_cycles;
if epochs_per_cycle > 0 {
let cycle_pos = effective_epoch % epochs_per_cycle;
let t = cycle_pos as f64 / epochs_per_cycle as f64;
let min_lr = base_lr * config.cosine_min_lr_frac;
min_lr + 0.5 * (base_lr - min_lr) * (1.0 + (t * std::f64::consts::PI).cos())
} else {
base_lr
}
} else {
base_lr
}
}
pub fn train(
train_triples: &[crate::dataset::TripleIds],
num_entities: usize,
num_relations: usize,
config: &TrainConfig,
device: &Device,
) -> Result<TrainResult> {
train_with_validation(
train_triples,
num_entities,
num_relations,
config,
device,
None,
)
}
pub fn train_with_validation(
train_triples: &[crate::dataset::TripleIds],
num_entities: usize,
num_relations: usize,
config: &TrainConfig,
device: &Device,
validation: Option<ValidationData<'_>>,
) -> Result<TrainResult> {
let model = TrainableModel::new(num_entities, num_relations, config, device)?;
let vars = vec![
model.entity_embeddings.clone(),
model.relation_embeddings.clone(),
];
enum Opt {
Adam(AdamW),
Adagrad(self::Adagrad),
}
impl Opt {
fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
match self {
Opt::Adam(o) => o.backward_step(loss),
Opt::Adagrad(o) => o.backward_step(loss),
}
}
fn set_learning_rate(&mut self, lr: f64) {
match self {
Opt::Adam(o) => o.set_learning_rate(lr),
Opt::Adagrad(o) => o.set_learning_rate(lr),
}
}
}
let mut optimizer = match config.optimizer {
OptimizerType::AdamW => Opt::Adam(AdamW::new(
vars,
ParamsAdamW {
lr: config.lr,
weight_decay: 0.0,
..ParamsAdamW::default()
},
)?),
OptimizerType::Adagrad => Opt::Adagrad(self::Adagrad::new(vars, config.lr)?),
};
let n_triples = train_triples.len();
let batch_size = config.batch_size.min(n_triples);
let gamma = config.gamma;
let alpha = config.adversarial_temperature;
let n3_coeff = config.n3_reg;
let (known_tails, known_heads) = if config.one_to_n {
let mut kt: std::collections::HashMap<(usize, usize), Vec<usize>> =
std::collections::HashMap::new();
let mut kh: std::collections::HashMap<(usize, usize), Vec<usize>> =
std::collections::HashMap::new();
for triple in train_triples {
kt.entry((triple.head, triple.relation))
.or_default()
.push(triple.tail);
kh.entry((triple.relation, triple.tail))
.or_default()
.push(triple.head);
}
(Some(kt), Some(kh))
} else {
(None, None)
};
let mut tail_target_buf = if config.one_to_n && config.multi_hot {
vec![0.0_f32; batch_size * num_entities]
} else {
Vec::new()
};
let mut head_target_buf = if config.one_to_n && config.multi_hot {
vec![0.0_f32; batch_size * num_entities]
} else {
Vec::new()
};
let entity_freq = if config.subsampling {
let mut freq = vec![0u32; num_entities];
for triple in train_triples {
freq[triple.head] += 1;
freq[triple.tail] += 1;
}
Some(freq)
} else {
None
};
let mut losses = Vec::with_capacity(config.epochs);
let mut epoch_times = Vec::with_capacity(config.epochs);
let mut snapshots = Vec::new();
let mut shuffled: Vec<crate::dataset::TripleIds> = train_triples.to_vec();
let mut best_mrr = f32::NEG_INFINITY;
let mut patience_counter = 0_usize;
let mut best_entity_vecs: Option<Vec<Vec<f32>>> = None;
let mut best_relation_vecs: Option<Vec<Vec<f32>>> = None;
let swa_active = config.swa_start_epoch > 0;
let mut swa_ent: Option<Vec<f32>> = None;
let mut swa_rel: Option<Vec<f32>> = None;
let mut swa_count = 0u64;
for _epoch in 0..config.epochs {
let lr = learning_rate(_epoch, config);
optimizer.set_learning_rate(lr);
let mut epoch_loss = 0.0_f64;
let mut n_batches = 0u32;
let epoch_start = std::time::Instant::now();
{
use rand::seq::SliceRandom;
shuffled.shuffle(&mut rand::rng());
}
let mut offset = 0;
while offset < n_triples {
let end = (offset + batch_size).min(n_triples);
let batch = &shuffled[offset..end];
let actual_bs = batch.len();
offset = end;
let heads_data: Vec<u32> = batch.iter().map(|t| t.head as u32).collect();
let rels_data: Vec<u32> = batch.iter().map(|t| t.relation as u32).collect();
let tails_data: Vec<u32> = batch.iter().map(|t| t.tail as u32).collect();
let heads = Tensor::from_vec(heads_data, actual_bs, &model.device)?;
let rels = Tensor::from_vec(rels_data, actual_bs, &model.device)?;
let tails = Tensor::from_vec(tails_data, actual_bs, &model.device)?;
let mut loss = if config.one_to_n {
let eps = config.label_smoothing as f64;
let tail_scores = model.score_1n(&heads, &rels)?;
let tail_log_probs = candle_nn::ops::log_softmax(&tail_scores, D::Minus1)?;
let head_scores = model.score_1n_heads(&rels, &tails)?;
let head_log_probs = candle_nn::ops::log_softmax(&head_scores, D::Minus1)?;
let (tail_nll, head_nll) = if config.multi_hot {
let kt = known_tails.as_ref().unwrap();
let tgt = &mut tail_target_buf[..actual_bs * num_entities];
tgt.fill(0.0);
for (i, triple) in batch.iter().enumerate() {
let tails = kt.get(&(triple.head, triple.relation)).unwrap();
let w = 1.0 / tails.len() as f32;
for &t in tails {
tgt[i * num_entities + t] = w;
}
}
let tail_t = Tensor::from_slice(tgt, (actual_bs, num_entities), &model.device)?;
let t_nll = (&tail_t * &tail_log_probs)?
.sum_all()?
.neg()?
.affine(1.0 / actual_bs as f64, 0.0)?;
let kh = known_heads.as_ref().unwrap();
let htgt = &mut head_target_buf[..actual_bs * num_entities];
htgt.fill(0.0);
for (i, triple) in batch.iter().enumerate() {
let heads = kh.get(&(triple.relation, triple.tail)).unwrap();
let w = 1.0 / heads.len() as f32;
for &h in heads {
htgt[i * num_entities + h] = w;
}
}
let head_t =
Tensor::from_slice(htgt, (actual_bs, num_entities), &model.device)?;
let h_nll = (&head_t * &head_log_probs)?
.sum_all()?
.neg()?
.affine(1.0 / actual_bs as f64, 0.0)?;
(t_nll, h_nll)
} else {
let tail_ids = Tensor::from_vec(
batch.iter().map(|t| t.tail as u32).collect::<Vec<_>>(),
actual_bs,
&model.device,
)?;
let t_nll = tail_log_probs
.gather(&tail_ids.unsqueeze(1)?, 1)?
.squeeze(1)?
.neg()?
.mean_all()?;
let head_ids = Tensor::from_vec(
batch.iter().map(|t| t.head as u32).collect::<Vec<_>>(),
actual_bs,
&model.device,
)?;
let h_nll = head_log_probs
.gather(&head_ids.unsqueeze(1)?, 1)?
.squeeze(1)?
.neg()?
.mean_all()?;
(t_nll, h_nll)
};
let nll = ((tail_nll + head_nll)? * 0.5)?;
let main_loss = if eps > 0.0 {
let tail_uniform = tail_log_probs.mean_all()?.neg()?;
let head_uniform = head_log_probs.mean_all()?.neg()?;
let uniform = ((tail_uniform + head_uniform)? * 0.5)?;
((nll * (1.0 - eps))? + (uniform * eps)?)?
} else {
nll
};
if config.relation_prediction_weight > 0.0 {
let rel_scores = model.score_1n_relations(&heads, &tails, num_relations)?;
let rel_log_probs = candle_nn::ops::log_softmax(&rel_scores, D::Minus1)?;
let rel_nll = rel_log_probs
.gather(&rels.unsqueeze(1)?, 1)?
.squeeze(1)?
.neg()?
.mean_all()?;
(main_loss + (rel_nll * config.relation_prediction_weight as f64)?)?
} else {
main_loss
}
} else {
let pos_scores = model.score_batch(&heads, &rels, &tails)?;
let neg_entities = Tensor::rand(
0.0_f32,
num_entities as f32,
(actual_bs, config.num_negatives),
&model.device,
)?
.to_dtype(DType::U32)?;
let corrupt_mask = Tensor::rand(
0.0_f32,
1.0_f32,
(actual_bs, config.num_negatives),
&model.device,
)?;
let half = Tensor::full(0.5_f32, (actual_bs, config.num_negatives), &model.device)?;
let corrupt_head = corrupt_mask.lt(&half)?;
let heads_exp = heads
.unsqueeze(1)?
.expand((actual_bs, config.num_negatives))?;
let rels_exp = rels
.unsqueeze(1)?
.expand((actual_bs, config.num_negatives))?;
let tails_exp = tails
.unsqueeze(1)?
.expand((actual_bs, config.num_negatives))?;
let neg_heads = corrupt_head.where_cond(&neg_entities, &heads_exp)?;
let neg_tails = corrupt_head.where_cond(&tails_exp, &neg_entities)?;
let neg_scores = model
.score_batch(
&neg_heads.flatten_all()?,
&rels_exp.flatten_all()?,
&neg_tails.flatten_all()?,
)?
.reshape((actual_bs, config.num_negatives))?;
let neg_weights = if alpha > 0.0 {
let scaled = (neg_scores.detach() * (-(alpha as f64)))?;
candle_nn::ops::softmax(&scaled, D::Minus1)?
} else {
Tensor::ones((actual_bs, config.num_negatives), DType::F32, &model.device)?
.affine(1.0 / config.num_negatives as f64, 0.0)?
};
let pos_loss = log_sigmoid(&(pos_scores.neg()? + gamma as f64)?)?.neg()?;
let neg_loss_per = log_sigmoid(&(neg_scores - gamma as f64)?)?;
let weighted_neg_loss = (&neg_weights * &neg_loss_per)?.sum(D::Minus1)?.neg()?;
let per_triple_loss = (pos_loss + weighted_neg_loss)?;
if let Some(ref freq) = entity_freq {
let subsample_w: Vec<f32> = batch
.iter()
.map(|triple| 1.0 / ((freq[triple.head] + freq[triple.tail]) as f32).sqrt())
.collect();
let subsample_t = Tensor::from_vec(subsample_w, actual_bs, &model.device)?;
(&per_triple_loss * &subsample_t)?.mean_all()?
} else {
per_triple_loss.mean_all()?
}
};
if n3_coeff > 0.0 {
let n3 = model.n3_penalty(&heads, &rels, &tails)?;
loss = (loss + (n3 * n3_coeff as f64)?)?;
}
if config.l2_reg > 0.0 {
let h = model
.entity_embeddings
.as_tensor()
.index_select(&heads, 0)?;
let r = model
.relation_embeddings
.as_tensor()
.index_select(&rels, 0)?;
let t = model
.entity_embeddings
.as_tensor()
.index_select(&tails, 0)?;
let l2 = ((h.sqr()?.mean_all()? + r.sqr()?.mean_all()?)? + t.sqr()?.mean_all()?)?;
loss = (loss + (l2 * config.l2_reg as f64)?)?;
}
optimizer.backward_step(&loss)?;
if config.normalize_entities {
let ent = model.entity_embeddings.as_tensor();
let norms = ent.sqr()?.sum(D::Minus1)?.sqrt()?.unsqueeze(D::Minus1)?;
let normalized = ent.broadcast_div(&norms.clamp(1e-8, f64::MAX)?)?;
model.entity_embeddings.set(&normalized)?;
}
epoch_loss += loss.to_scalar::<f32>()? as f64;
n_batches += 1;
}
let avg_loss = (epoch_loss / n_batches as f64) as f32;
losses.push(avg_loss);
epoch_times.push(epoch_start.elapsed().as_secs_f32());
if config.cosine_cycles > 0 {
let effective_epoch = _epoch.saturating_sub(config.warmup_epochs);
let total_effective = config.epochs.saturating_sub(config.warmup_epochs);
let epochs_per_cycle = total_effective / config.cosine_cycles;
if epochs_per_cycle > 0
&& effective_epoch > 0
&& (effective_epoch + 1) % epochs_per_cycle == 0
{
if let (Ok(ev), Ok(rv)) = (model.entity_vecs(), model.relation_vecs()) {
eprintln!(
"Snapshot {} saved at epoch {}",
snapshots.len() + 1,
_epoch + 1
);
snapshots.push(Snapshot {
entity_vecs: ev,
relation_vecs: rv,
epoch: _epoch + 1,
});
}
}
}
if config.log_interval > 0 && (_epoch + 1) % config.log_interval == 0 {
let epoch_secs = epoch_start.elapsed().as_secs_f32();
let ent_norm = model
.entity_embeddings
.as_tensor()
.sqr()
.and_then(|t| t.mean_all())
.and_then(|t| t.to_scalar::<f32>())
.map(|v| v.sqrt())
.unwrap_or(0.0);
eprintln!(
"epoch {:>4} | loss {:.4} | {:.1}s | emb_rms {:.4}",
_epoch + 1,
avg_loss,
epoch_secs,
ent_norm,
);
}
if let Some(ref dir) = config.checkpoint_dir {
if config.checkpoint_interval > 0 && (_epoch + 1) % config.checkpoint_interval == 0 {
if let (Ok(ent), Ok(rel)) = (model.entity_vecs(), model.relation_vecs()) {
let ent_names: Vec<String> = (0..num_entities).map(|i| i.to_string()).collect();
let rel_names: Vec<String> =
(0..num_relations).map(|i| i.to_string()).collect();
let _ = crate::io::export_embeddings(dir, &ent_names, &ent, &rel_names, &rel);
eprintln!("Checkpoint saved to {}", dir.display());
}
}
}
if swa_active && _epoch + 1 >= config.swa_start_epoch {
if let (Ok(ent_flat), Ok(rel_flat)) = (
model
.entity_embeddings
.as_tensor()
.flatten_all()?
.to_vec1::<f32>(),
model
.relation_embeddings
.as_tensor()
.flatten_all()?
.to_vec1::<f32>(),
) {
swa_count += 1;
let update = |avg: &mut Option<Vec<f32>>, current: &[f32]| match avg {
None => *avg = Some(current.to_vec()),
Some(ref mut buf) => {
for (a, &c) in buf.iter_mut().zip(current.iter()) {
*a += (c - *a) / swa_count as f32;
}
}
};
update(&mut swa_ent, &ent_flat);
update(&mut swa_rel, &rel_flat);
}
}
if let Some(ref val) = validation {
if config.eval_interval > 0 && (_epoch + 1) % config.eval_interval == 0 {
let scorer: Box<dyn crate::Scorer + Sync> = match model.model_type {
ModelType::TransE => Box::new(model.to_transe()?),
ModelType::RotatE => Box::new(model.to_rotate()?),
ModelType::ComplEx => Box::new(model.to_complex()?),
ModelType::DistMult => Box::new(model.to_distmult()?),
};
let metrics = crate::eval::evaluate_link_prediction(
scorer.as_ref(),
val.valid_triples,
val.filter,
num_entities,
);
if metrics.mrr > best_mrr {
best_mrr = metrics.mrr;
patience_counter = 0;
best_entity_vecs = model.entity_vecs().ok();
best_relation_vecs = model.relation_vecs().ok();
} else {
patience_counter += 1;
if patience_counter >= config.patience {
eprintln!(
"Early stopping at epoch {} (best MRR: {:.4})",
_epoch + 1,
best_mrr,
);
break;
}
}
}
}
}
if let (Some(ent_vecs), Some(rel_vecs)) = (best_entity_vecs, best_relation_vecs) {
let ent_flat: Vec<f32> = ent_vecs.iter().flat_map(|v| v.iter().copied()).collect();
let rel_flat: Vec<f32> = rel_vecs.iter().flat_map(|v| v.iter().copied()).collect();
let ent_shape = model.entity_embeddings.shape().clone();
let rel_shape = model.relation_embeddings.shape().clone();
let ent_t = Tensor::from_vec(ent_flat, ent_shape, &model.device)?;
let rel_t = Tensor::from_vec(rel_flat, rel_shape, &model.device)?;
model.entity_embeddings.set(&ent_t)?;
model.relation_embeddings.set(&rel_t)?;
}
let ent_cols = model.entity_embeddings.as_tensor().dim(1)?;
let rel_cols = model.relation_embeddings.as_tensor().dim(1)?;
let swa_entity_vecs =
swa_ent.map(|flat| flat.chunks_exact(ent_cols).map(|c| c.to_vec()).collect());
let swa_relation_vecs =
swa_rel.map(|flat| flat.chunks_exact(rel_cols).map(|c| c.to_vec()).collect());
Ok(TrainResult {
model,
losses,
epoch_times,
snapshots,
swa_entity_vecs,
swa_relation_vecs,
})
}
fn log_sigmoid(x: &Tensor) -> Result<Tensor> {
let neg_x = x.neg()?;
let abs_x = x.abs()?;
let neg_abs = abs_x.neg()?;
let relu_neg = neg_x.relu()?;
let softplus = (neg_abs.exp()? + 1.0)?.log()?;
let result = (relu_neg.neg()? - softplus)?;
Ok(result)
}
fn tensor_to_vecs(t: &Tensor) -> Result<Vec<Vec<f32>>> {
let t = t.to_device(&Device::Cpu)?;
let rows = t.dim(0)?;
let cols = t.dim(1)?;
let data = t.flatten_all()?.to_vec1::<f32>()?;
Ok((0..rows)
.map(|i| data[i * cols..(i + 1) * cols].to_vec())
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::TripleIds;
use crate::Scorer;
fn tid(h: usize, r: usize, t: usize) -> TripleIds {
TripleIds::new(h, r, t)
}
#[test]
fn log_sigmoid_basic() {
let device = Device::Cpu;
let x = Tensor::new(&[0.0_f32, 10.0, -10.0], &device).unwrap();
let result = log_sigmoid(&x).unwrap().to_vec1::<f32>().unwrap();
assert!((result[0] - (-0.693)).abs() < 0.01, "got {}", result[0]);
assert!(result[1] > -0.001, "got {}", result[1]);
assert!((result[2] - (-10.0)).abs() < 0.01, "got {}", result[2]);
}
#[test]
fn train_transe_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
let config = TrainConfig {
model_type: ModelType::TransE,
dim: 8,
num_negatives: 4,
gamma: 6.0,
adversarial_temperature: 0.5,
lr: 0.01,
n3_reg: 0.0,
batch_size: 4,
epochs: 5,
..TrainConfig::default()
};
let result = train(&triples, 3, 2, &config, &device).unwrap();
assert_eq!(result.losses.len(), 5);
assert!(result.losses.iter().all(|l| l.is_finite()));
let model = result.model.to_transe().unwrap();
assert_eq!(model.num_entities(), 3);
}
#[test]
fn train_rotate_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
let config = TrainConfig {
model_type: ModelType::RotatE,
dim: 4,
num_negatives: 2,
gamma: 6.0,
adversarial_temperature: 1.0,
lr: 0.01,
n3_reg: 0.0,
batch_size: 2,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"RotatE loss should decrease: {first} -> {last}"
);
let model = result.model.to_rotate().unwrap();
assert_eq!(model.num_entities(), 3);
}
#[test]
fn train_complex_with_n3() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
let config = TrainConfig {
model_type: ModelType::ComplEx,
dim: 4,
num_negatives: 2,
gamma: 6.0,
adversarial_temperature: 1.0,
lr: 0.01,
n3_reg: 0.001,
batch_size: 2,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"ComplEx loss should decrease: {first} -> {last}"
);
let model = result.model.to_complex().unwrap();
assert_eq!(model.num_entities(), 3);
}
#[test]
fn train_distmult_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
num_negatives: 2,
gamma: 6.0,
adversarial_temperature: 0.0,
lr: 0.01,
n3_reg: 0.0,
batch_size: 2,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"DistMult loss should decrease: {first} -> {last}"
);
let model = result.model.to_distmult().unwrap();
assert_eq!(model.num_entities(), 3);
}
#[test]
fn loss_decreases() {
let device = Device::Cpu;
let triples: Vec<_> = (0..20).map(|i| tid(i % 10, i % 3, (i + 1) % 10)).collect();
let config = TrainConfig {
model_type: ModelType::TransE,
dim: 16,
num_negatives: 8,
gamma: 6.0,
adversarial_temperature: 0.5,
lr: 0.01,
n3_reg: 0.0,
batch_size: 10,
epochs: 50,
..TrainConfig::default()
};
let result = train(&triples, 10, 3, &config, &device).unwrap();
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"Loss should decrease: first={first}, last={last}"
);
}
#[test]
fn transe_achieves_nonzero_mrr_on_trivial_graph() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3), tid(3, 0, 4)];
let config = TrainConfig {
model_type: ModelType::TransE,
dim: 32,
num_negatives: 4,
gamma: 6.0,
adversarial_temperature: 0.0,
lr: 0.01,
n3_reg: 0.0,
batch_size: 4,
epochs: 500,
..TrainConfig::default()
};
let result = train(&triples, 5, 1, &config, &device).unwrap();
let model = result.model.to_transe().unwrap();
let ds = crate::dataset::Dataset::new(
triples
.iter()
.map(|t| {
crate::dataset::Triple::new(
t.head.to_string(),
t.relation.to_string(),
t.tail.to_string(),
)
})
.collect(),
Vec::new(),
Vec::new(),
)
.into_interned();
let filter = crate::dataset::FilterIndex::from_dataset(&ds);
let metrics = crate::eval::evaluate_link_prediction(&model, &triples, &filter, 5);
assert!(
metrics.mrr > 0.3,
"TransE should achieve MRR > 0.3 on trivial graph, got {:.4}",
metrics.mrr
);
}
#[test]
fn one_to_n_distmult_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
one_to_n: true,
label_smoothing: 0.1,
lr: 0.01,
batch_size: 4,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 2, &config, &device).unwrap();
assert_eq!(result.losses.len(), 10);
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(last < first, "1-N loss should decrease: {first} -> {last}");
}
#[test]
fn one_to_n_transe_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3)];
let config = TrainConfig {
model_type: ModelType::TransE,
dim: 8,
one_to_n: true,
label_smoothing: 0.1,
lr: 0.001,
batch_size: 3,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 4, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
}
#[test]
fn adagrad_optimizer_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
let config = TrainConfig {
model_type: ModelType::DistMult,
optimizer: OptimizerType::Adagrad,
dim: 8,
init_scale: 1e-3,
lr: 0.1,
one_to_n: true,
batch_size: 4,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 2, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"Adagrad loss should decrease: {first} -> {last}"
);
}
#[test]
fn multi_hot_labels_with_duplicate_tails() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(0, 0, 2), tid(1, 0, 0)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
one_to_n: true,
batch_size: 3,
epochs: 5,
..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
}
#[test]
fn n3_regularization_complex_moduli() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
let config = TrainConfig {
model_type: ModelType::ComplEx,
dim: 4,
n3_reg: 0.1,
one_to_n: true,
batch_size: 2,
epochs: 5,
..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
}
#[test]
fn l2_regularization_reduces_embedding_norm() {
let device = Device::Cpu;
let triples: Vec<_> = (0..20).map(|i| tid(i % 5, 0, (i + 1) % 5)).collect();
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
l2_reg: 0.1,
one_to_n: true,
batch_size: 10,
epochs: 20,
..TrainConfig::default()
};
let result = train(&triples, 5, 1, &config, &device).unwrap();
let ent_vecs = result.model.entity_vecs().unwrap();
let max_norm: f32 = ent_vecs
.iter()
.map(|v| v.iter().map(|x| x * x).sum::<f32>().sqrt())
.fold(0.0_f32, f32::max);
assert!(
max_norm < 10.0,
"L2 reg should keep norms small, got max_norm={max_norm}"
);
}
#[test]
fn lr_warmup_ramps_linearly() {
let config = TrainConfig {
lr: 0.01,
warmup_epochs: 10,
epochs: 100,
..TrainConfig::default()
};
let lr0 = learning_rate(0, &config);
let lr5 = learning_rate(5, &config);
let lr9 = learning_rate(9, &config);
assert!((lr0 - 0.001).abs() < 1e-10, "epoch 0: {lr0}");
assert!((lr5 - 0.006).abs() < 1e-10, "epoch 5: {lr5}");
assert!((lr9 - 0.01).abs() < 1e-10, "epoch 9: {lr9}");
}
#[test]
fn lr_constant_after_warmup_without_cosine() {
let config = TrainConfig {
lr: 0.01,
warmup_epochs: 5,
cosine_cycles: 0,
epochs: 100,
..TrainConfig::default()
};
let lr = learning_rate(50, &config);
assert!((lr - 0.01).abs() < 1e-10, "should be base LR: {lr}");
}
#[test]
fn lr_cosine_starts_at_base_and_decays() {
let config = TrainConfig {
lr: 0.01,
warmup_epochs: 0,
cosine_cycles: 1,
cosine_min_lr_frac: 0.1,
epochs: 100,
..TrainConfig::default()
};
let lr_start = learning_rate(0, &config);
let lr_mid = learning_rate(50, &config);
let lr_end = learning_rate(99, &config);
assert!(
(lr_start - 0.01).abs() < 1e-6,
"cosine should start at base LR: {lr_start}"
);
assert!(
lr_mid < lr_start,
"mid-cycle LR should be below start: {lr_mid}"
);
assert!(
lr_end < lr_mid,
"end-of-cycle LR should be below mid: {lr_end}"
);
assert!(
lr_end >= 0.001 - 1e-10,
"LR should not drop below min: {lr_end}"
);
}
#[test]
fn lr_cosine_min_frac_respected() {
let config = TrainConfig {
lr: 0.1,
warmup_epochs: 0,
cosine_cycles: 1,
cosine_min_lr_frac: 0.1,
epochs: 100,
..TrainConfig::default()
};
for epoch in 0..100 {
let lr = learning_rate(epoch, &config);
assert!(lr >= 0.1 * 0.1 - 1e-10, "epoch {epoch}: LR {lr} below min");
assert!(lr <= 0.1 + 1e-10, "epoch {epoch}: LR {lr} above base");
}
}
#[test]
fn lr_always_positive() {
let config = TrainConfig {
lr: 0.001,
warmup_epochs: 10,
cosine_cycles: 3,
cosine_min_lr_frac: 0.1,
epochs: 300,
..TrainConfig::default()
};
for epoch in 0..300 {
let lr = learning_rate(epoch, &config);
assert!(lr > 0.0, "epoch {epoch}: LR must be positive, got {lr}");
}
}
fn make_trivial_graph() -> Vec<TripleIds> {
vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3), tid(3, 0, 4)]
}
fn eval_mrr(
triples: &[TripleIds],
model: &(dyn crate::Scorer + Sync),
num_entities: usize,
) -> f32 {
let ds = crate::dataset::Dataset::new(
triples
.iter()
.map(|t| {
crate::dataset::Triple::new(
t.head.to_string(),
t.relation.to_string(),
t.tail.to_string(),
)
})
.collect(),
Vec::new(),
Vec::new(),
)
.into_interned();
let filter = crate::dataset::FilterIndex::from_dataset(&ds);
crate::eval::evaluate_link_prediction(model, triples, &filter, num_entities).mrr
}
#[test]
fn rotate_achieves_nonzero_mrr_on_trivial_graph() {
let device = Device::Cpu;
let triples = make_trivial_graph();
let config = TrainConfig {
model_type: ModelType::RotatE,
dim: 32,
num_negatives: 4,
gamma: 6.0,
adversarial_temperature: 0.0,
lr: 0.01,
batch_size: 4,
epochs: 500,
..TrainConfig::default()
};
let result = train(&triples, 5, 1, &config, &device).unwrap();
let model = result.model.to_rotate().unwrap();
let mrr = eval_mrr(&triples, &model, 5);
assert!(
mrr > 0.3,
"RotatE should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
);
}
#[test]
fn complex_achieves_nonzero_mrr_on_trivial_graph() {
let device = Device::Cpu;
let triples = make_trivial_graph();
let config = TrainConfig {
model_type: ModelType::ComplEx,
dim: 32,
one_to_n: true,
lr: 0.01,
batch_size: 4,
epochs: 200,
..TrainConfig::default()
};
let result = train(&triples, 5, 1, &config, &device).unwrap();
let model = result.model.to_complex().unwrap();
let mrr = eval_mrr(&triples, &model, 5);
assert!(
mrr > 0.3,
"ComplEx should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
);
}
#[test]
fn distmult_achieves_nonzero_mrr_on_trivial_graph() {
let device = Device::Cpu;
let triples = make_trivial_graph();
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 32,
one_to_n: true,
lr: 0.01,
batch_size: 4,
epochs: 200,
..TrainConfig::default()
};
let result = train(&triples, 5, 1, &config, &device).unwrap();
let model = result.model.to_distmult().unwrap();
let mrr = eval_mrr(&triples, &model, 5);
assert!(
mrr > 0.3,
"DistMult should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
);
}
#[test]
fn swa_produces_averaged_embeddings() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
one_to_n: true,
batch_size: 3,
epochs: 10,
swa_start_epoch: 5,
..TrainConfig::default()
};
let result = train(&triples, 3, 2, &config, &device).unwrap();
assert!(
result.swa_entity_vecs.is_some(),
"SWA should produce entity vecs"
);
assert!(
result.swa_relation_vecs.is_some(),
"SWA should produce relation vecs"
);
let swa_ent = result.swa_entity_vecs.unwrap();
assert_eq!(swa_ent.len(), 3);
assert_eq!(swa_ent[0].len(), 8);
let final_ent = result.model.entity_vecs().unwrap();
let differs = swa_ent
.iter()
.zip(final_ent.iter())
.any(|(a, b)| a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-8));
assert!(differs, "SWA average should differ from final model");
}
#[test]
fn swa_disabled_returns_none() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 4,
one_to_n: true,
batch_size: 2,
epochs: 5,
swa_start_epoch: 0, ..TrainConfig::default()
};
let result = train(&triples, 3, 1, &config, &device).unwrap();
assert!(result.swa_entity_vecs.is_none());
assert!(result.swa_relation_vecs.is_none());
}
#[test]
fn relation_prediction_loss_smoke() {
let device = Device::Cpu;
let triples = vec![tid(0, 0, 1), tid(1, 1, 2), tid(2, 0, 0), tid(0, 1, 2)];
let config = TrainConfig {
model_type: ModelType::DistMult,
dim: 8,
one_to_n: true,
relation_prediction_weight: 0.1,
batch_size: 4,
epochs: 10,
..TrainConfig::default()
};
let result = train(&triples, 3, 2, &config, &device).unwrap();
assert!(result.losses.iter().all(|l| l.is_finite()));
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"Loss with relation prediction should decrease: {first} -> {last}"
);
}
}