use std::collections::HashMap;
use std::fmt;
#[derive(Debug)]
pub enum KgError {
NotTrained,
UnknownEntity(EntityId),
UnknownRelation(RelationId),
InvalidDimension,
NoTrainingData,
NumericalError(String),
InvalidTopK,
}
impl fmt::Display for KgError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
KgError::NotTrained => write!(f, "model has not been trained"),
KgError::UnknownEntity(id) => write!(f, "unknown entity id {id}"),
KgError::UnknownRelation(id) => write!(f, "unknown relation id {id}"),
KgError::InvalidDimension => write!(f, "embedding dimension must be > 0"),
KgError::NoTrainingData => write!(f, "no training triples provided"),
KgError::NumericalError(msg) => write!(f, "numerical error: {msg}"),
KgError::InvalidTopK => write!(f, "top_k must be > 0"),
}
}
}
impl std::error::Error for KgError {}
pub type KgResult<T> = Result<T, KgError>;
pub type EntityId = usize;
pub type RelationId = usize;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct KgTriple {
pub head: EntityId,
pub relation: RelationId,
pub tail: EntityId,
}
impl KgTriple {
pub fn new(head: EntityId, relation: RelationId, tail: EntityId) -> Self {
Self {
head,
relation,
tail,
}
}
}
#[derive(Debug, Clone)]
pub struct KgEmbeddingConfig {
pub embedding_dim: usize,
pub learning_rate: f64,
pub num_epochs: usize,
pub batch_size: usize,
pub neg_samples: usize,
pub margin: f64,
pub regularization: f64,
pub seed: u64,
}
impl Default for KgEmbeddingConfig {
fn default() -> Self {
Self {
embedding_dim: 50,
learning_rate: 0.01,
num_epochs: 100,
batch_size: 32,
neg_samples: 1,
margin: 1.0,
regularization: 1e-4,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct KgEmbeddings {
pub entity_embeddings: Vec<Vec<f64>>,
pub relation_embeddings: Vec<Vec<f64>>,
pub entity_to_id: HashMap<String, EntityId>,
pub relation_to_id: HashMap<String, RelationId>,
}
#[derive(Debug, Clone)]
pub struct TrainingHistory {
pub losses: Vec<f64>,
pub final_loss: f64,
pub epochs_trained: usize,
}
pub trait KgModel {
fn score(&self, triple: &KgTriple) -> KgResult<f64>;
fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>>;
fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>>;
}
#[derive(Debug, Clone)]
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.state >> 11) as f64 / (1u64 << 53) as f64
}
fn next_usize(&mut self, n: usize) -> usize {
(self.next_f64() * n as f64) as usize % n
}
}
fn l2_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn l2_dist(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn clamp_vec(v: &mut [f64], lo: f64, hi: f64) {
for x in v.iter_mut() {
*x = x.clamp(lo, hi);
}
}
fn normalize_vec(v: &mut [f64]) {
let norm = l2_norm(v);
if norm > 1e-12 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
fn corrupt_triple(
triple: &KgTriple,
num_entities: usize,
positive_set: &std::collections::HashSet<(usize, usize, usize)>,
rng: &mut Lcg,
) -> KgTriple {
for _ in 0..20 {
let corrupt_head = rng.next_usize(2) == 0;
let candidate = if corrupt_head {
let new_head = rng.next_usize(num_entities);
KgTriple::new(new_head, triple.relation, triple.tail)
} else {
let new_tail = rng.next_usize(num_entities);
KgTriple::new(triple.head, triple.relation, new_tail)
};
if !positive_set.contains(&(candidate.head, candidate.relation, candidate.tail)) {
return candidate;
}
}
let new_tail = (triple.tail + 1) % num_entities;
KgTriple::new(triple.head, triple.relation, new_tail)
}
#[derive(Debug, Clone)]
pub struct TransE {
pub config: KgEmbeddingConfig,
pub embeddings: Option<KgEmbeddings>,
num_entities: usize,
num_relations: usize,
}
impl TransE {
pub fn new(config: KgEmbeddingConfig) -> Self {
Self {
config,
embeddings: None,
num_entities: 0,
num_relations: 0,
}
}
pub fn train(
&mut self,
triples: &[KgTriple],
num_entities: usize,
num_relations: usize,
) -> KgResult<TrainingHistory> {
if triples.is_empty() {
return Err(KgError::NoTrainingData);
}
if self.config.embedding_dim == 0 {
return Err(KgError::InvalidDimension);
}
self.num_entities = num_entities;
self.num_relations = num_relations;
let dim = self.config.embedding_dim;
let mut rng = Lcg::new(self.config.seed);
let bound = 6.0 / (dim as f64).sqrt();
let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
.map(|_| {
let mut v: Vec<f64> = (0..dim)
.map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
.collect();
normalize_vec(&mut v);
v
})
.collect();
let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
.map(|_| {
(0..dim)
.map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
.collect()
})
.collect();
let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
.iter()
.map(|t| (t.head, t.relation, t.tail))
.collect();
let lr = self.config.learning_rate;
let margin = self.config.margin;
let reg = self.config.regularization;
let mut losses = Vec::with_capacity(self.config.num_epochs);
for _epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0_f64;
let mut count = 0usize;
for pos in triples {
for _ in 0..self.config.neg_samples {
let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
let h_pos = &ent_emb[pos.head];
let r = &rel_emb[pos.relation];
let t_pos = &ent_emb[pos.tail];
let h_neg = &ent_emb[neg.head];
let t_neg = &ent_emb[neg.tail];
let pos_diff: Vec<f64> = (0..dim).map(|i| h_pos[i] + r[i] - t_pos[i]).collect();
let neg_diff: Vec<f64> = (0..dim).map(|i| h_neg[i] + r[i] - t_neg[i]).collect();
let d_pos = l2_norm(&pos_diff);
let d_neg = l2_norm(&neg_diff);
let loss = (margin + d_pos - d_neg).max(0.0);
epoch_loss += loss;
count += 1;
if loss > 0.0 {
let grad_pos: Vec<f64> = if d_pos > 1e-12 {
pos_diff.iter().map(|x| x / d_pos).collect()
} else {
vec![0.0; dim]
};
let grad_neg: Vec<f64> = if d_neg > 1e-12 {
neg_diff.iter().map(|x| x / d_neg).collect()
} else {
vec![0.0; dim]
};
for i in 0..dim {
let g = grad_pos[i];
ent_emb[pos.head][i] -= lr * (g + reg * ent_emb[pos.head][i]);
rel_emb[pos.relation][i] -= lr * (g + reg * rel_emb[pos.relation][i]);
ent_emb[pos.tail][i] += lr * (g - reg * ent_emb[pos.tail][i]);
}
for i in 0..dim {
let g = grad_neg[i];
ent_emb[neg.head][i] += lr * (g + reg * ent_emb[neg.head][i]);
ent_emb[neg.tail][i] -= lr * (g - reg * ent_emb[neg.tail][i]);
}
}
normalize_vec(&mut ent_emb[pos.head]);
normalize_vec(&mut ent_emb[pos.tail]);
normalize_vec(&mut ent_emb[neg.head]);
normalize_vec(&mut ent_emb[neg.tail]);
}
}
let mean_loss = if count > 0 {
epoch_loss / count as f64
} else {
0.0
};
losses.push(mean_loss);
}
let final_loss = losses.last().copied().unwrap_or(0.0);
let epochs_trained = losses.len();
self.embeddings = Some(KgEmbeddings {
entity_embeddings: ent_emb,
relation_embeddings: rel_emb,
entity_to_id: HashMap::new(),
relation_to_id: HashMap::new(),
});
Ok(TrainingHistory {
losses,
final_loss,
epochs_trained,
})
}
pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let h = emb
.entity_embeddings
.get(triple.head)
.ok_or(KgError::UnknownEntity(triple.head))?;
let r = emb
.relation_embeddings
.get(triple.relation)
.ok_or(KgError::UnknownRelation(triple.relation))?;
let t = emb
.entity_embeddings
.get(triple.tail)
.ok_or(KgError::UnknownEntity(triple.tail))?;
Ok(-Self::score_fn(h, r, t))
}
pub fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let h = emb
.entity_embeddings
.get(head)
.ok_or(KgError::UnknownEntity(head))?;
let r = emb
.relation_embeddings
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let mut scored: Vec<(EntityId, f64)> = emb
.entity_embeddings
.iter()
.enumerate()
.map(|(id, t)| (id, -Self::score_fn(h, r, t)))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
pub fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let r = emb
.relation_embeddings
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let t = emb
.entity_embeddings
.get(tail)
.ok_or(KgError::UnknownEntity(tail))?;
let mut scored: Vec<(EntityId, f64)> = emb
.entity_embeddings
.iter()
.enumerate()
.map(|(id, h)| (id, -Self::score_fn(h, r, t)))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
pub fn normalize_entities(&mut self) {
if let Some(ref mut emb) = self.embeddings {
for v in emb.entity_embeddings.iter_mut() {
normalize_vec(v);
}
}
}
fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
let diff: Vec<f64> = (0..h.len()).map(|i| h[i] + r[i] - t[i]).collect();
l2_norm(&diff)
}
}
impl KgModel for TransE {
fn score(&self, triple: &KgTriple) -> KgResult<f64> {
self.score(triple)
}
fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_tail(head, relation, top_k)
}
fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_head(relation, tail, top_k)
}
}
#[derive(Debug, Clone)]
pub struct DistMult {
pub config: KgEmbeddingConfig,
pub embeddings: Option<KgEmbeddings>,
num_entities: usize,
num_relations: usize,
}
impl DistMult {
pub fn new(config: KgEmbeddingConfig) -> Self {
Self {
config,
embeddings: None,
num_entities: 0,
num_relations: 0,
}
}
pub fn train(
&mut self,
triples: &[KgTriple],
num_entities: usize,
num_relations: usize,
) -> KgResult<TrainingHistory> {
if triples.is_empty() {
return Err(KgError::NoTrainingData);
}
if self.config.embedding_dim == 0 {
return Err(KgError::InvalidDimension);
}
self.num_entities = num_entities;
self.num_relations = num_relations;
let dim = self.config.embedding_dim;
let mut rng = Lcg::new(self.config.seed);
let bound = 1.0 / (dim as f64).sqrt();
let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
.map(|_| {
(0..dim)
.map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
.collect()
})
.collect();
let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
.map(|_| {
(0..dim)
.map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
.collect()
})
.collect();
let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
.iter()
.map(|t| (t.head, t.relation, t.tail))
.collect();
let lr = self.config.learning_rate;
let reg = self.config.regularization;
let mut losses = Vec::with_capacity(self.config.num_epochs);
for _epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0_f64;
let mut count = 0usize;
for pos in triples {
{
let s = Self::score_fn(
&ent_emb[pos.head],
&rel_emb[pos.relation],
&ent_emb[pos.tail],
);
let sig = sigmoid(s);
let loss = -sig.ln().max(-100.0);
epoch_loss += loss;
count += 1;
let g = -(1.0 - sig);
for i in 0..dim {
let h_i = ent_emb[pos.head][i];
let r_i = rel_emb[pos.relation][i];
let t_i = ent_emb[pos.tail][i];
ent_emb[pos.head][i] -= lr * (g * r_i * t_i + reg * h_i);
rel_emb[pos.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
ent_emb[pos.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
}
clamp_vec(&mut ent_emb[pos.head], -10.0, 10.0);
clamp_vec(&mut rel_emb[pos.relation], -10.0, 10.0);
clamp_vec(&mut ent_emb[pos.tail], -10.0, 10.0);
}
for _ in 0..self.config.neg_samples {
let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
let s = Self::score_fn(
&ent_emb[neg.head],
&rel_emb[neg.relation],
&ent_emb[neg.tail],
);
let sig = sigmoid(-s);
let loss = -sig.ln().max(-100.0);
epoch_loss += loss;
count += 1;
let g = 1.0 - sig; for i in 0..dim {
let h_i = ent_emb[neg.head][i];
let r_i = rel_emb[neg.relation][i];
let t_i = ent_emb[neg.tail][i];
ent_emb[neg.head][i] -= lr * (g * r_i * t_i + reg * h_i);
rel_emb[neg.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
ent_emb[neg.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
}
clamp_vec(&mut ent_emb[neg.head], -10.0, 10.0);
clamp_vec(&mut ent_emb[neg.tail], -10.0, 10.0);
}
}
let mean_loss = if count > 0 {
epoch_loss / count as f64
} else {
0.0
};
losses.push(mean_loss);
}
let final_loss = losses.last().copied().unwrap_or(0.0);
let epochs_trained = losses.len();
self.embeddings = Some(KgEmbeddings {
entity_embeddings: ent_emb,
relation_embeddings: rel_emb,
entity_to_id: HashMap::new(),
relation_to_id: HashMap::new(),
});
Ok(TrainingHistory {
losses,
final_loss,
epochs_trained,
})
}
pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let h = emb
.entity_embeddings
.get(triple.head)
.ok_or(KgError::UnknownEntity(triple.head))?;
let r = emb
.relation_embeddings
.get(triple.relation)
.ok_or(KgError::UnknownRelation(triple.relation))?;
let t = emb
.entity_embeddings
.get(triple.tail)
.ok_or(KgError::UnknownEntity(triple.tail))?;
Ok(Self::score_fn(h, r, t))
}
pub fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let h = emb
.entity_embeddings
.get(head)
.ok_or(KgError::UnknownEntity(head))?;
let r = emb
.relation_embeddings
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let mut scored: Vec<(EntityId, f64)> = emb
.entity_embeddings
.iter()
.enumerate()
.map(|(id, t)| (id, Self::score_fn(h, r, t)))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
pub fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
let r = emb
.relation_embeddings
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let t = emb
.entity_embeddings
.get(tail)
.ok_or(KgError::UnknownEntity(tail))?;
let mut scored: Vec<(EntityId, f64)> = emb
.entity_embeddings
.iter()
.enumerate()
.map(|(id, h)| (id, Self::score_fn(h, r, t)))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
h.iter()
.zip(r.iter())
.zip(t.iter())
.map(|((hi, ri), ti)| hi * ri * ti)
.sum()
}
}
impl KgModel for DistMult {
fn score(&self, triple: &KgTriple) -> KgResult<f64> {
self.score(triple)
}
fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_tail(head, relation, top_k)
}
fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_head(relation, tail, top_k)
}
}
#[derive(Debug, Clone)]
pub struct RotatE {
pub config: KgEmbeddingConfig,
pub entity_re: Option<Vec<Vec<f64>>>,
pub entity_im: Option<Vec<Vec<f64>>>,
pub relation_phases: Option<Vec<Vec<f64>>>,
num_entities: usize,
num_relations: usize,
}
impl RotatE {
pub fn new(config: KgEmbeddingConfig) -> Self {
Self {
config,
entity_re: None,
entity_im: None,
relation_phases: None,
num_entities: 0,
num_relations: 0,
}
}
pub fn train(
&mut self,
triples: &[KgTriple],
num_entities: usize,
num_relations: usize,
) -> KgResult<TrainingHistory> {
if triples.is_empty() {
return Err(KgError::NoTrainingData);
}
if self.config.embedding_dim == 0 {
return Err(KgError::InvalidDimension);
}
self.num_entities = num_entities;
self.num_relations = num_relations;
let half_dim = (self.config.embedding_dim + 1) / 2;
let mut rng = Lcg::new(self.config.seed);
let pi = std::f64::consts::PI;
let mut ent_re: Vec<Vec<f64>> = (0..num_entities)
.map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
.collect();
let mut ent_im: Vec<Vec<f64>> = (0..num_entities)
.map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
.collect();
for i in 0..num_entities {
for k in 0..half_dim {
let norm = (ent_re[i][k].powi(2) + ent_im[i][k].powi(2))
.sqrt()
.max(1e-12);
ent_re[i][k] /= norm;
ent_im[i][k] /= norm;
}
}
let mut rel_phases: Vec<Vec<f64>> = (0..num_relations)
.map(|_| {
(0..half_dim)
.map(|_| (rng.next_f64() * 2.0 - 1.0) * pi)
.collect()
})
.collect();
let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
.iter()
.map(|t| (t.head, t.relation, t.tail))
.collect();
let lr = self.config.learning_rate;
let margin = self.config.margin;
let reg = self.config.regularization;
let mut losses = Vec::with_capacity(self.config.num_epochs);
for _epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0_f64;
let mut count = 0usize;
for pos in triples {
for _ in 0..self.config.neg_samples {
let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
let d_pos = Self::dist_fn(
&ent_re[pos.head],
&ent_im[pos.head],
&rel_phases[pos.relation],
&ent_re[pos.tail],
&ent_im[pos.tail],
);
let d_neg = Self::dist_fn(
&ent_re[neg.head],
&ent_im[neg.head],
&rel_phases[neg.relation],
&ent_re[neg.tail],
&ent_im[neg.tail],
);
let loss = (margin + d_pos - d_neg).max(0.0);
epoch_loss += loss;
count += 1;
if loss > 0.0 && d_pos > 1e-12 {
let r_re: Vec<f64> = rel_phases[pos.relation]
.iter()
.map(|&ph| ph.cos())
.collect();
let r_im: Vec<f64> = rel_phases[pos.relation]
.iter()
.map(|&ph| ph.sin())
.collect();
for k in 0..half_dim {
let (res_re, res_im) = Self::complex_multiply(
ent_re[pos.head][k],
ent_im[pos.head][k],
r_re[k],
r_im[k],
);
let err_re = res_re - ent_re[pos.tail][k];
let err_im = res_im - ent_im[pos.tail][k];
let g_scale = 1.0 / d_pos;
let d_h_re = g_scale * (err_re * r_re[k] + err_im * r_im[k]);
let d_h_im = g_scale * (err_im * r_re[k] - err_re * r_im[k]);
let d_ph = g_scale
* ((-ent_re[pos.head][k] * r_im[k]
+ ent_im[pos.head][k] * r_re[k])
* err_re
+ (-ent_re[pos.head][k] * r_re[k]
- ent_im[pos.head][k] * r_im[k])
* err_im);
let d_t_re = g_scale * (-err_re);
let d_t_im = g_scale * (-err_im);
ent_re[pos.head][k] -= lr * (d_h_re + reg * ent_re[pos.head][k]);
ent_im[pos.head][k] -= lr * (d_h_im + reg * ent_im[pos.head][k]);
rel_phases[pos.relation][k] -=
lr * (d_ph + reg * rel_phases[pos.relation][k]);
ent_re[pos.tail][k] -= lr * (d_t_re + reg * ent_re[pos.tail][k]);
ent_im[pos.tail][k] -= lr * (d_t_im + reg * ent_im[pos.tail][k]);
}
for ph in rel_phases[pos.relation].iter_mut() {
*ph = ph.clamp(-2.0 * pi, 2.0 * pi);
}
}
}
}
let mean_loss = if count > 0 {
epoch_loss / count as f64
} else {
0.0
};
losses.push(mean_loss);
}
let final_loss = losses.last().copied().unwrap_or(0.0);
let epochs_trained = losses.len();
self.entity_re = Some(ent_re);
self.entity_im = Some(ent_im);
self.relation_phases = Some(rel_phases);
Ok(TrainingHistory {
losses,
final_loss,
epochs_trained,
})
}
pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
let h_re = ent_re
.get(triple.head)
.ok_or(KgError::UnknownEntity(triple.head))?;
let h_im = ent_im
.get(triple.head)
.ok_or(KgError::UnknownEntity(triple.head))?;
let ph = phases
.get(triple.relation)
.ok_or(KgError::UnknownRelation(triple.relation))?;
let t_re = ent_re
.get(triple.tail)
.ok_or(KgError::UnknownEntity(triple.tail))?;
let t_im = ent_im
.get(triple.tail)
.ok_or(KgError::UnknownEntity(triple.tail))?;
Ok(-Self::dist_fn(h_re, h_im, ph, t_re, t_im))
}
pub fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
let h_re = ent_re.get(head).ok_or(KgError::UnknownEntity(head))?;
let h_im = ent_im.get(head).ok_or(KgError::UnknownEntity(head))?;
let ph = phases
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let num = ent_re.len();
let mut scored: Vec<(EntityId, f64)> = (0..num)
.map(|id| (id, -Self::dist_fn(h_re, h_im, ph, &ent_re[id], &ent_im[id])))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
pub fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
if top_k == 0 {
return Err(KgError::InvalidTopK);
}
let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
let ph = phases
.get(relation)
.ok_or(KgError::UnknownRelation(relation))?;
let t_re = ent_re.get(tail).ok_or(KgError::UnknownEntity(tail))?;
let t_im = ent_im.get(tail).ok_or(KgError::UnknownEntity(tail))?;
let num = ent_re.len();
let mut scored: Vec<(EntityId, f64)> = (0..num)
.map(|id| (id, -Self::dist_fn(&ent_re[id], &ent_im[id], ph, t_re, t_im)))
.collect();
scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
fn dist_fn(h_re: &[f64], h_im: &[f64], phases: &[f64], t_re: &[f64], t_im: &[f64]) -> f64 {
let sum_sq: f64 = phases
.iter()
.enumerate()
.map(|(k, &ph)| {
let (res_re, res_im) = Self::complex_multiply(h_re[k], h_im[k], ph.cos(), ph.sin());
(res_re - t_re[k]).powi(2) + (res_im - t_im[k]).powi(2)
})
.sum();
sum_sq.sqrt()
}
pub fn complex_multiply(a_re: f64, a_im: f64, b_re: f64, b_im: f64) -> (f64, f64) {
(a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re)
}
}
impl KgModel for RotatE {
fn score(&self, triple: &KgTriple) -> KgResult<f64> {
self.score(triple)
}
fn predict_tail(
&self,
head: EntityId,
relation: RelationId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_tail(head, relation, top_k)
}
fn predict_head(
&self,
relation: RelationId,
tail: EntityId,
top_k: usize,
) -> KgResult<Vec<(EntityId, f64)>> {
self.predict_head(relation, tail, top_k)
}
}
pub struct LinkPredictionEvaluator;
impl LinkPredictionEvaluator {
pub fn hits_at_k(model: &dyn KgModel, test_triples: &[KgTriple], k: usize) -> f64 {
if test_triples.is_empty() || k == 0 {
return 0.0;
}
let hits: usize = test_triples
.iter()
.filter(|t| {
model
.predict_tail(t.head, t.relation, k)
.map(|preds| preds.iter().any(|(eid, _)| *eid == t.tail))
.unwrap_or(false)
})
.count();
hits as f64 / test_triples.len() as f64
}
pub fn mean_rank(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
if test_triples.is_empty() {
return 0.0;
}
let total: usize = test_triples
.iter()
.map(|t| {
model
.predict_tail(t.head, t.relation, num_entities)
.map(|preds| {
preds
.iter()
.position(|(eid, _)| *eid == t.tail)
.map(|p| p + 1)
.unwrap_or(num_entities + 1)
})
.unwrap_or(num_entities + 1)
})
.sum();
total as f64 / test_triples.len() as f64
}
pub fn mrr(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
if test_triples.is_empty() {
return 0.0;
}
let sum: f64 = test_triples
.iter()
.map(|t| {
model
.predict_tail(t.head, t.relation, num_entities)
.map(|preds| {
preds
.iter()
.position(|(eid, _)| *eid == t.tail)
.map(|p| 1.0 / (p as f64 + 1.0))
.unwrap_or(0.0)
})
.unwrap_or(0.0)
})
.sum();
sum / test_triples.len() as f64
}
}
pub fn serialize_embeddings(emb: &KgEmbeddings) -> Vec<u8> {
let mut out = String::new();
out.push_str(&format!("ENTITIES {}\n", emb.entity_embeddings.len()));
for row in &emb.entity_embeddings {
let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
out.push_str(&line.join(","));
out.push('\n');
}
out.push_str(&format!("RELATIONS {}\n", emb.relation_embeddings.len()));
for row in &emb.relation_embeddings {
let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
out.push_str(&line.join(","));
out.push('\n');
}
out.into_bytes()
}
pub fn deserialize_embeddings(data: &[u8]) -> Result<KgEmbeddings, KgError> {
let text = std::str::from_utf8(data)
.map_err(|e| KgError::NumericalError(format!("utf8 error: {e}")))?;
let mut lines = text.lines();
let parse_section_header = |line: &str, prefix: &str| -> Result<usize, KgError> {
let rest = line
.strip_prefix(prefix)
.ok_or_else(|| KgError::NumericalError(format!("expected '{prefix}', got '{line}'")))?;
rest.trim()
.parse::<usize>()
.map_err(|e| KgError::NumericalError(e.to_string()))
};
let parse_row = |line: &str| -> Result<Vec<f64>, KgError> {
line.split(',')
.map(|s| {
s.trim()
.parse::<f64>()
.map_err(|e| KgError::NumericalError(e.to_string()))
})
.collect()
};
let ent_header = lines
.next()
.ok_or(KgError::NumericalError("empty data".into()))?;
let num_ent = parse_section_header(ent_header, "ENTITIES ")?;
let mut entity_embeddings = Vec::with_capacity(num_ent);
for _ in 0..num_ent {
let line = lines
.next()
.ok_or(KgError::NumericalError("truncated entity data".into()))?;
entity_embeddings.push(parse_row(line)?);
}
let rel_header = lines
.next()
.ok_or(KgError::NumericalError("missing RELATIONS header".into()))?;
let num_rel = parse_section_header(rel_header, "RELATIONS ")?;
let mut relation_embeddings = Vec::with_capacity(num_rel);
for _ in 0..num_rel {
let line = lines
.next()
.ok_or(KgError::NumericalError("truncated relation data".into()))?;
relation_embeddings.push(parse_row(line)?);
}
Ok(KgEmbeddings {
entity_embeddings,
relation_embeddings,
entity_to_id: HashMap::new(),
relation_to_id: HashMap::new(),
})
}
#[inline]
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
#[path = "kg_embeddings_tests.rs"]
mod tests;