use burn::module::{Module, Param, ParamId};
use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use burn::tensor::activation;
use burn::tensor::backend::AutodiffBackend;
#[derive(Module, Debug)]
pub struct BurnComplEx<B: Backend> {
entity_re: Param<Tensor<B, 2>>,
entity_im: Param<Tensor<B, 2>>,
relation_re: Param<Tensor<B, 2>>,
relation_im: Param<Tensor<B, 2>>,
}
#[derive(Debug, Clone)]
pub struct BurnTrainConfig {
pub dim: usize,
pub init_scale: f64,
pub lr: f64,
pub label_smoothing: f64,
pub n3_reg: f64,
pub batch_size: usize,
pub epochs: usize,
pub log_interval: usize,
}
impl Default for BurnTrainConfig {
fn default() -> Self {
Self {
dim: 200,
init_scale: 1e-3,
lr: 0.001,
label_smoothing: 0.1,
n3_reg: 0.0,
batch_size: 512,
epochs: 100,
log_interval: 10,
}
}
}
pub struct BurnTrainResult {
pub entity_vecs: Vec<Vec<f32>>,
pub relation_vecs: Vec<Vec<f32>>,
pub dim: usize,
pub losses: Vec<f32>,
}
impl BurnTrainResult {
pub fn to_complex(&self) -> crate::ComplEx {
crate::ComplEx::from_vecs(
self.entity_vecs.clone(),
self.relation_vecs.clone(),
self.dim,
)
}
}
fn init_model<B: AutodiffBackend>(
num_entities: usize,
num_relations: usize,
dim: usize,
init_scale: f64,
device: &B::Device,
) -> BurnComplEx<B> {
let mk = |rows, cols| {
Param::initialized(
ParamId::new(),
Tensor::<B, 2>::random(
[rows, cols],
burn::tensor::Distribution::Normal(0.0, init_scale),
device,
)
.require_grad(),
)
};
BurnComplEx {
entity_re: mk(num_entities, dim),
entity_im: mk(num_entities, dim),
relation_re: mk(num_relations, dim),
relation_im: mk(num_relations, dim),
}
}
fn score_1n<B: Backend>(
model: &BurnComplEx<B>,
heads: &Tensor<B, 1, Int>,
rels: &Tensor<B, 1, Int>,
) -> Tensor<B, 2> {
let h_re = model.entity_re.val().select(0, heads.clone());
let h_im = model.entity_im.val().select(0, heads.clone());
let r_re = model.relation_re.val().select(0, rels.clone());
let r_im = model.relation_im.val().select(0, rels.clone());
let hr_re = h_re.clone() * r_re.clone() - h_im.clone() * r_im.clone();
let hr_im = h_re * r_im + h_im * r_re;
let e_re = model.entity_re.val();
let e_im = model.entity_im.val();
hr_re.matmul(e_re.transpose()) + hr_im.matmul(e_im.transpose())
}
fn score_1n_heads<B: Backend>(
model: &BurnComplEx<B>,
rels: &Tensor<B, 1, Int>,
tails: &Tensor<B, 1, Int>,
) -> Tensor<B, 2> {
let r_re = model.relation_re.val().select(0, rels.clone());
let r_im = model.relation_im.val().select(0, rels.clone());
let t_re = model.entity_re.val().select(0, tails.clone());
let t_im = model.entity_im.val().select(0, tails.clone());
let rc_re = r_re.clone() * t_re.clone() + r_im.clone() * t_im.clone();
let rc_im = r_im * t_re - r_re * t_im;
let e_re = model.entity_re.val();
let e_im = model.entity_im.val();
rc_re.matmul(e_re.transpose()) + rc_im.matmul(e_im.transpose())
}
pub fn train_complex<B: AutodiffBackend>(
train_triples: &[crate::dataset::TripleIds],
num_entities: usize,
num_relations: usize,
config: &BurnTrainConfig,
device: &B::Device,
) -> BurnTrainResult {
let mut model = init_model::<B>(
num_entities,
num_relations,
config.dim,
config.init_scale,
device,
);
let mut optim = AdamWConfig::new()
.with_epsilon(1e-8)
.with_weight_decay(0.0)
.init::<B, BurnComplEx<B>>();
let n_triples = train_triples.len();
let batch_size = config.batch_size.min(n_triples);
let eps = config.label_smoothing;
let mut losses = Vec::with_capacity(config.epochs);
let mut indices: Vec<usize> = (0..n_triples).collect();
for epoch in 0..config.epochs {
let epoch_start = std::time::Instant::now();
{
use rand::seq::SliceRandom;
indices.shuffle(&mut rand::rng());
}
let mut epoch_loss = 0.0_f64;
let mut n_batches = 0u32;
let mut offset = 0;
while offset < n_triples {
let end = (offset + batch_size).min(n_triples);
let batch_idx = &indices[offset..end];
let actual_bs = batch_idx.len();
offset = end;
let heads_data: Vec<i64> = batch_idx
.iter()
.map(|&i| train_triples[i].head as i64)
.collect();
let rels_data: Vec<i64> = batch_idx
.iter()
.map(|&i| train_triples[i].relation as i64)
.collect();
let tails_data: Vec<i64> = batch_idx
.iter()
.map(|&i| train_triples[i].tail as i64)
.collect();
let heads = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(heads_data, [actual_bs]),
device,
);
let rels = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(rels_data, [actual_bs]),
device,
);
let tails = Tensor::<B, 1, Int>::from_data(
burn::tensor::TensorData::new(tails_data.clone(), [actual_bs]),
device,
);
let current = model.clone();
let tail_scores = score_1n(¤t, &heads, &rels);
let tail_log_probs = activation::log_softmax(tail_scores, 1);
let head_scores = score_1n_heads(¤t, &rels, &tails);
let head_log_probs = activation::log_softmax(head_scores, 1);
let tail_ids = tails.clone().unsqueeze_dim(1); let t_nll = tail_log_probs
.clone()
.gather(1, tail_ids)
.squeeze::<1>()
.neg()
.mean();
let head_ids = heads.clone().unsqueeze_dim(1); let h_nll = head_log_probs
.clone()
.gather(1, head_ids)
.squeeze::<1>()
.neg()
.mean();
let nll = (t_nll + h_nll) / 2.0;
let loss = if eps > 0.0 {
let tail_uniform = tail_log_probs.mean().neg();
let head_uniform = head_log_probs.mean().neg();
let uniform = (tail_uniform + head_uniform) / 2.0;
nll * (1.0 - eps) + uniform * eps
} else {
nll
};
let loss_val: f32 = loss.clone().inner().into_scalar().to_f32();
let grads = GradientsParams::from_grads(loss.backward(), ¤t);
if loss_val.is_finite() {
model = optim.step(config.lr, current, grads);
}
epoch_loss += loss_val as f64;
n_batches += 1;
}
let avg_loss = (epoch_loss / n_batches as f64) as f32;
losses.push(avg_loss);
if config.log_interval > 0 && (epoch + 1) % config.log_interval == 0 {
eprintln!(
"epoch {:>4} | loss {:.4} | {:.1}s",
epoch + 1,
avg_loss,
epoch_start.elapsed().as_secs_f32(),
);
}
}
let dim = config.dim;
let extract = |re: &Param<Tensor<B, 2>>, im: &Param<Tensor<B, 2>>| -> Vec<Vec<f32>> {
let re_data: Vec<f32> = re.val().into_data().to_vec().unwrap();
let im_data: Vec<f32> = im.val().into_data().to_vec().unwrap();
let n = re_data.len() / dim;
(0..n)
.map(|i| {
let mut v = Vec::with_capacity(dim * 2);
v.extend_from_slice(&re_data[i * dim..(i + 1) * dim]);
v.extend_from_slice(&im_data[i * dim..(i + 1) * dim]);
v
})
.collect()
};
BurnTrainResult {
entity_vecs: extract(&model.entity_re, &model.entity_im),
relation_vecs: extract(&model.relation_re, &model.relation_im),
dim,
losses,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::TripleIds;
use crate::Scorer;
fn tid(h: usize, r: usize, t: usize) -> TripleIds {
TripleIds::new(h, r, t)
}
#[cfg(feature = "burn-cpu")]
type TestBackend = burn::backend::Autodiff<burn_ndarray::NdArray>;
#[cfg(feature = "burn-cpu")]
fn test_device() -> <TestBackend as Backend>::Device {
burn_ndarray::NdArrayDevice::Cpu
}
#[test]
#[cfg(feature = "burn-cpu")]
fn burn_complex_smoke() {
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
let config = BurnTrainConfig {
dim: 8,
epochs: 10,
batch_size: 4,
..BurnTrainConfig::default()
};
let result = train_complex::<TestBackend>(&triples, 3, 2, &config, &test_device());
assert_eq!(result.losses.len(), 10);
assert!(result.losses.iter().all(|l| l.is_finite()));
let model = result.to_complex();
assert_eq!(model.num_entities(), 3);
}
#[test]
#[cfg(feature = "burn-cpu")]
fn burn_complex_loss_decreases() {
let triples: Vec<_> = (0..20).map(|i| tid(i % 5, i % 2, (i + 1) % 5)).collect();
let config = BurnTrainConfig {
dim: 16,
epochs: 30,
batch_size: 10,
lr: 0.001,
..BurnTrainConfig::default()
};
let result = train_complex::<TestBackend>(&triples, 5, 2, &config, &test_device());
let first = result.losses[0];
let last = *result.losses.last().unwrap();
assert!(
last < first,
"Burn ComplEx loss should decrease: {first} -> {last}"
);
}
#[test]
#[cfg(feature = "burn-cpu")]
fn burn_complex_achieves_nonzero_mrr() {
let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3), tid(3, 0, 4)];
let config = BurnTrainConfig {
dim: 32,
epochs: 200,
batch_size: 4,
lr: 0.001,
..BurnTrainConfig::default()
};
let result = train_complex::<TestBackend>(&triples, 5, 1, &config, &test_device());
let model = result.to_complex();
let ds = crate::dataset::Dataset::new(
triples
.iter()
.map(|t| {
crate::dataset::Triple::new(
t.head.to_string(),
t.relation.to_string(),
t.tail.to_string(),
)
})
.collect(),
Vec::new(),
Vec::new(),
)
.into_interned();
let filter = crate::dataset::FilterIndex::from_dataset(&ds);
let metrics = crate::eval::evaluate_link_prediction(&model, &triples, &filter, 5);
assert!(
metrics.mrr > 0.3,
"Burn ComplEx should achieve MRR > 0.3, got {:.4}",
metrics.mrr
);
}
}