oxirs_core/ai/embeddings/
mod.rs1use crate::model::Triple;
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9
10pub mod complex;
11pub mod distmult;
12pub mod evaluation;
13pub mod transe;
14
15pub use complex::ComplEx;
16pub use distmult::DistMult;
17pub use evaluation::{
18 ConfidenceIntervals, KnowledgeGraphMetrics, LinkPredictionMetrics, StatisticalTestResults,
19 TaskBreakdownMetrics, TrainingMetrics,
20};
21pub use transe::TransE;
22
23#[async_trait::async_trait]
25pub trait KnowledgeGraphEmbedding: Send + Sync {
26 async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>>;
28
29 async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32>;
31
32 async fn predict_links(
34 &self,
35 entities: &[String],
36 relations: &[String],
37 ) -> Result<Vec<(String, String, String, f32)>>;
38
39 async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>>;
41
42 async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>>;
44
45 async fn train(
47 &mut self,
48 triples: &[Triple],
49 config: &TrainingConfig,
50 ) -> Result<TrainingMetrics>;
51
52 async fn save(&self, path: &str) -> Result<()>;
54
55 async fn load(&mut self, path: &str) -> Result<()>;
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct EmbeddingConfig {
62 pub model_type: EmbeddingModelType,
64
65 pub embedding_dim: usize,
67
68 pub learning_rate: f32,
70
71 pub l2_weight: f32,
73
74 pub negative_sampling_ratio: f32,
76
77 pub batch_size: usize,
79
80 pub max_epochs: usize,
82
83 pub patience: usize,
85
86 pub validation_split: f32,
88
89 pub use_gpu: bool,
91
92 pub seed: u64,
94}
95
96impl Default for EmbeddingConfig {
97 fn default() -> Self {
98 Self {
99 model_type: EmbeddingModelType::TransE,
100 embedding_dim: 100,
101 learning_rate: 0.001,
102 l2_weight: 1e-5,
103 negative_sampling_ratio: 1.0,
104 batch_size: 1024,
105 max_epochs: 1000,
106 patience: 50,
107 validation_split: 0.1,
108 use_gpu: true,
109 seed: 42,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
116pub enum EmbeddingModelType {
117 TransE,
119
120 DistMult,
122
123 ComplEx,
125
126 RotatE,
128
129 HypE,
131
132 TuckER,
134
135 ConvE,
137
138 KGTransformer,
140
141 NeuralTensorNetwork,
143
144 SimplE,
146}
147
148pub fn create_embedding_model(
150 config: EmbeddingConfig,
151) -> anyhow::Result<std::sync::Arc<dyn KnowledgeGraphEmbedding>> {
152 match config.model_type {
153 EmbeddingModelType::TransE => Ok(std::sync::Arc::new(TransE::new(config))),
154 EmbeddingModelType::DistMult => Ok(std::sync::Arc::new(DistMult::new(config))),
155 EmbeddingModelType::ComplEx => Ok(std::sync::Arc::new(ComplEx::new(config))),
156 _ => Err(anyhow::anyhow!("Embedding model not yet implemented")),
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct TrainingConfig {
163 pub batch_size: usize,
165
166 pub learning_rate: f32,
168
169 pub max_epochs: usize,
171
172 pub validation_split: f32,
174
175 pub patience: usize,
177}
178
179impl Default for TrainingConfig {
180 fn default() -> Self {
181 Self {
182 batch_size: 1024,
183 learning_rate: 0.001,
184 max_epochs: 1000,
185 validation_split: 0.1,
186 patience: 50,
187 }
188 }
189}