use std::collections::HashMap;
use scirs2_core::random::{Rng, RngExt};
use crate::error::{GraphError, Result};
#[derive(Debug, Clone)]
pub struct KGDataset {
pub triples: Vec<(usize, usize, usize)>,
pub n_entities: usize,
pub n_relations: usize,
pub entity_labels: Vec<String>,
pub relation_labels: Vec<String>,
}
impl KGDataset {
pub fn new(
triples: Vec<(usize, usize, usize)>,
n_entities: usize,
n_relations: usize,
) -> Result<Self> {
for &(h, r, t) in &triples {
if h >= n_entities || t >= n_entities {
return Err(GraphError::InvalidParameter {
param: "triples".to_string(),
value: format!("entity index ({h},{t}) out of range"),
expected: format!("< n_entities={n_entities}"),
context: "KGDataset::new".to_string(),
});
}
if r >= n_relations {
return Err(GraphError::InvalidParameter {
param: "triples".to_string(),
value: format!("relation index {r} out of range"),
expected: format!("< n_relations={n_relations}"),
context: "KGDataset::new".to_string(),
});
}
}
let entity_labels = (0..n_entities).map(|i| format!("e{i}")).collect();
let relation_labels = (0..n_relations).map(|i| format!("r{i}")).collect();
Ok(KGDataset {
triples,
n_entities,
n_relations,
entity_labels,
relation_labels,
})
}
pub fn from_str_triples(
triples: &[(&str, &str, &str)],
) -> Self {
let mut entity_map: HashMap<String, usize> = HashMap::new();
let mut relation_map: HashMap<String, usize> = HashMap::new();
let mut entity_labels: Vec<String> = Vec::new();
let mut relation_labels: Vec<String> = Vec::new();
let mut get_or_insert_entity = |s: &str| -> usize {
if let Some(&idx) = entity_map.get(s) {
idx
} else {
let idx = entity_labels.len();
entity_map.insert(s.to_string(), idx);
entity_labels.push(s.to_string());
idx
}
};
let mut indexed_triples: Vec<(usize, usize, usize)> = Vec::with_capacity(triples.len());
for &(h, r, t) in triples {
let hi = get_or_insert_entity(h);
let ti = get_or_insert_entity(t);
let ri = if let Some(&idx) = relation_map.get(r) {
idx
} else {
let idx = relation_labels.len();
relation_map.insert(r.to_string(), idx);
relation_labels.push(r.to_string());
idx
};
indexed_triples.push((hi, ri, ti));
}
let n_entities = entity_labels.len();
let n_relations = relation_labels.len();
KGDataset {
triples: indexed_triples,
n_entities,
n_relations,
entity_labels,
relation_labels,
}
}
pub fn len(&self) -> usize {
self.triples.len()
}
pub fn is_empty(&self) -> bool {
self.triples.is_empty()
}
pub fn corrupt_triple(&self, triple: (usize, usize, usize)) -> (usize, usize, usize) {
let (h, r, t) = triple;
let mut rng = scirs2_core::random::rng();
let replace_head = rng.random::<f64>() < 0.5;
if replace_head {
let mut new_h = (rng.random::<f64>() * self.n_entities as f64) as usize;
new_h = new_h.min(self.n_entities - 1);
if new_h == h && self.n_entities > 1 {
new_h = (new_h + 1) % self.n_entities;
}
(new_h, r, t)
} else {
let mut new_t = (rng.random::<f64>() * self.n_entities as f64) as usize;
new_t = new_t.min(self.n_entities - 1);
if new_t == t && self.n_entities > 1 {
new_t = (new_t + 1) % self.n_entities;
}
(h, r, new_t)
}
}
}
fn init_embeddings(n_items: usize, dim: usize, scale: f64) -> Vec<Vec<f64>> {
let mut rng = scirs2_core::random::rng();
(0..n_items)
.map(|_| {
let mut row: Vec<f64> = (0..dim)
.map(|_| rng.random::<f64>() * 2.0 * scale - scale)
.collect();
let norm = row.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
row.iter_mut().for_each(|x| *x /= norm);
row
})
.collect()
}
#[inline]
fn l2_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn l2_normalize(v: &mut Vec<f64>) {
let norm = l2_norm(v).max(1e-12);
v.iter_mut().for_each(|x| *x /= norm);
}
#[derive(Debug, Clone)]
pub struct TransE {
pub entity_embeddings: Vec<Vec<f64>>,
pub relation_embeddings: Vec<Vec<f64>>,
pub dim: usize,
pub norm_order: u32,
}
impl TransE {
pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
if dim == 0 {
return Err(GraphError::InvalidParameter {
param: "dim".to_string(),
value: "0".to_string(),
expected: "> 0".to_string(),
context: "TransE::new".to_string(),
});
}
let entity_embeddings = init_embeddings(n_entities, dim, 1.0 / (dim as f64).sqrt());
let relation_embeddings = init_embeddings(n_relations, dim, 1.0 / (dim as f64).sqrt());
Ok(TransE {
entity_embeddings,
relation_embeddings,
dim,
norm_order: 2,
})
}
pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.validate_indices(h, r, t)?;
let he = &self.entity_embeddings[h];
let re = &self.relation_embeddings[r];
let te = &self.entity_embeddings[t];
let dist = translation_distance(he, re, te, self.norm_order);
Ok(-dist)
}
pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
let n = self.entity_embeddings.len();
if h >= n {
return Err(GraphError::InvalidParameter {
param: "h".to_string(),
value: format!("{h}"),
expected: format!("< {n}"),
context: "TransE::predict_tails".to_string(),
});
}
let he = &self.entity_embeddings[h];
let re = &self.relation_embeddings[r];
let mut scores: Vec<(usize, f64)> = (0..n)
.map(|t| {
let te = &self.entity_embeddings[t];
let dist = translation_distance(he, re, te, self.norm_order);
(t, -dist) })
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
}
pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.score(h, r, t)
}
pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
let mut total_loss = 0.0;
for &(h, r, t) in &dataset.triples {
let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
let pos_score = {
let he = &self.entity_embeddings[h];
let re = &self.relation_embeddings[r];
let te = &self.entity_embeddings[t];
translation_distance(he, re, te, self.norm_order)
};
let neg_score = {
let he = &self.entity_embeddings[nh];
let re = &self.relation_embeddings[nr];
let te = &self.entity_embeddings[nt];
translation_distance(he, re, te, self.norm_order)
};
let loss = (margin + pos_score - neg_score).max(0.0);
total_loss += loss;
if loss > 0.0 {
let dim = self.dim;
let g_pos: Vec<f64> = (0..dim)
.map(|k| {
let diff = self.entity_embeddings[h][k]
+ self.relation_embeddings[r][k]
- self.entity_embeddings[t][k];
if diff >= 0.0 { 1.0 } else { -1.0 }
})
.collect();
let g_neg: Vec<f64> = (0..dim)
.map(|k| {
let diff = self.entity_embeddings[nh][k]
+ self.relation_embeddings[nr][k]
- self.entity_embeddings[nt][k];
if diff >= 0.0 { 1.0 } else { -1.0 }
})
.collect();
for k in 0..dim {
self.entity_embeddings[h][k] -= lr * g_pos[k];
self.entity_embeddings[t][k] += lr * g_pos[k];
self.relation_embeddings[r][k] -= lr * g_pos[k];
}
for k in 0..dim {
self.entity_embeddings[nh][k] += lr * g_neg[k];
self.entity_embeddings[nt][k] -= lr * g_neg[k];
}
l2_normalize(&mut self.entity_embeddings[h]);
l2_normalize(&mut self.entity_embeddings[t]);
l2_normalize(&mut self.entity_embeddings[nh]);
l2_normalize(&mut self.entity_embeddings[nt]);
}
}
total_loss
}
fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
let ne = self.entity_embeddings.len();
let nr = self.relation_embeddings.len();
if h >= ne || t >= ne {
return Err(GraphError::InvalidParameter {
param: "entity_index".to_string(),
value: format!("({h},{t})"),
expected: format!("< {ne}"),
context: "TransE score".to_string(),
});
}
if r >= nr {
return Err(GraphError::InvalidParameter {
param: "relation_index".to_string(),
value: format!("{r}"),
expected: format!("< {nr}"),
context: "TransE score".to_string(),
});
}
Ok(())
}
}
fn translation_distance(h: &[f64], r: &[f64], t: &[f64], norm_order: u32) -> f64 {
let diff_sum: f64 = h
.iter()
.zip(r.iter())
.zip(t.iter())
.map(|((&hi, &ri), &ti)| {
let d = hi + ri - ti;
match norm_order {
1 => d.abs(),
_ => d * d,
}
})
.sum();
match norm_order {
1 => diff_sum,
_ => diff_sum.sqrt(),
}
}
#[derive(Debug, Clone)]
pub struct DistMult {
pub entity_embeddings: Vec<Vec<f64>>,
pub relation_embeddings: Vec<Vec<f64>>,
pub dim: usize,
}
impl DistMult {
pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
if dim == 0 {
return Err(GraphError::InvalidParameter {
param: "dim".to_string(),
value: "0".to_string(),
expected: "> 0".to_string(),
context: "DistMult::new".to_string(),
});
}
let mut rng = scirs2_core::random::rng();
let scale = 1.0 / (dim as f64).sqrt();
let mut mk_table = |n: usize| -> Vec<Vec<f64>> {
(0..n)
.map(|_| {
(0..dim)
.map(|_| rng.random::<f64>() * 2.0 * scale - scale)
.collect()
})
.collect()
};
Ok(DistMult {
entity_embeddings: mk_table(n_entities),
relation_embeddings: mk_table(n_relations),
dim,
})
}
pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.validate_indices(h, r, t)?;
let score = distmult_score(
&self.entity_embeddings[h],
&self.relation_embeddings[r],
&self.entity_embeddings[t],
);
Ok(score)
}
pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.score(h, r, t)
}
pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
let n = self.entity_embeddings.len();
if h >= n {
return Err(GraphError::InvalidParameter {
param: "h".to_string(),
value: format!("{h}"),
expected: format!("< {n}"),
context: "DistMult::predict_tails".to_string(),
});
}
let he = &self.entity_embeddings[h];
let re = &self.relation_embeddings[r];
let mut scores: Vec<(usize, f64)> = (0..n)
.map(|ti| {
let te = &self.entity_embeddings[ti];
(ti, distmult_score(he, re, te))
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
}
pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
let mut total_loss = 0.0;
for &(h, r, t) in &dataset.triples {
let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
let pos = distmult_score(
&self.entity_embeddings[h],
&self.relation_embeddings[r],
&self.entity_embeddings[t],
);
let neg = distmult_score(
&self.entity_embeddings[nh],
&self.relation_embeddings[nr],
&self.entity_embeddings[nt],
);
let loss = (margin - pos + neg).max(0.0);
total_loss += loss;
if loss > 0.0 {
let dim = self.dim;
for k in 0..dim {
let re = self.relation_embeddings[r][k];
let te = self.entity_embeddings[t][k];
self.entity_embeddings[h][k] += lr * re * te;
}
for k in 0..dim {
let re = self.relation_embeddings[nr][k];
let te = self.entity_embeddings[nt][k];
self.entity_embeddings[nh][k] -= lr * re * te;
}
}
}
total_loss
}
fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
let ne = self.entity_embeddings.len();
let nr = self.relation_embeddings.len();
if h >= ne || t >= ne {
return Err(GraphError::InvalidParameter {
param: "entity_index".to_string(),
value: format!("({h},{t})"),
expected: format!("< {ne}"),
context: "DistMult score".to_string(),
});
}
if r >= nr {
return Err(GraphError::InvalidParameter {
param: "relation_index".to_string(),
value: format!("{r}"),
expected: format!("< {nr}"),
context: "DistMult score".to_string(),
});
}
Ok(())
}
}
fn distmult_score(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
h.iter()
.zip(r.iter())
.zip(t.iter())
.map(|((&hi, &ri), &ti)| hi * ri * ti)
.sum()
}
#[derive(Debug, Clone)]
pub struct ComplEx {
pub entity_re: Vec<Vec<f64>>,
pub entity_im: Vec<Vec<f64>>,
pub relation_re: Vec<Vec<f64>>,
pub relation_im: Vec<Vec<f64>>,
pub dim: usize,
}
impl ComplEx {
pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
if dim == 0 {
return Err(GraphError::InvalidParameter {
param: "dim".to_string(),
value: "0".to_string(),
expected: "> 0".to_string(),
context: "ComplEx::new".to_string(),
});
}
let scale = 1.0 / (dim as f64).sqrt();
Ok(ComplEx {
entity_re: init_embeddings(n_entities, dim, scale),
entity_im: init_embeddings(n_entities, dim, scale),
relation_re: init_embeddings(n_relations, dim, scale),
relation_im: init_embeddings(n_relations, dim, scale),
dim,
})
}
pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.validate_indices(h, r, t)?;
let s = complex_score(
&self.entity_re[h],
&self.entity_im[h],
&self.relation_re[r],
&self.relation_im[r],
&self.entity_re[t],
&self.entity_im[t],
);
Ok(s)
}
pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
self.score(h, r, t)
}
pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
let n = self.entity_re.len();
if h >= n {
return Err(GraphError::InvalidParameter {
param: "h".to_string(),
value: format!("{h}"),
expected: format!("< {n}"),
context: "ComplEx::predict_tails".to_string(),
});
}
let mut scores: Vec<(usize, f64)> = (0..n)
.map(|ti| {
let s = complex_score(
&self.entity_re[h],
&self.entity_im[h],
&self.relation_re[r],
&self.relation_im[r],
&self.entity_re[ti],
&self.entity_im[ti],
);
(ti, s)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
}
pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
let mut total_loss = 0.0;
for &(h, r, t) in &dataset.triples {
let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
let pos = complex_score(
&self.entity_re[h],
&self.entity_im[h],
&self.relation_re[r],
&self.relation_im[r],
&self.entity_re[t],
&self.entity_im[t],
);
let neg = complex_score(
&self.entity_re[nh],
&self.entity_im[nh],
&self.relation_re[nr],
&self.relation_im[nr],
&self.entity_re[nt],
&self.entity_im[nt],
);
let loss = (margin - pos + neg).max(0.0);
total_loss += loss;
if loss > 0.0 {
let dim = self.dim;
for k in 0..dim {
let re_r = self.relation_re[r][k];
let im_r = self.relation_im[r][k];
let re_t = self.entity_re[t][k];
let im_t = self.entity_im[t][k];
let g_re_h = re_r * re_t + im_r * im_t;
let g_im_h = re_r * im_t - im_r * re_t;
self.entity_re[h][k] += lr * g_re_h;
self.entity_im[h][k] += lr * g_im_h;
let re_rn = self.relation_re[nr][k];
let im_rn = self.relation_im[nr][k];
let re_tn = self.entity_re[nt][k];
let im_tn = self.entity_im[nt][k];
let g_re_hn = re_rn * re_tn + im_rn * im_tn;
let g_im_hn = re_rn * im_tn - im_rn * re_tn;
self.entity_re[nh][k] -= lr * g_re_hn;
self.entity_im[nh][k] -= lr * g_im_hn;
}
}
}
total_loss
}
fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
let ne = self.entity_re.len();
let nr = self.relation_re.len();
if h >= ne || t >= ne {
return Err(GraphError::InvalidParameter {
param: "entity_index".to_string(),
value: format!("({h},{t})"),
expected: format!("< {ne}"),
context: "ComplEx score".to_string(),
});
}
if r >= nr {
return Err(GraphError::InvalidParameter {
param: "relation_index".to_string(),
value: format!("{r}"),
expected: format!("< {nr}"),
context: "ComplEx score".to_string(),
});
}
Ok(())
}
}
fn complex_score(
h_re: &[f64],
h_im: &[f64],
r_re: &[f64],
r_im: &[f64],
t_re: &[f64],
t_im: &[f64],
) -> f64 {
h_re.iter()
.zip(h_im.iter())
.zip(r_re.iter())
.zip(r_im.iter())
.zip(t_re.iter())
.zip(t_im.iter())
.map(|(((((hre, him), rre), rim), tre), tim)| {
hre * rre * tre + him * rre * tim + hre * rim * tim - him * rim * tre
})
.sum()
}
pub enum KgModel {
TransE(TransE),
DistMult(DistMult),
ComplEx(ComplEx),
}
impl KgModel {
pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
match self {
KgModel::TransE(m) => m.link_prediction_score(h, r, t),
KgModel::DistMult(m) => m.link_prediction_score(h, r, t),
KgModel::ComplEx(m) => m.link_prediction_score(h, r, t),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_dataset() -> KGDataset {
let triples = vec![(0, 0, 1), (1, 0, 2), (2, 1, 3), (0, 1, 3)];
KGDataset::new(triples, 4, 2).expect("dataset")
}
#[test]
fn test_dataset_creation() {
let ds = simple_dataset();
assert_eq!(ds.n_entities, 4);
assert_eq!(ds.n_relations, 2);
assert_eq!(ds.len(), 4);
assert!(!ds.is_empty());
}
#[test]
fn test_dataset_from_str_triples() {
let raw = vec![
("Alice", "knows", "Bob"),
("Bob", "likes", "Carol"),
("Alice", "likes", "Carol"),
];
let ds = KGDataset::from_str_triples(&raw);
assert_eq!(ds.n_entities, 3); assert_eq!(ds.n_relations, 2); assert_eq!(ds.len(), 3);
}
#[test]
fn test_dataset_out_of_bounds() {
let triples = vec![(10, 0, 1)]; let result = KGDataset::new(triples, 4, 2);
assert!(result.is_err());
}
#[test]
fn test_corrupt_triple_changes_entity() {
let ds = simple_dataset();
let original = (0, 0, 1);
let corrupted = ds.corrupt_triple(original);
assert_eq!(corrupted.1, 0);
assert!(corrupted.0 != 0 || corrupted.2 != 1);
}
#[test]
fn test_transe_score_finite() {
let model = TransE::new(4, 2, 8).expect("TransE::new");
let score = model.score(0, 0, 1).expect("score");
assert!(score.is_finite());
}
#[test]
fn test_transe_score_range() {
let model = TransE::new(4, 2, 8).expect("TransE::new");
let score = model.score(0, 0, 1).expect("score");
assert!(score <= 0.0);
}
#[test]
fn test_transe_predict_tails_length() {
let model = TransE::new(10, 3, 16).expect("TransE");
let preds = model.predict_tails(0, 0, 5).expect("predict_tails");
assert_eq!(preds.len(), 5);
for &idx in &preds {
assert!(idx < 10);
}
}
#[test]
fn test_transe_train_epoch_reduces_loss() {
let ds = simple_dataset();
let mut model = TransE::new(4, 2, 8).expect("TransE");
let loss0 = model.train_epoch(&ds, 0.01, 1.0);
let loss1 = model.train_epoch(&ds, 0.01, 1.0);
assert!(loss0.is_finite());
assert!(loss1.is_finite());
}
#[test]
fn test_transe_invalid_index() {
let model = TransE::new(4, 2, 8).expect("TransE");
assert!(model.score(10, 0, 1).is_err());
}
#[test]
fn test_distmult_score_finite() {
let model = DistMult::new(4, 2, 8).expect("DistMult");
let score = model.score(0, 0, 1).expect("score");
assert!(score.is_finite());
}
#[test]
fn test_distmult_predict_tails() {
let model = DistMult::new(10, 3, 16).expect("DistMult");
let preds = model.predict_tails(0, 1, 3).expect("predict");
assert_eq!(preds.len(), 3);
}
#[test]
fn test_distmult_train_epoch() {
let ds = simple_dataset();
let mut model = DistMult::new(4, 2, 8).expect("DistMult");
let loss = model.train_epoch(&ds, 0.01, 1.0);
assert!(loss.is_finite());
}
#[test]
fn test_complex_score_finite() {
let model = ComplEx::new(4, 2, 8).expect("ComplEx");
let score = model.score(0, 0, 1).expect("score");
assert!(score.is_finite());
}
#[test]
fn test_complex_predict_tails() {
let model = ComplEx::new(10, 3, 16).expect("ComplEx");
let preds = model.predict_tails(0, 0, 4).expect("predict");
assert_eq!(preds.len(), 4);
}
#[test]
fn test_complex_train_epoch() {
let ds = simple_dataset();
let mut model = ComplEx::new(4, 2, 8).expect("ComplEx");
let loss = model.train_epoch(&ds, 0.01, 1.0);
assert!(loss.is_finite());
}
#[test]
fn test_complex_antisymmetry() {
let model = ComplEx::new(4, 2, 16).expect("ComplEx");
let s1 = model.score(0, 0, 1).expect("s1");
let s2 = model.score(1, 0, 0).expect("s2");
assert!(s1.is_finite());
assert!(s2.is_finite());
}
#[test]
fn test_kgmodel_dispatch() {
let transe = TransE::new(4, 2, 8).expect("TransE");
let model = KgModel::TransE(transe);
let score = model.link_prediction_score(0, 0, 1).expect("score");
assert!(score.is_finite());
}
#[test]
fn test_multi_epoch_training_transe() {
let ds = simple_dataset();
let mut model = TransE::new(4, 2, 16).expect("TransE");
let mut losses = Vec::new();
for _ in 0..5 {
losses.push(model.train_epoch(&ds, 0.01, 1.0));
}
for loss in &losses {
assert!(loss.is_finite());
}
}
#[test]
fn test_complex_score_symmetry_check() {
let mut model = ComplEx::new(2, 1, 2).expect("ComplEx");
model.entity_re[0] = vec![1.0, 0.0];
model.entity_im[0] = vec![0.0, 1.0];
model.relation_re[0] = vec![1.0, 1.0];
model.relation_im[0] = vec![0.0, 0.0];
model.entity_re[1] = vec![1.0, 0.0];
model.entity_im[1] = vec![0.0, 1.0];
let score = model.score(0, 0, 1).expect("manual score");
assert!((score - 2.0).abs() < 1e-10, "expected 2.0, got {score}");
}
}