use crate::model::Triple;
use anyhow::Result;
use serde::{Deserialize, Serialize};
pub mod complex;
pub mod distmult;
pub mod evaluation;
pub mod transe;
pub use complex::ComplEx;
pub use distmult::DistMult;
pub use evaluation::{
ConfidenceIntervals, KnowledgeGraphMetrics, LinkPredictionMetrics, StatisticalTestResults,
TaskBreakdownMetrics, TrainingMetrics,
};
pub use transe::TransE;
#[async_trait::async_trait]
pub trait KnowledgeGraphEmbedding: Send + Sync {
async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>>;
async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32>;
async fn predict_links(
&self,
entities: &[String],
relations: &[String],
) -> Result<Vec<(String, String, String, f32)>>;
async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>>;
async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>>;
async fn train(
&mut self,
triples: &[Triple],
config: &TrainingConfig,
) -> Result<TrainingMetrics>;
async fn save(&self, path: &str) -> Result<()>;
async fn load(&mut self, path: &str) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub model_type: EmbeddingModelType,
pub embedding_dim: usize,
pub learning_rate: f32,
pub l2_weight: f32,
pub negative_sampling_ratio: f32,
pub batch_size: usize,
pub max_epochs: usize,
pub patience: usize,
pub validation_split: f32,
pub use_gpu: bool,
pub seed: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_type: EmbeddingModelType::TransE,
embedding_dim: 100,
learning_rate: 0.001,
l2_weight: 1e-5,
negative_sampling_ratio: 1.0,
batch_size: 1024,
max_epochs: 1000,
patience: 50,
validation_split: 0.1,
use_gpu: true,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum EmbeddingModelType {
TransE,
DistMult,
ComplEx,
RotatE,
HypE,
TuckER,
ConvE,
KGTransformer,
NeuralTensorNetwork,
SimplE,
}
pub fn create_embedding_model(
config: EmbeddingConfig,
) -> anyhow::Result<std::sync::Arc<dyn KnowledgeGraphEmbedding>> {
match config.model_type {
EmbeddingModelType::TransE => Ok(std::sync::Arc::new(TransE::new(config))),
EmbeddingModelType::DistMult => Ok(std::sync::Arc::new(DistMult::new(config))),
EmbeddingModelType::ComplEx => Ok(std::sync::Arc::new(ComplEx::new(config))),
_ => Err(anyhow::anyhow!("Embedding model not yet implemented")),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub batch_size: usize,
pub learning_rate: f32,
pub max_epochs: usize,
pub validation_split: f32,
pub patience: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
batch_size: 1024,
learning_rate: 0.001,
max_epochs: 1000,
validation_split: 0.1,
patience: 50,
}
}
}