Skip to main content

oxirs_core/ai/embeddings/
mod.rs

1//! Knowledge Graph Embeddings for RDF
2//!
3//! This module implements various knowledge graph embedding models including
4//! TransE, DistMult, ComplEx, RotatE, and other state-of-the-art approaches.
5
6use crate::model::Triple;
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9
10pub mod complex;
11pub mod distmult;
12pub mod evaluation;
13pub mod transe;
14
15pub use complex::ComplEx;
16pub use distmult::DistMult;
17pub use evaluation::{
18    ConfidenceIntervals, KnowledgeGraphMetrics, LinkPredictionMetrics, StatisticalTestResults,
19    TaskBreakdownMetrics, TrainingMetrics,
20};
21pub use transe::TransE;
22
23/// Knowledge graph embedding trait
24#[async_trait::async_trait]
25pub trait KnowledgeGraphEmbedding: Send + Sync {
26    /// Generate embeddings for entities and relations
27    async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>>;
28
29    /// Score a triple (head, relation, tail)
30    async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32>;
31
32    /// Predict missing links
33    async fn predict_links(
34        &self,
35        entities: &[String],
36        relations: &[String],
37    ) -> Result<Vec<(String, String, String, f32)>>;
38
39    /// Get entity embedding
40    async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>>;
41
42    /// Get relation embedding
43    async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>>;
44
45    /// Train the embedding model
46    async fn train(
47        &mut self,
48        triples: &[Triple],
49        config: &TrainingConfig,
50    ) -> Result<TrainingMetrics>;
51
52    /// Save model to file
53    async fn save(&self, path: &str) -> Result<()>;
54
55    /// Load model from file
56    async fn load(&mut self, path: &str) -> Result<()>;
57}
58
59/// Embedding model configuration
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct EmbeddingConfig {
62    /// Model type
63    pub model_type: EmbeddingModelType,
64
65    /// Embedding dimension
66    pub embedding_dim: usize,
67
68    /// Learning rate
69    pub learning_rate: f32,
70
71    /// L2 regularization weight
72    pub l2_weight: f32,
73
74    /// Negative sampling ratio
75    pub negative_sampling_ratio: f32,
76
77    /// Training batch size
78    pub batch_size: usize,
79
80    /// Maximum training epochs
81    pub max_epochs: usize,
82
83    /// Early stopping patience
84    pub patience: usize,
85
86    /// Validation split
87    pub validation_split: f32,
88
89    /// Enable GPU acceleration
90    pub use_gpu: bool,
91
92    /// Random seed
93    pub seed: u64,
94}
95
96impl Default for EmbeddingConfig {
97    fn default() -> Self {
98        Self {
99            model_type: EmbeddingModelType::TransE,
100            embedding_dim: 100,
101            learning_rate: 0.001,
102            l2_weight: 1e-5,
103            negative_sampling_ratio: 1.0,
104            batch_size: 1024,
105            max_epochs: 1000,
106            patience: 50,
107            validation_split: 0.1,
108            use_gpu: true,
109            seed: 42,
110        }
111    }
112}
113
114/// Embedding model types
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
116pub enum EmbeddingModelType {
117    /// Translation-based model (Bordes et al., 2013)
118    TransE,
119
120    /// Bilinear model (Yang et al., 2014)
121    DistMult,
122
123    /// Complex embeddings (Trouillon et al., 2016)
124    ComplEx,
125
126    /// Rotation-based model (Sun et al., 2019)
127    RotatE,
128
129    /// Hyperbolic embeddings (Balazevic et al., 2019)
130    HypE,
131
132    /// Tucker decomposition (Balazevic et al., 2019)
133    TuckER,
134
135    /// Convolutional model (Dettmers et al., 2018)
136    ConvE,
137
138    /// Transformer-based model
139    KGTransformer,
140
141    /// Neural tensor network (Socher et al., 2013)
142    NeuralTensorNetwork,
143
144    /// SimplE (Kazemi & Poole, 2018)
145    SimplE,
146}
147
148/// Create embedding model based on configuration
149pub fn create_embedding_model(
150    config: EmbeddingConfig,
151) -> anyhow::Result<std::sync::Arc<dyn KnowledgeGraphEmbedding>> {
152    match config.model_type {
153        EmbeddingModelType::TransE => Ok(std::sync::Arc::new(TransE::new(config))),
154        EmbeddingModelType::DistMult => Ok(std::sync::Arc::new(DistMult::new(config))),
155        EmbeddingModelType::ComplEx => Ok(std::sync::Arc::new(ComplEx::new(config))),
156        _ => Err(anyhow::anyhow!("Embedding model not yet implemented")),
157    }
158}
159
160/// Training configuration
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct TrainingConfig {
163    /// Batch size
164    pub batch_size: usize,
165
166    /// Learning rate
167    pub learning_rate: f32,
168
169    /// Maximum epochs
170    pub max_epochs: usize,
171
172    /// Validation split
173    pub validation_split: f32,
174
175    /// Early stopping patience
176    pub patience: usize,
177}
178
179impl Default for TrainingConfig {
180    fn default() -> Self {
181        Self {
182            batch_size: 1024,
183            learning_rate: 0.001,
184            max_epochs: 1000,
185            validation_split: 0.1,
186            patience: 50,
187        }
188    }
189}