use crate::gnn_embeddings::{GraphSAGE, GCN};
use crate::random_utils::NormalSampler as Normal;
use crate::Vector;
use anyhow::{anyhow, Result};
use nalgebra::{Complex, DVector};
use scirs2_core::random::{Random, Rng, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum KGEmbeddingModelType {
TransE,
ComplEx,
RotatE,
GCN,
GraphSAGE,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KGEmbeddingConfig {
pub model: KGEmbeddingModelType,
pub dimensions: usize,
pub learning_rate: f32,
pub margin: f32,
pub negative_samples: usize,
pub batch_size: usize,
pub epochs: usize,
pub norm: usize,
pub random_seed: Option<u64>,
pub regularization: f32,
}
impl Default for KGEmbeddingConfig {
fn default() -> Self {
Self {
model: KGEmbeddingModelType::TransE,
dimensions: 100,
learning_rate: 0.01,
margin: 1.0,
negative_samples: 10,
batch_size: 100,
epochs: 100,
norm: 2,
random_seed: Some(42),
regularization: 0.0,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Triple {
pub subject: String,
pub predicate: String,
pub object: String,
}
impl Triple {
pub fn new(subject: String, predicate: String, object: String) -> Self {
Self {
subject,
predicate,
object,
}
}
}
pub trait KGEmbeddingModel: Send + Sync {
fn train(&mut self, triples: &[Triple]) -> Result<()>;
fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
fn score_triple(&self, triple: &Triple) -> f32;
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
}
pub struct TransE {
config: KGEmbeddingConfig,
entity_embeddings: HashMap<String, DVector<f32>>,
relation_embeddings: HashMap<String, DVector<f32>>,
entities: Vec<String>,
relations: Vec<String>,
}
impl TransE {
pub fn new(config: KGEmbeddingConfig) -> Self {
Self {
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
}
}
fn initialize_embeddings(&mut self, triples: &[Triple]) {
let mut entities = std::collections::HashSet::new();
let mut relations = std::collections::HashSet::new();
for triple in triples {
entities.insert(triple.subject.clone());
entities.insert(triple.object.clone());
relations.insert(triple.predicate.clone());
}
self.entities = entities.into_iter().collect();
self.relations = relations.into_iter().collect();
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
for entity in &self.entities {
let values: Vec<f32> = (0..self.config.dimensions)
.map(|_| rng.random_range(range_min..range_max))
.collect();
let mut embedding = DVector::from_vec(values);
let norm = embedding.norm();
if norm > 0.0 {
embedding /= norm;
}
self.entity_embeddings.insert(entity.clone(), embedding);
}
for relation in &self.relations {
let values: Vec<f32> = (0..self.config.dimensions)
.map(|_| rng.random_range(range_min..range_max))
.collect();
let embedding = DVector::from_vec(values);
self.relation_embeddings.insert(relation.clone(), embedding);
}
}
#[allow(deprecated)]
fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
let mut negatives = Vec::new();
for _ in 0..self.config.negative_samples {
if rng.random_bool(0.5) {
let mut negative = triple.clone();
loop {
let idx = rng.random_range(0..self.entities.len());
let entity = &self.entities[idx];
if entity != &triple.subject {
negative.subject = entity.clone();
break;
}
}
negatives.push(negative);
} else {
let mut negative = triple.clone();
loop {
let idx = rng.random_range(0..self.entities.len());
let entity = &self.entities[idx];
if entity != &triple.object {
negative.object = entity.clone();
break;
}
}
negatives.push(negative);
}
}
negatives
}
fn distance(&self, triple: &Triple) -> f32 {
let h = self
.entity_embeddings
.get(&triple.subject)
.expect("subject entity should have embedding");
let r = self
.relation_embeddings
.get(&triple.predicate)
.expect("predicate relation should have embedding");
let t = self
.entity_embeddings
.get(&triple.object)
.expect("object entity should have embedding");
let translation = h + r - t;
match self.config.norm {
1 => translation.iter().map(|x| x.abs()).sum(),
2 => translation.norm(),
_ => translation.norm(),
}
}
fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
let pos_dist = self.distance(positive);
for negative in negatives {
let neg_dist = self.distance(negative);
let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
if loss > 0.0 {
let h_pos = self
.entity_embeddings
.get(&positive.subject)
.expect("positive subject entity should have embedding")
.clone();
let r = self
.relation_embeddings
.get(&positive.predicate)
.expect("positive predicate relation should have embedding")
.clone();
let t_pos = self
.entity_embeddings
.get(&positive.object)
.expect("positive object entity should have embedding")
.clone();
let h_neg = self
.entity_embeddings
.get(&negative.subject)
.expect("negative subject entity should have embedding")
.clone();
let t_neg = self
.entity_embeddings
.get(&negative.object)
.expect("negative object entity should have embedding")
.clone();
let pos_grad = &h_pos + &r - &t_pos;
let neg_grad = &h_neg + &r - &t_neg;
let pos_norm = pos_grad.norm();
let neg_norm = neg_grad.norm();
let pos_grad_norm = if pos_norm > 0.0 {
&pos_grad / pos_norm
} else {
pos_grad
};
let neg_grad_norm = if neg_norm > 0.0 {
&neg_grad / neg_norm
} else {
neg_grad
};
let lr = self.config.learning_rate;
if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
*h -= lr * &pos_grad_norm;
let norm = h.norm();
if norm > 0.0 {
*h /= norm;
}
}
if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
*r -= lr * (&pos_grad_norm - &neg_grad_norm);
}
if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
*t += lr * &pos_grad_norm;
let norm = t.norm();
if norm > 0.0 {
*t /= norm;
}
}
if positive.subject != negative.subject {
if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
*h += lr * &neg_grad_norm;
let norm = h.norm();
if norm > 0.0 {
*h /= norm;
}
}
}
if positive.object != negative.object {
if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
*t -= lr * &neg_grad_norm;
let norm = t.norm();
if norm > 0.0 {
*t /= norm;
}
}
}
}
}
}
}
impl KGEmbeddingModel for TransE {
fn train(&mut self, triples: &[Triple]) -> Result<()> {
if triples.is_empty() {
return Err(anyhow!("No triples provided for training"));
}
self.initialize_embeddings(triples);
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
for epoch in 0..self.config.epochs {
let mut total_loss = 0.0;
let mut batch_count = 0;
let mut shuffled_triples = triples.to_vec();
for i in (1..shuffled_triples.len()).rev() {
let j = rng.random_range(0..i + 1);
shuffled_triples.swap(i, j);
}
for batch in shuffled_triples.chunks(self.config.batch_size) {
for triple in batch {
let negatives = self.generate_negative_samples(triple, &mut rng);
let pos_dist = self.distance(triple);
for negative in &negatives {
let neg_dist = self.distance(negative);
let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
total_loss += loss;
}
self.update_embeddings(triple, &negatives);
}
batch_count += 1;
}
if epoch % 10 == 0 {
let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
}
}
Ok(())
}
fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
self.entity_embeddings
.get(entity)
.map(|embedding| Vector::new(embedding.iter().cloned().collect()))
}
fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
self.relation_embeddings
.get(relation)
.map(|embedding| Vector::new(embedding.iter().cloned().collect()))
}
fn score_triple(&self, triple: &Triple) -> f32 {
-self.distance(triple)
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
let h = match self.entity_embeddings.get(head) {
Some(emb) => emb,
None => return Vec::new(),
};
let r = match self.relation_embeddings.get(relation) {
Some(emb) => emb,
None => return Vec::new(),
};
let translation = h + r;
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != head)
.filter_map(|entity| {
self.entity_embeddings.get(entity).map(|t| {
let distance = match self.config.norm {
1 => (&translation - t).iter().map(|x| x.abs()).sum(),
2 => (&translation - t).norm(),
_ => (&translation - t).norm(),
};
(entity.clone(), -distance)
})
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
let t = match self.entity_embeddings.get(tail) {
Some(emb) => emb,
None => return Vec::new(),
};
let r = match self.relation_embeddings.get(relation) {
Some(emb) => emb,
None => return Vec::new(),
};
let target = t - r;
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != tail)
.filter_map(|entity| {
self.entity_embeddings.get(entity).map(|h| {
let distance = match self.config.norm {
1 => (h - &target).iter().map(|x| x.abs()).sum(),
2 => (h - &target).norm(),
_ => (h - &target).norm(),
};
(entity.clone(), -distance)
})
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
self.entity_embeddings
.iter()
.map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
.collect()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
self.relation_embeddings
.iter()
.map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
.collect()
}
}
pub struct ComplEx {
config: KGEmbeddingConfig,
entity_embeddings_real: HashMap<String, DVector<f32>>,
entity_embeddings_imag: HashMap<String, DVector<f32>>,
relation_embeddings_real: HashMap<String, DVector<f32>>,
relation_embeddings_imag: HashMap<String, DVector<f32>>,
entities: Vec<String>,
relations: Vec<String>,
}
impl ComplEx {
pub fn new(config: KGEmbeddingConfig) -> Self {
Self {
config,
entity_embeddings_real: HashMap::new(),
entity_embeddings_imag: HashMap::new(),
relation_embeddings_real: HashMap::new(),
relation_embeddings_imag: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
}
}
fn initialize_embeddings(&mut self, triples: &[Triple]) {
let mut entities = std::collections::HashSet::new();
let mut relations = std::collections::HashSet::new();
for triple in triples {
entities.insert(triple.subject.clone());
entities.insert(triple.object.clone());
relations.insert(triple.predicate.clone());
}
self.entities = entities.into_iter().collect();
self.relations = relations.into_iter().collect();
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
let normal =
Normal::new(0.0, std_dev).expect("normal distribution parameters should be valid");
for entity in &self.entities {
let real_values: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
let imag_values: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.entity_embeddings_real
.insert(entity.clone(), DVector::from_vec(real_values));
self.entity_embeddings_imag
.insert(entity.clone(), DVector::from_vec(imag_values));
}
for relation in &self.relations {
let real_values: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
let imag_values: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.relation_embeddings_real
.insert(relation.clone(), DVector::from_vec(real_values));
self.relation_embeddings_imag
.insert(relation.clone(), DVector::from_vec(imag_values));
}
}
fn hermitian_dot(&self, triple: &Triple) -> f32 {
let h_real = self
.entity_embeddings_real
.get(&triple.subject)
.expect("subject entity should have real embedding");
let h_imag = self
.entity_embeddings_imag
.get(&triple.subject)
.expect("subject entity should have imag embedding");
let r_real = self
.relation_embeddings_real
.get(&triple.predicate)
.expect("predicate relation should have real embedding");
let r_imag = self
.relation_embeddings_imag
.get(&triple.predicate)
.expect("predicate relation should have imag embedding");
let t_real = self
.entity_embeddings_real
.get(&triple.object)
.expect("object entity should have real embedding");
let t_imag = self
.entity_embeddings_imag
.get(&triple.object)
.expect("object entity should have imag embedding");
let mut score = 0.0;
for i in 0..self.config.dimensions {
score += h_real[i] * r_real[i] * t_real[i]
+ h_real[i] * r_imag[i] * t_imag[i]
+ h_imag[i] * r_real[i] * t_imag[i]
- h_imag[i] * r_imag[i] * t_real[i];
}
score
}
}
impl KGEmbeddingModel for ComplEx {
fn train(&mut self, triples: &[Triple]) -> Result<()> {
if triples.is_empty() {
return Err(anyhow!("No triples provided for training"));
}
self.initialize_embeddings(triples);
Ok(())
}
fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
let real = self.entity_embeddings_real.get(entity)?;
let imag = self.entity_embeddings_imag.get(entity)?;
let mut values = Vec::with_capacity(self.config.dimensions * 2);
values.extend(real.iter().cloned());
values.extend(imag.iter().cloned());
Some(Vector::new(values))
}
fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
let real = self.relation_embeddings_real.get(relation)?;
let imag = self.relation_embeddings_imag.get(relation)?;
let mut values = Vec::with_capacity(self.config.dimensions * 2);
values.extend(real.iter().cloned());
values.extend(imag.iter().cloned());
Some(Vector::new(values))
}
fn score_triple(&self, triple: &Triple) -> f32 {
self.hermitian_dot(triple)
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != head)
.map(|tail| {
let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
let score = self.score_triple(&triple);
(tail.clone(), score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != tail)
.map(|head| {
let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
let score = self.score_triple(&triple);
(head.clone(), score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
self.entity_embeddings_real
.iter()
.map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
.collect()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
self.relation_embeddings_real
.iter()
.map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
.collect()
}
}
pub struct RotatE {
config: KGEmbeddingConfig,
entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
relation_embeddings: HashMap<String, DVector<f32>>, entities: Vec<String>,
relations: Vec<String>,
}
impl RotatE {
pub fn new(config: KGEmbeddingConfig) -> Self {
Self {
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
}
}
fn initialize_embeddings(&mut self, triples: &[Triple]) {
let mut entities = std::collections::HashSet::new();
let mut relations = std::collections::HashSet::new();
for triple in triples {
entities.insert(triple.subject.clone());
entities.insert(triple.object.clone());
relations.insert(triple.predicate.clone());
}
self.entities = entities.into_iter().collect();
self.relations = relations.into_iter().collect();
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
for entity in &self.entities {
let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
.map(|_| {
let phase = rng.random_range(phase_range.clone());
Complex::new(phase.cos(), phase.sin())
})
.collect();
self.entity_embeddings
.insert(entity.clone(), DVector::from_vec(phases));
}
for relation in &self.relations {
let phases: Vec<f32> = (0..self.config.dimensions)
.map(|_| rng.random_range(phase_range.clone()))
.collect();
self.relation_embeddings
.insert(relation.clone(), DVector::from_vec(phases));
}
}
fn distance(&self, triple: &Triple) -> f32 {
let h = self
.entity_embeddings
.get(&triple.subject)
.expect("subject entity should have embedding");
let r_phases = self
.relation_embeddings
.get(&triple.predicate)
.expect("predicate relation should have embedding");
let t = self
.entity_embeddings
.get(&triple.object)
.expect("object entity should have embedding");
let r: DVector<Complex<f32>> = DVector::from_iterator(
self.config.dimensions,
r_phases
.iter()
.map(|&phase| Complex::new(phase.cos(), phase.sin())),
);
let rotated: DVector<Complex<f32>> = h.component_mul(&r);
let diff = rotated - t;
diff.iter().map(|c| c.norm()).sum::<f32>()
}
}
impl KGEmbeddingModel for RotatE {
fn train(&mut self, triples: &[Triple]) -> Result<()> {
if triples.is_empty() {
return Err(anyhow!("No triples provided for training"));
}
self.initialize_embeddings(triples);
Ok(())
}
fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
let complex_emb = self.entity_embeddings.get(entity)?;
let mut values = Vec::with_capacity(self.config.dimensions * 2);
for c in complex_emb.iter() {
values.push(c.re); values.push(c.im); }
Some(Vector::new(values))
}
fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
self.relation_embeddings
.get(relation)
.map(|phases| Vector::new(phases.iter().cloned().collect()))
}
fn score_triple(&self, triple: &Triple) -> f32 {
let gamma = 12.0; gamma - self.distance(triple)
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
let h = match self.entity_embeddings.get(head) {
Some(emb) => emb,
None => return Vec::new(),
};
let r_phases = match self.relation_embeddings.get(relation) {
Some(emb) => emb,
None => return Vec::new(),
};
let r: DVector<Complex<f32>> = DVector::from_iterator(
self.config.dimensions,
r_phases
.iter()
.map(|&phase| Complex::new(phase.cos(), phase.sin())),
);
let rotated = h.component_mul(&r);
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != head)
.filter_map(|entity| {
self.entity_embeddings.get(entity).map(|t| {
let diff = &rotated - t;
let distance: f32 = diff.iter().map(|c| c.norm()).sum();
(entity.clone(), -distance)
})
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
let t = match self.entity_embeddings.get(tail) {
Some(emb) => emb,
None => return Vec::new(),
};
let r_phases = match self.relation_embeddings.get(relation) {
Some(emb) => emb,
None => return Vec::new(),
};
let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
self.config.dimensions,
r_phases
.iter()
.map(|&phase| Complex::new(phase.cos(), -phase.sin())),
);
let mut scores: Vec<(String, f32)> = self
.entities
.iter()
.filter(|e| *e != tail)
.filter_map(|entity| {
self.entity_embeddings.get(entity).map(|h| {
let rotated = h.component_mul(&r_inv);
let diff = rotated - t;
let distance: f32 = diff.iter().map(|c| c.norm()).sum();
(entity.clone(), -distance)
})
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
self.entity_embeddings
.iter()
.map(|(k, v)| {
let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
(k.clone(), Vector::new(real_values))
})
.collect()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
self.relation_embeddings
.iter()
.map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
.collect()
}
}
pub struct KGEmbedding {
model: Box<dyn KGEmbeddingModel>,
config: KGEmbeddingConfig,
}
impl KGEmbedding {
pub fn new(config: KGEmbeddingConfig) -> Self {
let model: Box<dyn KGEmbeddingModel> = match config.model {
KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
KGEmbeddingModelType::GCN => {
let gcn = GCN::new(config.clone());
Box::new(GCNAdapter::new(gcn))
}
KGEmbeddingModelType::GraphSAGE => {
let graphsage = GraphSAGE::new(config.clone())
.with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
Box::new(GraphSAGEAdapter::new(graphsage))
}
};
Self { model, config }
}
pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
self.model.train(triples)
}
pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
self.model.get_entity_embedding(entity)
}
pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
self.model.get_relation_embedding(relation)
}
pub fn score_triple(&self, triple: &Triple) -> f32 {
self.model.score_triple(triple)
}
pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
self.model.predict_tail(head, relation, k)
}
pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
self.model.predict_head(relation, tail, k)
}
pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
self.model.score_triple(triple) > threshold
}
}
pub struct GCNAdapter {
gcn: GCN,
}
impl GCNAdapter {
pub fn new(gcn: GCN) -> Self {
Self { gcn }
}
}
impl KGEmbeddingModel for GCNAdapter {
fn train(&mut self, _triples: &[Triple]) -> Result<()> {
Ok(())
}
fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
Some(Vector::new(vec![0.0; 128]))
}
fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
Some(Vector::new(vec![0.0; 128]))
}
fn score_triple(&self, _triple: &Triple) -> f32 {
0.5
}
fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
vec![]
}
fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
vec![]
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
}
pub struct GraphSAGEAdapter {
graphsage: GraphSAGE,
}
impl GraphSAGEAdapter {
pub fn new(graphsage: GraphSAGE) -> Self {
Self { graphsage }
}
}
impl KGEmbeddingModel for GraphSAGEAdapter {
fn train(&mut self, _triples: &[Triple]) -> Result<()> {
Ok(())
}
fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
}
fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
}
fn score_triple(&self, _triple: &Triple) -> f32 {
0.5
}
fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
vec![]
}
fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
vec![]
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
fn create_test_triples() -> Vec<Triple> {
vec![
Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
Triple::new(
"Bob".to_string(),
"knows".to_string(),
"Charlie".to_string(),
),
Triple::new(
"Alice".to_string(),
"likes".to_string(),
"Pizza".to_string(),
),
Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
Triple::new(
"Charlie".to_string(),
"knows".to_string(),
"Alice".to_string(),
),
]
}
#[test]
fn test_transe() -> Result<()> {
let config = KGEmbeddingConfig {
model: KGEmbeddingModelType::TransE,
dimensions: 50,
epochs: 10,
..Default::default()
};
let mut model = KGEmbedding::new(config);
let triples = create_test_triples();
model.train(&triples)?;
assert!(model.get_entity_embedding("Alice").is_some());
assert!(model.get_relation_embedding("knows").is_some());
let score = model.score_triple(&triples[0]);
assert!(score.is_finite());
let predictions = model.predict_tail("Alice", "knows", 2);
assert!(!predictions.is_empty());
Ok(())
}
#[test]
fn test_complex() -> Result<()> {
let config = KGEmbeddingConfig {
model: KGEmbeddingModelType::ComplEx,
dimensions: 50,
epochs: 10,
..Default::default()
};
let mut model = KGEmbedding::new(config);
let triples = create_test_triples();
model.train(&triples)?;
assert!(model.get_entity_embedding("Bob").is_some());
let emb = model
.get_entity_embedding("Bob")
.expect("Bob embedding should exist");
assert_eq!(emb.dimensions, 100); Ok(())
}
#[test]
fn test_rotate() -> Result<()> {
let config = KGEmbeddingConfig {
model: KGEmbeddingModelType::RotatE,
dimensions: 50,
epochs: 10,
..Default::default()
};
let mut model = KGEmbedding::new(config);
let triples = create_test_triples();
model.train(&triples)?;
assert!(model.get_entity_embedding("Charlie").is_some());
assert!(model.get_relation_embedding("likes").is_some());
let rel_emb = model
.get_relation_embedding("likes")
.expect("likes relation embedding should exist");
assert_eq!(rel_emb.dimensions, 50);
Ok(())
}
}