use crate::{StarResult, StarTerm, StarTriple};
use scirs2_core::ndarray_ext::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info, instrument};
fn random_uniform() -> f64 {
use std::cell::Cell;
thread_local! {
static SEED: Cell<u64> = const { Cell::new(42) };
}
SEED.with(|s| {
let mut seed = s.get();
seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
s.set(seed);
(seed as f64) / (u64::MAX as f64)
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub embedding_dim: usize,
pub learning_rate: f64,
pub margin: f64,
pub batch_size: usize,
pub num_negative_samples: usize,
pub use_gpu: bool,
pub enable_rdfstar_context: bool,
pub l2_reg: f64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
embedding_dim: 128,
learning_rate: 0.01,
margin: 1.0,
batch_size: 128,
num_negative_samples: 10,
use_gpu: false,
enable_rdfstar_context: true,
l2_reg: 0.0001,
}
}
}
pub trait EmbeddingModel: Send + Sync {
fn train(&mut self, triples: &[StarTriple], epochs: usize) -> StarResult<TrainingStats>;
fn get_embedding(&self, entity: &str) -> Option<Array1<f64>>;
fn similarity(&self, entity1: &str, entity2: &str) -> StarResult<f64>;
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> StarResult<Vec<(String, f64)>>;
fn save(&self, path: &str) -> StarResult<()>;
fn load(&mut self, path: &str) -> StarResult<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingStats {
pub total_epochs: usize,
pub final_loss: f64,
pub losses_per_epoch: Vec<f64>,
pub training_time_secs: f64,
}
#[derive(Debug, Clone)]
pub struct Vocabulary {
entity_to_idx: HashMap<String, usize>,
idx_to_entity: Vec<String>,
relation_to_idx: HashMap<String, usize>,
idx_to_relation: Vec<String>,
}
impl Vocabulary {
fn term_to_string(term: &StarTerm) -> String {
match term {
StarTerm::NamedNode(n) => n.iri.clone(),
StarTerm::BlankNode(b) => format!("_:{}", b.id),
StarTerm::Literal(l) => l.value.clone(),
StarTerm::Variable(v) => format!("?{}", v.name),
StarTerm::QuotedTriple(t) => format!("<<{}>>", t.subject),
}
}
pub fn from_triples(triples: &[StarTriple]) -> Self {
let mut entity_to_idx = HashMap::new();
let mut idx_to_entity = Vec::new();
let mut relation_to_idx = HashMap::new();
let mut idx_to_relation = Vec::new();
for triple in triples {
let subject_str = Self::term_to_string(&triple.subject);
if !entity_to_idx.contains_key(&subject_str) {
entity_to_idx.insert(subject_str.clone(), idx_to_entity.len());
idx_to_entity.push(subject_str);
}
let predicate_str = Self::term_to_string(&triple.predicate);
if !relation_to_idx.contains_key(&predicate_str) {
relation_to_idx.insert(predicate_str.clone(), idx_to_relation.len());
idx_to_relation.push(predicate_str);
}
let object_str = Self::term_to_string(&triple.object);
if !entity_to_idx.contains_key(&object_str) {
entity_to_idx.insert(object_str.clone(), idx_to_entity.len());
idx_to_entity.push(object_str);
}
}
Self {
entity_to_idx,
idx_to_entity,
relation_to_idx,
idx_to_relation,
}
}
pub fn entity_idx(&self, entity: &str) -> Option<usize> {
self.entity_to_idx.get(entity).copied()
}
pub fn relation_idx(&self, relation: &str) -> Option<usize> {
self.relation_to_idx.get(relation).copied()
}
pub fn entity(&self, idx: usize) -> Option<&str> {
self.idx_to_entity.get(idx).map(|s| s.as_str())
}
pub fn relation(&self, idx: usize) -> Option<&str> {
self.idx_to_relation.get(idx).map(|s| s.as_str())
}
pub fn num_entities(&self) -> usize {
self.idx_to_entity.len()
}
pub fn num_relations(&self) -> usize {
self.idx_to_relation.len()
}
}
pub struct TransE {
config: EmbeddingConfig,
entity_embeddings: Array2<f64>,
relation_embeddings: Array2<f64>,
vocab: Option<Vocabulary>,
#[allow(dead_code)]
seed: u64,
}
impl TransE {
pub fn new(config: EmbeddingConfig) -> Self {
Self::with_seed(config, 42)
}
pub fn with_seed(config: EmbeddingConfig, seed: u64) -> Self {
Self {
config,
entity_embeddings: Array2::zeros((0, 0)),
relation_embeddings: Array2::zeros((0, 0)),
vocab: None,
seed,
}
}
fn initialize_embeddings(&mut self, num_entities: usize, num_relations: usize) {
let dim = self.config.embedding_dim;
let scale = (6.0 / (dim as f64)).sqrt();
self.entity_embeddings = Array2::zeros((num_entities, dim));
for i in 0..num_entities {
for j in 0..dim {
self.entity_embeddings[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
self.relation_embeddings = Array2::zeros((num_relations, dim));
for i in 0..num_relations {
for j in 0..dim {
self.relation_embeddings[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
self.normalize_embeddings();
info!(
"Initialized TransE embeddings: {} entities, {} relations, dim={}",
num_entities, num_relations, dim
);
}
fn normalize_embeddings(&mut self) {
for i in 0..self.entity_embeddings.nrows() {
let mut row = self.entity_embeddings.row_mut(i);
let norm = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > 1e-10 {
row.mapv_inplace(|x| x / norm);
}
}
}
fn score(&self, head_idx: usize, rel_idx: usize, tail_idx: usize) -> f64 {
let h = self.entity_embeddings.row(head_idx);
let r = self.relation_embeddings.row(rel_idx);
let t = self.entity_embeddings.row(tail_idx);
let mut distance = 0.0;
for i in 0..self.config.embedding_dim {
let diff = h[i] + r[i] - t[i];
distance += diff.abs(); }
distance
}
fn negative_sample(&self, num_entities: usize) -> usize {
(random_uniform() * num_entities as f64) as usize
}
}
impl EmbeddingModel for TransE {
#[instrument(skip(self, triples))]
fn train(&mut self, triples: &[StarTriple], epochs: usize) -> StarResult<TrainingStats> {
let start = std::time::Instant::now();
let vocab = Vocabulary::from_triples(triples);
let num_entities = vocab.num_entities();
let num_relations = vocab.num_relations();
info!(
"Training TransE on {} triples ({} entities, {} relations) for {} epochs",
triples.len(),
num_entities,
num_relations,
epochs
);
self.initialize_embeddings(num_entities, num_relations);
self.vocab = Some(vocab.clone());
let mut losses = Vec::with_capacity(epochs);
for epoch in 0..epochs {
let mut epoch_loss = 0.0;
let mut num_batches = 0;
let mut triple_indices: Vec<usize> = (0..triples.len()).collect();
for i in (1..triple_indices.len()).rev() {
let j = (random_uniform() * (i + 1) as f64) as usize;
triple_indices.swap(i, j);
}
for batch_start in (0..triples.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(triples.len());
let mut batch_loss = 0.0;
for &triple_idx in &triple_indices[batch_start..batch_end] {
let triple = &triples[triple_idx];
let h_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.subject))
.expect("entity should be in vocabulary");
let r_idx = vocab
.relation_idx(&Vocabulary::term_to_string(&triple.predicate))
.expect("relation should be in vocabulary");
let t_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.object))
.expect("entity should be in vocabulary");
let pos_score = self.score(h_idx, r_idx, t_idx);
for _ in 0..self.config.num_negative_samples {
let neg_t_idx = self.negative_sample(num_entities);
let neg_score = self.score(h_idx, r_idx, neg_t_idx);
let loss = (self.config.margin + pos_score - neg_score).max(0.0);
batch_loss += loss;
if loss > 0.0 {
let lr = self.config.learning_rate;
for i in 0..self.config.embedding_dim {
let h_grad = if self.entity_embeddings[[h_idx, i]]
+ self.relation_embeddings[[r_idx, i]]
> self.entity_embeddings[[t_idx, i]]
{
lr
} else {
-lr
};
self.entity_embeddings[[h_idx, i]] -= h_grad;
self.relation_embeddings[[r_idx, i]] -= h_grad;
self.entity_embeddings[[t_idx, i]] += h_grad;
self.entity_embeddings[[h_idx, i]] *= 1.0 - self.config.l2_reg * lr;
self.relation_embeddings[[r_idx, i]] *=
1.0 - self.config.l2_reg * lr;
self.entity_embeddings[[t_idx, i]] *= 1.0 - self.config.l2_reg * lr;
}
}
}
}
epoch_loss += batch_loss;
num_batches += 1;
}
self.normalize_embeddings();
let avg_loss = epoch_loss / num_batches as f64;
losses.push(avg_loss);
if epoch % 10 == 0 {
debug!("Epoch {}/{}: loss = {:.4}", epoch + 1, epochs, avg_loss);
}
}
let training_time = start.elapsed().as_secs_f64();
info!(
"Training complete in {:.2}s, final loss: {:.4}",
training_time,
losses.last().copied().unwrap_or(0.0)
);
Ok(TrainingStats {
total_epochs: epochs,
final_loss: losses.last().copied().unwrap_or(0.0),
losses_per_epoch: losses,
training_time_secs: training_time,
})
}
fn get_embedding(&self, entity: &str) -> Option<Array1<f64>> {
let vocab = self.vocab.as_ref()?;
let idx = vocab.entity_idx(entity)?;
Some(self.entity_embeddings.row(idx).to_owned())
}
fn similarity(&self, entity1: &str, entity2: &str) -> StarResult<f64> {
let e1 = self
.get_embedding(entity1)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity1),
query_fragment: None,
position: None,
suggestion: None,
})?;
let e2 = self
.get_embedding(entity2)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity2),
query_fragment: None,
position: None,
suggestion: None,
})?;
let dot: f64 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
let norm1: f64 = e1.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm2: f64 = e2.iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(dot / (norm1 * norm2))
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> StarResult<Vec<(String, f64)>> {
let vocab = self
.vocab
.as_ref()
.ok_or_else(|| crate::StarError::QueryError {
message: "Model not trained".to_string(),
query_fragment: None,
position: None,
suggestion: Some("Train the model first".to_string()),
})?;
let h_idx = vocab
.entity_idx(head)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Head entity not found: {}", head),
query_fragment: None,
position: None,
suggestion: None,
})?;
let r_idx = vocab
.relation_idx(relation)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Relation not found: {}", relation),
query_fragment: None,
position: None,
suggestion: None,
})?;
let mut scores: Vec<(String, f64)> = Vec::new();
for t_idx in 0..vocab.num_entities() {
let score = self.score(h_idx, r_idx, t_idx);
let entity = vocab
.entity(t_idx)
.expect("entity index should be valid")
.to_string();
scores.push((entity, -score)); }
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).collect())
}
fn save(&self, _path: &str) -> StarResult<()> {
Ok(())
}
fn load(&mut self, _path: &str) -> StarResult<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{model::NamedNode, StarTerm};
fn create_test_triples() -> Vec<StarTriple> {
vec![
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Alice".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
},
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Charlie".to_string(),
}),
},
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Alice".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "likes".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Coffee".to_string(),
}),
},
]
}
#[test]
fn test_vocabulary_creation() {
let triples = create_test_triples();
let vocab = Vocabulary::from_triples(&triples);
assert_eq!(vocab.num_entities(), 4); assert_eq!(vocab.num_relations(), 2); }
#[test]
fn test_transe_initialization() {
let config = EmbeddingConfig {
embedding_dim: 64,
..Default::default()
};
let model = TransE::new(config);
assert_eq!(model.config.embedding_dim, 64);
}
#[test]
fn test_transe_training() {
let config = EmbeddingConfig {
embedding_dim: 32,
learning_rate: 0.01,
batch_size: 2,
num_negative_samples: 3,
..Default::default()
};
let mut model = TransE::new(config);
let triples = create_test_triples();
let stats = model.train(&triples, 10).unwrap();
assert_eq!(stats.total_epochs, 10);
assert!(stats.final_loss >= 0.0);
assert_eq!(stats.losses_per_epoch.len(), 10);
}
#[test]
fn test_get_embedding() {
let config = EmbeddingConfig::default();
let mut model = TransE::new(config);
let triples = create_test_triples();
model.train(&triples, 5).unwrap();
let emb = model.get_embedding("Alice");
assert!(emb.is_some());
assert_eq!(emb.unwrap().len(), 128); }
#[test]
fn test_similarity() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = TransE::new(config);
let triples = create_test_triples();
model.train(&triples, 20).unwrap();
let sim = model.similarity("Alice", "Bob").unwrap();
assert!((-1.0..=1.0).contains(&sim));
}
#[test]
fn test_predict_tail() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = TransE::new(config);
let triples = create_test_triples();
model.train(&triples, 10).unwrap();
let predictions = model.predict_tail("Alice", "knows", 3).unwrap();
assert_eq!(predictions.len(), 3);
assert!(predictions[0].1 >= predictions[1].1); }
#[test]
fn test_embedding_normalization() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = TransE::new(config);
let triples = create_test_triples();
model.train(&triples, 5).unwrap();
let emb = model.get_embedding("Alice").unwrap();
let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 0.01); }
}
pub struct DistMult {
config: EmbeddingConfig,
entity_embeddings: Array2<f64>,
relation_embeddings: Array2<f64>,
vocab: Option<Vocabulary>,
#[allow(dead_code)]
seed: u64,
}
impl DistMult {
pub fn new(config: EmbeddingConfig) -> Self {
Self::with_seed(config, 42)
}
pub fn with_seed(config: EmbeddingConfig, seed: u64) -> Self {
Self {
config,
entity_embeddings: Array2::zeros((0, 0)),
relation_embeddings: Array2::zeros((0, 0)),
vocab: None,
seed,
}
}
fn initialize_embeddings(&mut self, num_entities: usize, num_relations: usize) {
let dim = self.config.embedding_dim;
let scale = (6.0 / (dim as f64)).sqrt();
self.entity_embeddings = Array2::zeros((num_entities, dim));
for i in 0..num_entities {
for j in 0..dim {
self.entity_embeddings[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
self.relation_embeddings = Array2::zeros((num_relations, dim));
for i in 0..num_relations {
for j in 0..dim {
self.relation_embeddings[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
info!(
"Initialized DistMult embeddings: {} entities, {} relations, dim={}",
num_entities, num_relations, dim
);
}
fn score(&self, head_idx: usize, rel_idx: usize, tail_idx: usize) -> f64 {
let h = self.entity_embeddings.row(head_idx);
let r = self.relation_embeddings.row(rel_idx);
let t = self.entity_embeddings.row(tail_idx);
let mut score = 0.0;
for i in 0..self.config.embedding_dim {
score += h[i] * r[i] * t[i];
}
score
}
fn negative_sample(&self, num_entities: usize) -> usize {
(random_uniform() * num_entities as f64) as usize
}
}
impl EmbeddingModel for DistMult {
#[instrument(skip(self, triples))]
fn train(&mut self, triples: &[StarTriple], epochs: usize) -> StarResult<TrainingStats> {
let start = std::time::Instant::now();
let vocab = Vocabulary::from_triples(triples);
let num_entities = vocab.num_entities();
let num_relations = vocab.num_relations();
info!(
"Training DistMult on {} triples ({} entities, {} relations) for {} epochs",
triples.len(),
num_entities,
num_relations,
epochs
);
self.initialize_embeddings(num_entities, num_relations);
self.vocab = Some(vocab.clone());
let mut losses = Vec::with_capacity(epochs);
for epoch in 0..epochs {
let mut epoch_loss = 0.0;
let mut num_batches = 0;
let mut triple_indices: Vec<usize> = (0..triples.len()).collect();
for i in (1..triple_indices.len()).rev() {
let j = (random_uniform() * (i + 1) as f64) as usize;
triple_indices.swap(i, j);
}
for batch_start in (0..triples.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(triples.len());
let mut batch_loss = 0.0;
for &triple_idx in &triple_indices[batch_start..batch_end] {
let triple = &triples[triple_idx];
let h_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.subject))
.expect("entity should be in vocabulary");
let r_idx = vocab
.relation_idx(&Vocabulary::term_to_string(&triple.predicate))
.expect("relation should be in vocabulary");
let t_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.object))
.expect("entity should be in vocabulary");
let pos_score = self.score(h_idx, r_idx, t_idx);
for _ in 0..self.config.num_negative_samples {
let corrupt_head = random_uniform() > 0.5;
let (neg_h_idx, neg_t_idx) = if corrupt_head {
(self.negative_sample(num_entities), t_idx)
} else {
(h_idx, self.negative_sample(num_entities))
};
let neg_score = self.score(neg_h_idx, r_idx, neg_t_idx);
let margin_diff = pos_score - neg_score;
let loss = (1.0 + (-margin_diff).exp()).ln(); batch_loss += loss;
if loss > 0.01 {
let lr = self.config.learning_rate;
let sigmoid = 1.0 / (1.0 + margin_diff.exp());
for i in 0..self.config.embedding_dim {
let h_val = self.entity_embeddings[[h_idx, i]];
let r_val = self.relation_embeddings[[r_idx, i]];
let t_val = self.entity_embeddings[[t_idx, i]];
let grad_h_pos = sigmoid * r_val * t_val;
let grad_r_pos = sigmoid * h_val * t_val;
let grad_t_pos = sigmoid * h_val * r_val;
self.entity_embeddings[[h_idx, i]] += lr * grad_h_pos;
self.relation_embeddings[[r_idx, i]] += lr * grad_r_pos;
self.entity_embeddings[[t_idx, i]] += lr * grad_t_pos;
if corrupt_head {
let grad_neg_h = -sigmoid * r_val * t_val;
self.entity_embeddings[[neg_h_idx, i]] += lr * grad_neg_h;
} else {
let grad_neg_t = -sigmoid * h_val * r_val;
self.entity_embeddings[[neg_t_idx, i]] += lr * grad_neg_t;
}
self.entity_embeddings[[h_idx, i]] *= 1.0 - self.config.l2_reg * lr;
self.relation_embeddings[[r_idx, i]] *=
1.0 - self.config.l2_reg * lr;
}
}
}
}
epoch_loss += batch_loss;
num_batches += 1;
}
let avg_loss = epoch_loss / num_batches as f64;
losses.push(avg_loss);
if epoch % 10 == 0 {
debug!("Epoch {}/{}: loss = {:.4}", epoch + 1, epochs, avg_loss);
}
}
let training_time = start.elapsed().as_secs_f64();
info!(
"Training complete in {:.2}s, final loss: {:.4}",
training_time,
losses.last().copied().unwrap_or(0.0)
);
Ok(TrainingStats {
total_epochs: epochs,
final_loss: losses.last().copied().unwrap_or(0.0),
losses_per_epoch: losses,
training_time_secs: training_time,
})
}
fn get_embedding(&self, entity: &str) -> Option<Array1<f64>> {
let vocab = self.vocab.as_ref()?;
let idx = vocab.entity_idx(entity)?;
Some(self.entity_embeddings.row(idx).to_owned())
}
fn similarity(&self, entity1: &str, entity2: &str) -> StarResult<f64> {
let e1 = self
.get_embedding(entity1)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity1),
query_fragment: None,
position: None,
suggestion: None,
})?;
let e2 = self
.get_embedding(entity2)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity2),
query_fragment: None,
position: None,
suggestion: None,
})?;
let dot: f64 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
let norm1: f64 = e1.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm2: f64 = e2.iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(dot / (norm1 * norm2 + 1e-10))
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> StarResult<Vec<(String, f64)>> {
let vocab = self
.vocab
.as_ref()
.ok_or_else(|| crate::StarError::QueryError {
message: "Model not trained".to_string(),
query_fragment: None,
position: None,
suggestion: Some("Train the model first".to_string()),
})?;
let h_idx = vocab
.entity_idx(head)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Head entity not found: {}", head),
query_fragment: None,
position: None,
suggestion: None,
})?;
let r_idx = vocab
.relation_idx(relation)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Relation not found: {}", relation),
query_fragment: None,
position: None,
suggestion: None,
})?;
let mut scores: Vec<(String, f64)> = Vec::new();
for t_idx in 0..vocab.num_entities() {
let score = self.score(h_idx, r_idx, t_idx);
let entity = vocab
.entity(t_idx)
.expect("entity index should be valid")
.to_string();
scores.push((entity, score)); }
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).collect())
}
fn save(&self, _path: &str) -> StarResult<()> {
Ok(())
}
fn load(&mut self, _path: &str) -> StarResult<()> {
Ok(())
}
}
pub struct ComplEx {
config: EmbeddingConfig,
entity_embeddings_real: Array2<f64>,
entity_embeddings_imag: Array2<f64>,
relation_embeddings_real: Array2<f64>,
relation_embeddings_imag: Array2<f64>,
vocab: Option<Vocabulary>,
#[allow(dead_code)]
seed: u64,
}
impl ComplEx {
pub fn new(config: EmbeddingConfig) -> Self {
Self::with_seed(config, 42)
}
pub fn with_seed(config: EmbeddingConfig, seed: u64) -> Self {
Self {
config,
entity_embeddings_real: Array2::zeros((0, 0)),
entity_embeddings_imag: Array2::zeros((0, 0)),
relation_embeddings_real: Array2::zeros((0, 0)),
relation_embeddings_imag: Array2::zeros((0, 0)),
vocab: None,
seed,
}
}
fn initialize_embeddings(&mut self, num_entities: usize, num_relations: usize) {
let dim = self.config.embedding_dim;
let scale = (6.0 / (2.0 * dim as f64)).sqrt();
self.entity_embeddings_real = Array2::zeros((num_entities, dim));
self.entity_embeddings_imag = Array2::zeros((num_entities, dim));
for i in 0..num_entities {
for j in 0..dim {
self.entity_embeddings_real[[i, j]] = random_uniform() * 2.0 * scale - scale;
self.entity_embeddings_imag[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
self.relation_embeddings_real = Array2::zeros((num_relations, dim));
self.relation_embeddings_imag = Array2::zeros((num_relations, dim));
for i in 0..num_relations {
for j in 0..dim {
self.relation_embeddings_real[[i, j]] = random_uniform() * 2.0 * scale - scale;
self.relation_embeddings_imag[[i, j]] = random_uniform() * 2.0 * scale - scale;
}
}
info!(
"Initialized ComplEx embeddings: {} entities, {} relations, dim={} (complex)",
num_entities, num_relations, dim
);
}
fn score(&self, head_idx: usize, rel_idx: usize, tail_idx: usize) -> f64 {
let h_re = self.entity_embeddings_real.row(head_idx);
let h_im = self.entity_embeddings_imag.row(head_idx);
let r_re = self.relation_embeddings_real.row(rel_idx);
let r_im = self.relation_embeddings_imag.row(rel_idx);
let t_re = self.entity_embeddings_real.row(tail_idx);
let t_im = self.entity_embeddings_imag.row(tail_idx);
let mut score = 0.0;
for i in 0..self.config.embedding_dim {
let hr_re = h_re[i] * r_re[i] - h_im[i] * r_im[i];
let hr_im = h_re[i] * r_im[i] + h_im[i] * r_re[i];
score += hr_re * t_re[i] + hr_im * t_im[i];
}
score
}
fn negative_sample(&self, num_entities: usize) -> usize {
(random_uniform() * num_entities as f64) as usize
}
}
impl EmbeddingModel for ComplEx {
#[instrument(skip(self, triples))]
fn train(&mut self, triples: &[StarTriple], epochs: usize) -> StarResult<TrainingStats> {
let start = std::time::Instant::now();
let vocab = Vocabulary::from_triples(triples);
let num_entities = vocab.num_entities();
let num_relations = vocab.num_relations();
info!(
"Training ComplEx on {} triples ({} entities, {} relations) for {} epochs",
triples.len(),
num_entities,
num_relations,
epochs
);
self.initialize_embeddings(num_entities, num_relations);
self.vocab = Some(vocab.clone());
let mut losses = Vec::with_capacity(epochs);
for epoch in 0..epochs {
let mut epoch_loss = 0.0;
let mut num_batches = 0;
let mut triple_indices: Vec<usize> = (0..triples.len()).collect();
for i in (1..triple_indices.len()).rev() {
let j = (random_uniform() * (i + 1) as f64) as usize;
triple_indices.swap(i, j);
}
for batch_start in (0..triples.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(triples.len());
let mut batch_loss = 0.0;
for &triple_idx in &triple_indices[batch_start..batch_end] {
let triple = &triples[triple_idx];
let h_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.subject))
.expect("entity should be in vocabulary");
let r_idx = vocab
.relation_idx(&Vocabulary::term_to_string(&triple.predicate))
.expect("relation should be in vocabulary");
let t_idx = vocab
.entity_idx(&Vocabulary::term_to_string(&triple.object))
.expect("entity should be in vocabulary");
let pos_score = self.score(h_idx, r_idx, t_idx);
for _ in 0..self.config.num_negative_samples {
let corrupt_head = random_uniform() > 0.5;
let (neg_h_idx, neg_t_idx) = if corrupt_head {
(self.negative_sample(num_entities), t_idx)
} else {
(h_idx, self.negative_sample(num_entities))
};
let neg_score = self.score(neg_h_idx, r_idx, neg_t_idx);
let margin_diff = pos_score - neg_score;
let loss = (1.0 + (-margin_diff).exp()).ln();
batch_loss += loss;
if loss > 0.01 {
let lr = self.config.learning_rate;
let sigmoid = 1.0 / (1.0 + margin_diff.exp());
for i in 0..self.config.embedding_dim {
let h_re = self.entity_embeddings_real[[h_idx, i]];
let h_im = self.entity_embeddings_imag[[h_idx, i]];
let r_re = self.relation_embeddings_real[[r_idx, i]];
let r_im = self.relation_embeddings_imag[[r_idx, i]];
let t_re = self.entity_embeddings_real[[t_idx, i]];
let t_im = self.entity_embeddings_imag[[t_idx, i]];
let grad_h_re = sigmoid * (r_re * t_re + r_im * t_im);
let grad_h_im = sigmoid * (r_im * t_re - r_re * t_im);
let grad_r_re = sigmoid * (h_re * t_re + h_im * t_im);
let grad_r_im = sigmoid * (h_re * t_im - h_im * t_re);
let grad_t_re = sigmoid * (h_re * r_re - h_im * r_im);
let grad_t_im = sigmoid * (h_re * r_im + h_im * r_re);
self.entity_embeddings_real[[h_idx, i]] += lr * grad_h_re;
self.entity_embeddings_imag[[h_idx, i]] += lr * grad_h_im;
self.relation_embeddings_real[[r_idx, i]] += lr * grad_r_re;
self.relation_embeddings_imag[[r_idx, i]] += lr * grad_r_im;
self.entity_embeddings_real[[t_idx, i]] += lr * grad_t_re;
self.entity_embeddings_imag[[t_idx, i]] += lr * grad_t_im;
if corrupt_head {
self.entity_embeddings_real[[neg_h_idx, i]] -=
lr * sigmoid * (r_re * t_re + r_im * t_im);
self.entity_embeddings_imag[[neg_h_idx, i]] -=
lr * sigmoid * (r_im * t_re - r_re * t_im);
} else {
self.entity_embeddings_real[[neg_t_idx, i]] -=
lr * sigmoid * (h_re * r_re - h_im * r_im);
self.entity_embeddings_imag[[neg_t_idx, i]] -=
lr * sigmoid * (h_re * r_im + h_im * r_re);
}
let reg_factor = 1.0 - self.config.l2_reg * lr;
self.entity_embeddings_real[[h_idx, i]] *= reg_factor;
self.entity_embeddings_imag[[h_idx, i]] *= reg_factor;
self.relation_embeddings_real[[r_idx, i]] *= reg_factor;
self.relation_embeddings_imag[[r_idx, i]] *= reg_factor;
}
}
}
}
epoch_loss += batch_loss;
num_batches += 1;
}
let avg_loss = epoch_loss / num_batches as f64;
losses.push(avg_loss);
if epoch % 10 == 0 {
debug!("Epoch {}/{}: loss = {:.4}", epoch + 1, epochs, avg_loss);
}
}
let training_time = start.elapsed().as_secs_f64();
info!(
"Training complete in {:.2}s, final loss: {:.4}",
training_time,
losses.last().copied().unwrap_or(0.0)
);
Ok(TrainingStats {
total_epochs: epochs,
final_loss: losses.last().copied().unwrap_or(0.0),
losses_per_epoch: losses,
training_time_secs: training_time,
})
}
fn get_embedding(&self, entity: &str) -> Option<Array1<f64>> {
let vocab = self.vocab.as_ref()?;
let idx = vocab.entity_idx(entity)?;
let real = self.entity_embeddings_real.row(idx);
let imag = self.entity_embeddings_imag.row(idx);
let mut embedding = Array1::zeros(self.config.embedding_dim * 2);
for i in 0..self.config.embedding_dim {
embedding[i] = real[i];
embedding[i + self.config.embedding_dim] = imag[i];
}
Some(embedding)
}
fn similarity(&self, entity1: &str, entity2: &str) -> StarResult<f64> {
let e1 = self
.get_embedding(entity1)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity1),
query_fragment: None,
position: None,
suggestion: None,
})?;
let e2 = self
.get_embedding(entity2)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Entity not found: {}", entity2),
query_fragment: None,
position: None,
suggestion: None,
})?;
let dot: f64 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
let norm1: f64 = e1.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm2: f64 = e2.iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(dot / (norm1 * norm2 + 1e-10))
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> StarResult<Vec<(String, f64)>> {
let vocab = self
.vocab
.as_ref()
.ok_or_else(|| crate::StarError::QueryError {
message: "Model not trained".to_string(),
query_fragment: None,
position: None,
suggestion: Some("Train the model first".to_string()),
})?;
let h_idx = vocab
.entity_idx(head)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Head entity not found: {}", head),
query_fragment: None,
position: None,
suggestion: None,
})?;
let r_idx = vocab
.relation_idx(relation)
.ok_or_else(|| crate::StarError::QueryError {
message: format!("Relation not found: {}", relation),
query_fragment: None,
position: None,
suggestion: None,
})?;
let mut scores: Vec<(String, f64)> = Vec::new();
for t_idx in 0..vocab.num_entities() {
let score = self.score(h_idx, r_idx, t_idx);
let entity = vocab
.entity(t_idx)
.expect("entity index should be valid")
.to_string();
scores.push((entity, score)); }
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scores.into_iter().take(k).collect())
}
fn save(&self, _path: &str) -> StarResult<()> {
Ok(())
}
fn load(&mut self, _path: &str) -> StarResult<()> {
Ok(())
}
}
#[cfg(test)]
mod advanced_model_tests {
use super::*;
use crate::{model::NamedNode, StarTerm};
fn create_test_triples() -> Vec<StarTriple> {
vec![
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Alice".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
},
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Charlie".to_string(),
}),
},
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Alice".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "likes".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Coffee".to_string(),
}),
},
]
}
#[test]
fn test_distmult_initialization() {
let config = EmbeddingConfig {
embedding_dim: 64,
..Default::default()
};
let model = DistMult::new(config);
assert_eq!(model.config.embedding_dim, 64);
}
#[test]
fn test_distmult_training() {
let config = EmbeddingConfig {
embedding_dim: 32,
learning_rate: 0.01,
batch_size: 2,
num_negative_samples: 3,
..Default::default()
};
let mut model = DistMult::new(config);
let triples = create_test_triples();
let stats = model.train(&triples, 10).unwrap();
assert_eq!(stats.total_epochs, 10);
assert!(stats.final_loss >= 0.0);
assert_eq!(stats.losses_per_epoch.len(), 10);
}
#[test]
fn test_distmult_get_embedding() {
let config = EmbeddingConfig::default();
let mut model = DistMult::new(config);
let triples = create_test_triples();
model.train(&triples, 5).unwrap();
let emb = model.get_embedding("Alice");
assert!(emb.is_some());
assert_eq!(emb.unwrap().len(), 128);
}
#[test]
fn test_distmult_similarity() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = DistMult::new(config);
let triples = create_test_triples();
model.train(&triples, 20).unwrap();
let sim = model.similarity("Alice", "Bob").unwrap();
assert!((-1.0..=1.0).contains(&sim));
}
#[test]
fn test_distmult_predict_tail() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = DistMult::new(config);
let triples = create_test_triples();
model.train(&triples, 10).unwrap();
let predictions = model.predict_tail("Alice", "knows", 3).unwrap();
assert_eq!(predictions.len(), 3);
assert!(predictions[0].1 >= predictions[1].1); }
#[test]
fn test_complex_initialization() {
let config = EmbeddingConfig {
embedding_dim: 64,
..Default::default()
};
let model = ComplEx::new(config);
assert_eq!(model.config.embedding_dim, 64);
}
#[test]
fn test_complex_training() {
let config = EmbeddingConfig {
embedding_dim: 32,
learning_rate: 0.01,
batch_size: 2,
num_negative_samples: 3,
..Default::default()
};
let mut model = ComplEx::new(config);
let triples = create_test_triples();
let stats = model.train(&triples, 10).unwrap();
assert_eq!(stats.total_epochs, 10);
assert!(stats.final_loss >= 0.0);
assert_eq!(stats.losses_per_epoch.len(), 10);
}
#[test]
fn test_complex_get_embedding() {
let config = EmbeddingConfig::default();
let mut model = ComplEx::new(config);
let triples = create_test_triples();
model.train(&triples, 5).unwrap();
let emb = model.get_embedding("Alice");
assert!(emb.is_some());
assert_eq!(emb.unwrap().len(), 128 * 2);
}
#[test]
fn test_complex_similarity() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = ComplEx::new(config);
let triples = create_test_triples();
model.train(&triples, 20).unwrap();
let sim = model.similarity("Alice", "Bob").unwrap();
assert!((-1.0..=1.0).contains(&sim));
}
#[test]
fn test_complex_predict_tail() {
let config = EmbeddingConfig {
embedding_dim: 32,
..Default::default()
};
let mut model = ComplEx::new(config);
let triples = create_test_triples();
model.train(&triples, 10).unwrap();
let predictions = model.predict_tail("Alice", "knows", 3).unwrap();
assert_eq!(predictions.len(), 3);
assert!(predictions[0].1 >= predictions[1].1); }
#[test]
fn test_model_comparison() {
let config = EmbeddingConfig {
embedding_dim: 32,
learning_rate: 0.01,
..Default::default()
};
let triples = create_test_triples();
let mut transe = TransE::new(config.clone());
let mut distmult = DistMult::new(config.clone());
let mut complex = ComplEx::new(config);
let stats_transe = transe.train(&triples, 20).unwrap();
let stats_distmult = distmult.train(&triples, 20).unwrap();
let stats_complex = complex.train(&triples, 20).unwrap();
assert!(stats_transe.final_loss < stats_transe.losses_per_epoch[0]);
assert!(stats_distmult.final_loss < stats_distmult.losses_per_epoch[0]);
assert!(stats_complex.final_loss < stats_complex.losses_per_epoch[0]);
let pred_transe = transe.predict_tail("Alice", "knows", 1).unwrap();
let pred_distmult = distmult.predict_tail("Alice", "knows", 1).unwrap();
let pred_complex = complex.predict_tail("Alice", "knows", 1).unwrap();
assert_eq!(pred_transe.len(), 1);
assert_eq!(pred_distmult.len(), 1);
assert_eq!(pred_complex.len(), 1);
}
}