use crate::random_utils::NormalSampler as Normal;
use crate::{
kg_embeddings::{KGEmbeddingConfig, KGEmbeddingModel, Triple},
Vector,
};
use anyhow::{anyhow, Result};
use nalgebra::{DMatrix, DVector};
use scirs2_core::random::{Random, Rng, RngExt};
use std::collections::HashMap;
pub struct GCN {
config: KGEmbeddingConfig,
entity_embeddings: HashMap<String, DVector<f32>>,
relation_embeddings: HashMap<String, DVector<f32>>,
entities: Vec<String>,
relations: Vec<String>,
adjacency_matrix: Option<DMatrix<f32>>,
weight_matrices: Vec<DMatrix<f32>>,
num_layers: usize,
}
impl GCN {
pub fn new(config: KGEmbeddingConfig) -> Self {
let num_layers = 2; Self {
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
adjacency_matrix: None,
weight_matrices: Vec::new(),
num_layers,
}
}
pub fn with_layers(config: KGEmbeddingConfig, num_layers: usize) -> Self {
Self {
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
adjacency_matrix: None,
weight_matrices: Vec::new(),
num_layers,
}
}
fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
let mut entities = std::collections::HashSet::new();
let mut relations = std::collections::HashSet::new();
for triple in triples {
entities.insert(triple.subject.clone());
entities.insert(triple.object.clone());
relations.insert(triple.predicate.clone());
}
self.entities = entities.into_iter().collect();
self.relations = relations.into_iter().collect();
let _num_entities = self.entities.len();
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let normal = Normal::new(0.0, 0.1)
.map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
for entity in &self.entities {
let embedding: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.entity_embeddings
.insert(entity.clone(), DVector::from_vec(embedding));
}
for relation in &self.relations {
let embedding: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.relation_embeddings
.insert(relation.clone(), DVector::from_vec(embedding));
}
self.build_adjacency_matrix(triples)?;
self.weight_matrices.clear();
for _ in 0..self.num_layers {
let weight_matrix =
DMatrix::from_fn(self.config.dimensions, self.config.dimensions, |_, _| {
normal.sample(&mut rng)
});
self.weight_matrices.push(weight_matrix);
}
Ok(())
}
fn build_adjacency_matrix(&mut self, triples: &[Triple]) -> Result<()> {
let num_entities = self.entities.len();
let mut adj_matrix = DMatrix::zeros(num_entities, num_entities);
let entity_to_index: HashMap<String, usize> = self
.entities
.iter()
.enumerate()
.map(|(i, entity)| (entity.clone(), i))
.collect();
for triple in triples {
if let (Some(&subject_idx), Some(&object_idx)) = (
entity_to_index.get(&triple.subject),
entity_to_index.get(&triple.object),
) {
adj_matrix[(subject_idx, object_idx)] = 1.0;
adj_matrix[(object_idx, subject_idx)] = 1.0; }
}
for i in 0..num_entities {
adj_matrix[(i, i)] = 1.0;
}
self.adjacency_matrix = Some(self.normalize_adjacency_matrix(adj_matrix));
Ok(())
}
fn normalize_adjacency_matrix(&self, mut adj_matrix: DMatrix<f32>) -> DMatrix<f32> {
let num_nodes = adj_matrix.nrows();
let mut degrees = Vec::with_capacity(num_nodes);
for i in 0..num_nodes {
let degree: f32 = (0..num_nodes).map(|j| adj_matrix[(i, j)]).sum();
degrees.push(if degree > 0.0 {
1.0 / degree.sqrt()
} else {
0.0
});
}
for i in 0..num_nodes {
for j in 0..num_nodes {
adj_matrix[(i, j)] *= degrees[i] * degrees[j];
}
}
adj_matrix
}
fn forward_pass(&self, features: &DMatrix<f32>) -> Result<DMatrix<f32>> {
let adj_matrix = self
.adjacency_matrix
.as_ref()
.ok_or_else(|| anyhow!("Adjacency matrix not initialized"))?;
let mut hidden = features.clone();
for layer_idx in 0..self.num_layers {
let weight = &self.weight_matrices[layer_idx];
let linear_transform = &hidden * weight;
hidden = adj_matrix * &linear_transform;
if layer_idx < self.num_layers - 1 {
hidden = hidden.map(|x| x.max(0.0));
}
}
Ok(hidden)
}
fn train_gcn(&mut self, _triples: &[Triple]) -> Result<()> {
let num_entities = self.entities.len();
let mut features = DMatrix::zeros(num_entities, self.config.dimensions);
for (i, entity) in self.entities.iter().enumerate() {
if let Some(embedding) = self.entity_embeddings.get(entity) {
for (j, &value) in embedding.iter().enumerate() {
features[(i, j)] = value;
}
}
}
let updated_features = self.forward_pass(&features)?;
for (i, entity) in self.entities.iter().enumerate() {
let new_embedding: Vec<f32> = (0..self.config.dimensions)
.map(|j| updated_features[(i, j)])
.collect();
self.entity_embeddings
.insert(entity.clone(), DVector::from_vec(new_embedding));
}
Ok(())
}
}
impl KGEmbeddingModel for GCN {
fn train(&mut self, triples: &[Triple]) -> Result<()> {
self.initialize(triples)?;
for epoch in 0..self.config.epochs {
self.train_gcn(triples)?;
if epoch % 10 == 0 {
println!("GCN training epoch {}/{}", epoch, self.config.epochs);
}
}
Ok(())
}
fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
self.entity_embeddings
.get(entity)
.map(|embedding| Vector::new(embedding.as_slice().to_vec()))
}
fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
self.relation_embeddings
.get(relation)
.map(|embedding| Vector::new(embedding.as_slice().to_vec()))
}
fn score_triple(&self, triple: &Triple) -> f32 {
if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
self.get_entity_embedding(&triple.subject),
self.get_relation_embedding(&triple.predicate),
self.get_entity_embedding(&triple.object),
) {
let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
} else {
0.0
}
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
if let (Some(head_emb), Some(rel_emb)) = (
self.get_entity_embedding(head),
self.get_relation_embedding(relation),
) {
let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
let mut scores = Vec::new();
for entity in &self.entities {
if entity != head {
if let Some(entity_emb) = self.get_entity_embedding(entity) {
let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
scores.push((entity.clone(), score));
}
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(k).collect()
} else {
Vec::new()
}
}
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
if let (Some(rel_emb), Some(tail_emb)) = (
self.get_relation_embedding(relation),
self.get_entity_embedding(tail),
) {
let mut scores = Vec::new();
for entity in &self.entities {
if entity != tail {
if let Some(entity_emb) = self.get_entity_embedding(entity) {
let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
scores.push((entity.clone(), score));
}
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(k).collect()
} else {
Vec::new()
}
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
}
pub struct GraphSAGE {
config: KGEmbeddingConfig,
entity_embeddings: HashMap<String, DVector<f32>>,
relation_embeddings: HashMap<String, DVector<f32>>,
entities: Vec<String>,
relations: Vec<String>,
graph: HashMap<String, Vec<String>>, aggregator_type: AggregatorType,
num_layers: usize,
sample_size: usize,
sampling_strategy: SamplingStrategy,
}
#[derive(Debug, Clone, Copy)]
pub enum AggregatorType {
Mean,
LSTM,
Pool,
Attention,
}
#[derive(Debug, Clone, Copy)]
pub enum SamplingStrategy {
Uniform, Degree, PageRank, Recent, }
impl GraphSAGE {
pub fn new(config: KGEmbeddingConfig) -> Self {
Self {
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entities: Vec::new(),
relations: Vec::new(),
graph: HashMap::new(),
aggregator_type: AggregatorType::Mean,
num_layers: 2,
sample_size: 10, sampling_strategy: SamplingStrategy::Uniform,
}
}
pub fn with_aggregator(mut self, aggregator: AggregatorType) -> Self {
self.aggregator_type = aggregator;
self
}
pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
self.sampling_strategy = strategy;
self
}
pub fn with_sample_size(mut self, size: usize) -> Self {
self.sample_size = size;
self
}
pub fn dimensions(&self) -> usize {
self.config.dimensions
}
fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
let mut entities = std::collections::HashSet::new();
let mut relations = std::collections::HashSet::new();
for triple in triples {
entities.insert(triple.subject.clone());
entities.insert(triple.object.clone());
relations.insert(triple.predicate.clone());
}
self.entities = entities.into_iter().collect();
self.relations = relations.into_iter().collect();
self.build_graph(triples);
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
let normal = Normal::new(0.0, 0.1)
.map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
for entity in &self.entities {
let embedding: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.entity_embeddings
.insert(entity.clone(), DVector::from_vec(embedding));
}
for relation in &self.relations {
let embedding: Vec<f32> = (0..self.config.dimensions)
.map(|_| normal.sample(&mut rng))
.collect();
self.relation_embeddings
.insert(relation.clone(), DVector::from_vec(embedding));
}
Ok(())
}
fn build_graph(&mut self, triples: &[Triple]) {
for triple in triples {
self.graph
.entry(triple.subject.clone())
.or_default()
.push(triple.object.clone());
self.graph
.entry(triple.object.clone())
.or_default()
.push(triple.subject.clone());
}
}
#[allow(deprecated)]
fn sample_neighbors(&self, node: &str, rng: &mut impl Rng) -> Vec<String> {
if let Some(neighbors) = self.graph.get(node) {
if neighbors.len() <= self.sample_size {
neighbors.clone()
} else {
match self.sampling_strategy {
SamplingStrategy::Uniform => {
let mut sampled = Vec::new();
let sample_size = std::cmp::min(self.sample_size, neighbors.len());
for (i, neighbor) in neighbors.iter().enumerate() {
if sampled.len() < sample_size {
sampled.push(neighbor.clone());
} else {
let j = rng.random_range(0..=i);
if j < sample_size {
sampled[j] = neighbor.clone();
}
}
}
sampled
}
SamplingStrategy::Degree => self.degree_based_sampling(neighbors, rng),
SamplingStrategy::PageRank => {
self.degree_based_sampling(neighbors, rng)
}
SamplingStrategy::Recent => {
neighbors
.iter()
.rev()
.take(self.sample_size)
.cloned()
.collect()
}
}
}
} else {
Vec::new()
}
}
#[allow(deprecated)]
fn degree_based_sampling(&self, neighbors: &[String], rng: &mut impl Rng) -> Vec<String> {
let mut neighbor_degrees: Vec<(String, usize)> = neighbors
.iter()
.map(|neighbor| {
let degree = self.graph.get(neighbor).map(|n| n.len()).unwrap_or(0);
(neighbor.clone(), degree)
})
.collect();
neighbor_degrees.sort_by(|a, b| {
let degree_cmp = b.1.cmp(&a.1);
if degree_cmp == std::cmp::Ordering::Equal {
if rng.random_bool(0.5) {
std::cmp::Ordering::Greater
} else {
std::cmp::Ordering::Less
}
} else {
degree_cmp
}
});
neighbor_degrees
.into_iter()
.take(self.sample_size)
.map(|(neighbor, _)| neighbor)
.collect()
}
fn aggregate_neighbors(&self, neighbors: &[String]) -> Result<DVector<f32>> {
if neighbors.is_empty() {
return Ok(DVector::zeros(self.config.dimensions));
}
match self.aggregator_type {
AggregatorType::Mean => {
let mut sum = DVector::zeros(self.config.dimensions);
let mut count = 0;
for neighbor in neighbors {
if let Some(embedding) = self.entity_embeddings.get(neighbor) {
sum += embedding;
count += 1;
}
}
if count > 0 {
Ok(sum / count as f32)
} else {
Ok(DVector::zeros(self.config.dimensions))
}
}
AggregatorType::Pool => {
let mut max_embedding =
DVector::from_element(self.config.dimensions, f32::NEG_INFINITY);
for neighbor in neighbors {
if let Some(embedding) = self.entity_embeddings.get(neighbor) {
for i in 0..self.config.dimensions {
max_embedding[i] = max_embedding[i].max(embedding[i]);
}
}
}
for i in 0..self.config.dimensions {
if max_embedding[i] == f32::NEG_INFINITY {
max_embedding[i] = 0.0;
}
}
Ok(max_embedding)
}
AggregatorType::LSTM => {
self.lstm_aggregate(neighbors)
}
AggregatorType::Attention => {
self.attention_aggregate(neighbors)
}
}
}
fn lstm_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
if neighbors.is_empty() {
return Ok(DVector::zeros(self.config.dimensions));
}
let mut cell_state = DVector::zeros(self.config.dimensions);
let mut hidden_state = DVector::zeros(self.config.dimensions);
for neighbor in neighbors {
if let Some(embedding) = self.entity_embeddings.get(neighbor) {
let forget_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp())); let input_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
let candidate = embedding.map(|x| x.tanh());
cell_state =
cell_state.component_mul(&forget_gate) + input_gate.component_mul(&candidate);
let output_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
hidden_state = output_gate.component_mul(&cell_state.map(|x| x.tanh()));
}
}
Ok(hidden_state)
}
fn attention_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
if neighbors.is_empty() {
return Ok(DVector::zeros(self.config.dimensions));
}
let neighbor_embeddings: Vec<&DVector<f32>> = neighbors
.iter()
.filter_map(|neighbor| self.entity_embeddings.get(neighbor))
.collect();
if neighbor_embeddings.is_empty() {
return Ok(DVector::zeros(self.config.dimensions));
}
let mut attention_scores = Vec::new();
let mut weighted_sum = DVector::zeros(self.config.dimensions);
let query = DVector::from_element(self.config.dimensions, 1.0);
for embedding in &neighbor_embeddings {
let score = query.dot(embedding).exp(); attention_scores.push(score);
}
let total_score: f32 = attention_scores.iter().sum();
if total_score > 0.0 {
for score in &mut attention_scores {
*score /= total_score;
}
}
for (embedding, &score) in neighbor_embeddings.iter().zip(attention_scores.iter()) {
weighted_sum += *embedding * score;
}
Ok(weighted_sum)
}
fn forward_node(&self, node: &str, rng: &mut impl Rng) -> Result<DVector<f32>> {
let neighbors = self.sample_neighbors(node, rng);
let neighbor_aggregate = self.aggregate_neighbors(&neighbors)?;
if let Some(node_embedding) = self.entity_embeddings.get(node) {
Ok(node_embedding + neighbor_aggregate)
} else {
Ok(neighbor_aggregate)
}
}
}
impl KGEmbeddingModel for GraphSAGE {
fn train(&mut self, triples: &[Triple]) -> Result<()> {
self.initialize(triples)?;
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
for epoch in 0..self.config.epochs {
let mut new_embeddings = HashMap::new();
for entity in &self.entities {
let new_embedding = self.forward_node(entity, &mut rng)?;
new_embeddings.insert(entity.clone(), new_embedding);
}
self.entity_embeddings = new_embeddings;
if epoch % 10 == 0 {
println!("GraphSAGE training epoch {}/{}", epoch, self.config.epochs);
}
}
Ok(())
}
fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
self.entity_embeddings
.get(entity)
.map(|embedding| Vector::new(embedding.as_slice().to_vec()))
}
fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
self.relation_embeddings
.get(relation)
.map(|embedding| Vector::new(embedding.as_slice().to_vec()))
}
fn score_triple(&self, triple: &Triple) -> f32 {
if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
self.get_entity_embedding(&triple.subject),
self.get_relation_embedding(&triple.predicate),
self.get_entity_embedding(&triple.object),
) {
let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
} else {
0.0
}
}
fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
if let (Some(head_emb), Some(rel_emb)) = (
self.get_entity_embedding(head),
self.get_relation_embedding(relation),
) {
let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
let mut scores = Vec::new();
for entity in &self.entities {
if entity != head {
if let Some(entity_emb) = self.get_entity_embedding(entity) {
let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
scores.push((entity.clone(), score));
}
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(k).collect()
} else {
Vec::new()
}
}
fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
if let (Some(rel_emb), Some(tail_emb)) = (
self.get_relation_embedding(relation),
self.get_entity_embedding(tail),
) {
let mut scores = Vec::new();
for entity in &self.entities {
if entity != tail {
if let Some(entity_emb) = self.get_entity_embedding(entity) {
let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
scores.push((entity.clone(), score));
}
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(k).collect()
} else {
Vec::new()
}
}
fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
HashMap::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_gcn_creation() {
let config = KGEmbeddingConfig {
model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
dimensions: 64,
learning_rate: 0.01,
margin: 1.0,
negative_samples: 5,
batch_size: 32,
epochs: 10,
norm: 2,
random_seed: Some(42),
regularization: 0.01,
};
let gcn = GCN::new(config);
assert_eq!(gcn.num_layers, 2);
}
#[test]
fn test_graphsage_creation() {
let config = KGEmbeddingConfig {
model: crate::kg_embeddings::KGEmbeddingModelType::GraphSAGE,
dimensions: 64,
learning_rate: 0.01,
margin: 1.0,
negative_samples: 5,
batch_size: 32,
epochs: 10,
norm: 2,
random_seed: Some(42),
regularization: 0.01,
};
let graphsage = GraphSAGE::new(config);
assert_eq!(graphsage.sample_size, 10);
}
#[test]
fn test_gnn_training() -> Result<()> {
let config = KGEmbeddingConfig {
model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
dimensions: 32,
learning_rate: 0.01,
margin: 1.0,
negative_samples: 5,
batch_size: 16,
epochs: 5,
norm: 2,
random_seed: Some(42),
regularization: 0.01,
};
let mut gcn = GCN::new(config);
let triples = vec![
Triple::new(
"entity1".to_string(),
"relation1".to_string(),
"entity2".to_string(),
),
Triple::new(
"entity2".to_string(),
"relation2".to_string(),
"entity3".to_string(),
),
Triple::new(
"entity1".to_string(),
"relation3".to_string(),
"entity3".to_string(),
),
];
gcn.train(&triples)?;
assert!(gcn.get_entity_embedding("entity1").is_some());
assert!(gcn.get_entity_embedding("entity2").is_some());
assert!(gcn.get_entity_embedding("entity3").is_some());
Ok(())
}
}