use crate::model::Triple;
use anyhow::Result;
use serde::{Deserialize, Serialize};
pub mod complex;
pub mod conve;
pub mod distmult;
pub mod evaluation;
pub mod hype;
pub mod kgtransformer;
pub mod neural_tensor_network;
pub mod rotate;
pub mod simple;
pub mod transe;
pub mod tucker;
pub use complex::ComplEx;
pub use conve::ConvE;
pub use distmult::DistMult;
pub use evaluation::{
ConfidenceIntervals, KnowledgeGraphMetrics, LinkPredictionMetrics, StatisticalTestResults,
TaskBreakdownMetrics, TrainingMetrics,
};
pub use hype::HypE;
pub use kgtransformer::KGTransformer;
pub use neural_tensor_network::NeuralTensorNetwork;
pub use rotate::RotatE;
pub use simple::SimplE;
pub use transe::TransE;
pub use tucker::TuckER;
#[async_trait::async_trait]
pub trait KnowledgeGraphEmbedding: Send + Sync {
async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>>;
async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32>;
async fn predict_links(
&self,
entities: &[String],
relations: &[String],
) -> Result<Vec<(String, String, String, f32)>>;
async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>>;
async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>>;
async fn train(
&mut self,
triples: &[Triple],
config: &TrainingConfig,
) -> Result<TrainingMetrics>;
async fn save(&self, path: &str) -> Result<()>;
async fn load(&mut self, path: &str) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub model_type: EmbeddingModelType,
pub embedding_dim: usize,
pub learning_rate: f32,
pub l2_weight: f32,
pub negative_sampling_ratio: f32,
pub batch_size: usize,
pub max_epochs: usize,
pub patience: usize,
pub validation_split: f32,
pub use_gpu: bool,
pub seed: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_type: EmbeddingModelType::TransE,
embedding_dim: 100,
learning_rate: 0.001,
l2_weight: 1e-5,
negative_sampling_ratio: 1.0,
batch_size: 1024,
max_epochs: 1000,
patience: 50,
validation_split: 0.1,
use_gpu: true,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum EmbeddingModelType {
TransE,
DistMult,
ComplEx,
RotatE,
HypE,
TuckER,
ConvE,
KGTransformer,
NeuralTensorNetwork,
SimplE,
}
pub fn create_embedding_model(
config: EmbeddingConfig,
) -> anyhow::Result<std::sync::Arc<dyn KnowledgeGraphEmbedding>> {
match config.model_type {
EmbeddingModelType::TransE => Ok(std::sync::Arc::new(TransE::new(config))),
EmbeddingModelType::DistMult => Ok(std::sync::Arc::new(DistMult::new(config))),
EmbeddingModelType::ComplEx => Ok(std::sync::Arc::new(ComplEx::new(config))),
EmbeddingModelType::RotatE => Ok(std::sync::Arc::new(RotatE::new(config))),
EmbeddingModelType::HypE => Ok(std::sync::Arc::new(HypE::new(config))),
EmbeddingModelType::TuckER => Ok(std::sync::Arc::new(TuckER::new(config))),
EmbeddingModelType::ConvE => Ok(std::sync::Arc::new(ConvE::new(config))),
EmbeddingModelType::KGTransformer => Ok(std::sync::Arc::new(KGTransformer::new(config))),
EmbeddingModelType::NeuralTensorNetwork => {
Ok(std::sync::Arc::new(NeuralTensorNetwork::new(config)))
}
EmbeddingModelType::SimplE => Ok(std::sync::Arc::new(SimplE::new(config))),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub batch_size: usize,
pub learning_rate: f32,
pub max_epochs: usize,
pub validation_split: f32,
pub patience: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
batch_size: 1024,
learning_rate: 0.001,
max_epochs: 1000,
validation_split: 0.1,
patience: 50,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_embedding_model_all_variants() {
let variants = [
EmbeddingModelType::TransE,
EmbeddingModelType::DistMult,
EmbeddingModelType::ComplEx,
EmbeddingModelType::RotatE,
EmbeddingModelType::HypE,
EmbeddingModelType::TuckER,
EmbeddingModelType::ConvE,
EmbeddingModelType::KGTransformer,
EmbeddingModelType::NeuralTensorNetwork,
EmbeddingModelType::SimplE,
];
for variant in &variants {
let config = EmbeddingConfig {
model_type: variant.clone(),
embedding_dim: 8,
..Default::default()
};
let result = create_embedding_model(config);
assert!(
result.is_ok(),
"create_embedding_model should succeed for {:?}",
variant
);
}
}
}