use crate::error::{FusekiError, FusekiResult};
use scirs2_core::gpu::{GpuBackend, GpuBuffer, GpuContext, GpuKernel};
use scirs2_core::memory::BufferPool;
use scirs2_core::metrics::{Counter, Histogram, Timer};
use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::random::{Random, RngExt};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbeddingModel {
TransE,
DistMult,
ComplEx,
RotatE,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuBackendType {
Cuda,
Metal,
Cpu,
}
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub embedding_dim: usize,
pub learning_rate: f32,
pub batch_size: usize,
pub num_negatives: usize,
pub model: EmbeddingModel,
pub backend: GpuBackendType,
pub use_mixed_precision: bool,
pub use_tensor_cores: bool,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
embedding_dim: 128,
learning_rate: 0.01,
batch_size: 1024,
num_negatives: 10,
model: EmbeddingModel::TransE,
backend: GpuBackendType::Cuda,
use_mixed_precision: true,
use_tensor_cores: true,
}
}
}
#[derive(Debug, Clone)]
pub struct KgTriple {
pub subject_id: usize,
pub predicate_id: usize,
pub object_id: usize,
}
pub struct GpuEmbeddingGenerator {
config: EmbeddingConfig,
entity_embeddings: Array2<f32>,
relation_embeddings: Array2<f32>,
entity_to_id: HashMap<String, usize>,
id_to_entity: HashMap<usize, String>,
relation_to_id: HashMap<String, usize>,
id_to_relation: HashMap<usize, String>,
gpu_context: Option<Arc<GpuContext>>,
rng: Random,
training_time_histogram: Histogram,
gpu_operations_counter: Counter,
tensor_core_ops_counter: Counter,
}
impl GpuEmbeddingGenerator {
pub fn new(config: EmbeddingConfig) -> FusekiResult<Self> {
let rng = Random::default();
let entity_embeddings = Array2::zeros((0, config.embedding_dim));
let relation_embeddings = Array2::zeros((0, config.embedding_dim));
let gpu_context = match config.backend {
GpuBackendType::Cuda | GpuBackendType::Metal => {
None }
GpuBackendType::Cpu => None,
};
Ok(Self {
config,
entity_embeddings,
relation_embeddings,
entity_to_id: HashMap::new(),
id_to_entity: HashMap::new(),
relation_to_id: HashMap::new(),
id_to_relation: HashMap::new(),
gpu_context,
rng,
training_time_histogram: Histogram::new("embedding_training_time_ms".to_string()),
gpu_operations_counter: Counter::new("gpu_embedding_operations".to_string()),
tensor_core_ops_counter: Counter::new("tensor_core_operations".to_string()),
})
}
pub fn initialize_from_triples(
&mut self,
triples: &[(String, String, String)],
) -> FusekiResult<()> {
for (subject, predicate, object) in triples {
if !self.entity_to_id.contains_key(subject) {
let id = self.entity_to_id.len();
self.entity_to_id.insert(subject.clone(), id);
self.id_to_entity.insert(id, subject.clone());
}
if !self.relation_to_id.contains_key(predicate) {
let id = self.relation_to_id.len();
self.relation_to_id.insert(predicate.clone(), id);
self.id_to_relation.insert(id, predicate.clone());
}
if !self.entity_to_id.contains_key(object) {
let id = self.entity_to_id.len();
self.entity_to_id.insert(object.clone(), id);
self.id_to_entity.insert(id, object.clone());
}
}
let num_entities = self.entity_to_id.len();
let num_relations = self.relation_to_id.len();
self.entity_embeddings = self.random_embeddings(num_entities, self.config.embedding_dim);
self.relation_embeddings = self.random_embeddings(num_relations, self.config.embedding_dim);
Ok(())
}
fn random_embeddings(&mut self, num_vectors: usize, dim: usize) -> Array2<f32> {
let mut embeddings = Array2::zeros((num_vectors, dim));
for i in 0..num_vectors {
for j in 0..dim {
embeddings[[i, j]] = self.rng.random_range(-0.1..0.1);
}
let norm: f32 = embeddings.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for j in 0..dim {
embeddings[[i, j]] /= norm;
}
}
}
embeddings
}
pub fn train(
&mut self,
triples: &[(String, String, String)],
epochs: usize,
) -> FusekiResult<TrainingMetrics> {
let start_time = std::time::Instant::now();
let kg_triples: Vec<KgTriple> = triples
.iter()
.filter_map(|(s, p, o)| {
let subject_id = *self.entity_to_id.get(s)?;
let predicate_id = *self.relation_to_id.get(p)?;
let object_id = *self.entity_to_id.get(o)?;
Some(KgTriple {
subject_id,
predicate_id,
object_id,
})
})
.collect();
let mut total_loss = 0.0;
let mut num_batches = 0;
for epoch in 0..epochs {
let epoch_loss = if self.gpu_context.is_some() {
self.train_epoch_gpu(&kg_triples)?
} else {
self.train_epoch_cpu(&kg_triples)?
};
total_loss += epoch_loss;
num_batches += 1;
if epoch % 10 == 0 {
tracing::info!("Epoch {}/{}: loss = {:.4}", epoch, epochs, epoch_loss);
}
}
let training_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
self.training_time_histogram
.observe(training_time_ms / 1000.0);
Ok(TrainingMetrics {
epochs,
average_loss: total_loss / num_batches as f64,
training_time_ms,
gpu_accelerated: self.gpu_context.is_some(),
tensor_core_used: false, })
}
fn train_epoch_gpu(&mut self, triples: &[KgTriple]) -> FusekiResult<f64> {
self.gpu_operations_counter.inc();
self.train_epoch_cpu(triples)
}
fn train_epoch_cpu(&mut self, triples: &[KgTriple]) -> FusekiResult<f64> {
let mut total_loss = 0.0;
let mut num_samples = 0;
for batch_start in (0..triples.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(triples.len());
let batch = &triples[batch_start..batch_end];
for triple in batch {
let pos_score = self.score_triple(triple);
let mut neg_scores = Vec::new();
for _ in 0..self.config.num_negatives {
let neg_triple = if self.rng.random::<f32>() < 0.5 {
KgTriple {
subject_id: self.rng.random_range(0..self.entity_embeddings.nrows()),
predicate_id: triple.predicate_id,
object_id: triple.object_id,
}
} else {
KgTriple {
subject_id: triple.subject_id,
predicate_id: triple.predicate_id,
object_id: self.rng.random_range(0..self.entity_embeddings.nrows()),
}
};
neg_scores.push(self.score_triple(&neg_triple));
}
let margin = 1.0;
for neg_score in neg_scores {
let loss = (margin + pos_score - neg_score).max(0.0);
total_loss += loss as f64;
num_samples += 1;
}
}
}
Ok(total_loss / num_samples as f64)
}
fn score_triple(&self, triple: &KgTriple) -> f32 {
match self.config.model {
EmbeddingModel::TransE => self.score_transe(triple),
EmbeddingModel::DistMult => self.score_distmult(triple),
EmbeddingModel::ComplEx => self.score_complex(triple),
EmbeddingModel::RotatE => self.score_rotate(triple),
}
}
fn score_transe(&self, triple: &KgTriple) -> f32 {
let h = self.entity_embeddings.row(triple.subject_id);
let r = self.relation_embeddings.row(triple.predicate_id);
let t = self.entity_embeddings.row(triple.object_id);
let mut distance = 0.0;
for i in 0..self.config.embedding_dim {
let diff = h[i] + r[i] - t[i];
distance += diff * diff;
}
-distance.sqrt() }
fn score_distmult(&self, triple: &KgTriple) -> f32 {
let h = self.entity_embeddings.row(triple.subject_id);
let r = self.relation_embeddings.row(triple.predicate_id);
let t = self.entity_embeddings.row(triple.object_id);
let mut score = 0.0;
for i in 0..self.config.embedding_dim {
score += h[i] * r[i] * t[i];
}
score
}
fn score_complex(&self, triple: &KgTriple) -> f32 {
self.score_distmult(triple)
}
fn score_rotate(&self, triple: &KgTriple) -> f32 {
self.score_transe(triple)
}
pub fn get_entity_embedding(&self, entity: &str) -> Option<Array1<f32>> {
let id = self.entity_to_id.get(entity)?;
Some(self.entity_embeddings.row(*id).to_owned())
}
pub fn get_relation_embedding(&self, relation: &str) -> Option<Array1<f32>> {
let id = self.relation_to_id.get(relation)?;
Some(self.relation_embeddings.row(*id).to_owned())
}
pub fn find_similar_entities(&self, entity: &str, top_k: usize) -> Vec<(String, f32)> {
let Some(embedding) = self.get_entity_embedding(entity) else {
return Vec::new();
};
let mut similarities: Vec<(String, f32)> = self
.entity_to_id
.iter()
.filter(|(e, _)| *e != entity)
.map(|(e, &id)| {
let other_embedding = self.entity_embeddings.row(id);
let similarity = self.cosine_similarity(embedding.view(), other_embedding);
(e.clone(), similarity)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(top_k);
similarities
}
fn cosine_similarity(&self, a: ArrayView1<f32>, b: ArrayView1<f32>) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn get_statistics(&self) -> EmbeddingStatistics {
EmbeddingStatistics {
num_entities: self.entity_to_id.len(),
num_relations: self.relation_to_id.len(),
embedding_dim: self.config.embedding_dim,
model: self.config.model,
gpu_enabled: self.gpu_context.is_some(),
tensor_core_enabled: false, total_parameters: (self.entity_to_id.len() + self.relation_to_id.len())
* self.config.embedding_dim,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct TrainingMetrics {
pub epochs: usize,
pub average_loss: f64,
pub training_time_ms: f64,
pub gpu_accelerated: bool,
pub tensor_core_used: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct EmbeddingStatistics {
pub num_entities: usize,
pub num_relations: usize,
pub embedding_dim: usize,
#[serde(skip)]
pub model: EmbeddingModel,
pub gpu_enabled: bool,
pub tensor_core_enabled: bool,
pub total_parameters: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_generator_creation() {
let config = EmbeddingConfig::default();
let generator = GpuEmbeddingGenerator::new(config).unwrap();
let stats = generator.get_statistics();
assert_eq!(stats.num_entities, 0);
assert_eq!(stats.num_relations, 0);
}
#[test]
fn test_initialize_from_triples() {
let config = EmbeddingConfig::default();
let mut generator = GpuEmbeddingGenerator::new(config).unwrap();
let triples = vec![
("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
(
"Bob".to_string(),
"knows".to_string(),
"Charlie".to_string(),
),
];
generator.initialize_from_triples(&triples).unwrap();
let stats = generator.get_statistics();
assert_eq!(stats.num_entities, 3); assert_eq!(stats.num_relations, 1); }
#[test]
fn test_get_entity_embedding() {
let config = EmbeddingConfig::default();
let mut generator = GpuEmbeddingGenerator::new(config).unwrap();
let triples = vec![("Alice".to_string(), "knows".to_string(), "Bob".to_string())];
generator.initialize_from_triples(&triples).unwrap();
let embedding = generator.get_entity_embedding("Alice");
assert!(embedding.is_some());
assert_eq!(embedding.unwrap().len(), 128);
}
#[test]
fn test_training() {
let config = EmbeddingConfig {
backend: GpuBackendType::Cpu, batch_size: 2,
..Default::default()
};
let mut generator = GpuEmbeddingGenerator::new(config).unwrap();
let triples = vec![
("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
(
"Bob".to_string(),
"knows".to_string(),
"Charlie".to_string(),
),
(
"Charlie".to_string(),
"works_at".to_string(),
"Company".to_string(),
),
];
generator.initialize_from_triples(&triples).unwrap();
let metrics = generator.train(&triples, 5).unwrap();
assert_eq!(metrics.epochs, 5);
assert!(metrics.average_loss >= 0.0);
}
#[test]
fn test_find_similar_entities() {
let config = EmbeddingConfig {
backend: GpuBackendType::Cpu,
..Default::default()
};
let mut generator = GpuEmbeddingGenerator::new(config).unwrap();
let triples = vec![
("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
(
"Bob".to_string(),
"knows".to_string(),
"Charlie".to_string(),
),
("David".to_string(), "knows".to_string(), "Eve".to_string()),
];
generator.initialize_from_triples(&triples).unwrap();
generator.train(&triples, 2).unwrap();
let similar = generator.find_similar_entities("Alice", 2);
assert!(similar.len() <= 2);
}
#[test]
fn test_cosine_similarity() {
let config = EmbeddingConfig::default();
let generator = GpuEmbeddingGenerator::new(config).unwrap();
let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
let b = Array1::from_vec(vec![0.0, 1.0, 0.0]);
let c = Array1::from_vec(vec![1.0, 0.0, 0.0]);
let sim_ab = generator.cosine_similarity(a.view(), b.view());
let sim_ac = generator.cosine_similarity(a.view(), c.view());
assert!((sim_ab - 0.0).abs() < 0.001);
assert!((sim_ac - 1.0).abs() < 0.001);
}
}