oxirs_core/ai/embeddings/
distmult.rs1use super::{
2 EmbeddingConfig, KnowledgeGraphEmbedding, KnowledgeGraphMetrics, TrainingConfig,
3 TrainingMetrics,
4};
5use crate::model::Triple;
6use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::Array1;
8use scirs2_core::random::{Random, RngExt};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub struct DistMult {
15 #[allow(dead_code)]
16 config: EmbeddingConfig,
17 entity_embeddings: Arc<RwLock<HashMap<String, Array1<f32>>>>,
18 relation_embeddings: Arc<RwLock<HashMap<String, Array1<f32>>>>,
19 #[allow(dead_code)]
20 entity_vocab: HashMap<String, usize>,
21 #[allow(dead_code)]
22 relation_vocab: HashMap<String, usize>,
23 trained: bool,
24}
25
26impl DistMult {
27 pub fn new(config: EmbeddingConfig) -> Self {
28 Self {
29 config,
30 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
31 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
32 entity_vocab: HashMap::new(),
33 relation_vocab: HashMap::new(),
34 trained: false,
35 }
36 }
37
38 async fn compute_score(&self, head: &str, relation: &str, tail: &str) -> Result<f32> {
40 let entity_embs = self.entity_embeddings.read().await;
41 let relation_embs = self.relation_embeddings.read().await;
42
43 let h = entity_embs
44 .get(head)
45 .ok_or_else(|| anyhow!("Entity not found: {}", head))?;
46 let r = relation_embs
47 .get(relation)
48 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
49 let t = entity_embs
50 .get(tail)
51 .ok_or_else(|| anyhow!("Entity not found: {}", tail))?;
52
53 let score = (h * r * t).sum();
55
56 Ok(score)
57 }
58
59 async fn initialize_embeddings(&mut self, triples: &[Triple]) -> Result<()> {
61 let mut entities = HashSet::new();
62 let mut relations = HashSet::new();
63
64 for triple in triples {
66 entities.insert(triple.subject().to_string());
67 entities.insert(triple.object().to_string());
68 relations.insert(triple.predicate().to_string());
69 }
70
71 self.entity_vocab = entities
73 .iter()
74 .enumerate()
75 .map(|(i, entity)| (entity.clone(), i))
76 .collect();
77
78 self.relation_vocab = relations
79 .iter()
80 .enumerate()
81 .map(|(i, relation)| (relation.clone(), i))
82 .collect();
83
84 let mut entity_embs = self.entity_embeddings.write().await;
86 let mut relation_embs = self.relation_embeddings.write().await;
87
88 let bound = (6.0 / self.config.embedding_dim as f32).sqrt();
89
90 for entity in entities {
91 let embedding = Array1::from_shape_simple_fn(self.config.embedding_dim, || {
92 let mut rng = Random::default();
93 rng.random::<f32>() * 2.0 * bound - bound
94 });
95 entity_embs.insert(entity, embedding);
96 }
97
98 for relation in relations {
99 let embedding = Array1::from_shape_simple_fn(self.config.embedding_dim, || {
100 let mut rng = Random::default();
101 rng.random::<f32>() * 2.0 * bound - bound
102 });
103 relation_embs.insert(relation, embedding);
104 }
105
106 Ok(())
107 }
108
109 async fn calculate_accuracy(&self, triples: &[(String, String, String)]) -> Result<f32> {
111 if triples.is_empty() {
112 return Ok(0.0);
113 }
114
115 let mut correct = 0;
116 let total = triples.len().min(100); for triple in triples.iter().take(total) {
119 let positive_score = self.compute_score(&triple.0, &triple.1, &triple.2).await?;
120
121 let entities: Vec<String> = self.entity_vocab.keys().cloned().collect();
123 if entities.len() >= 2 {
124 let corrupt_idx = {
125 let mut rng = Random::default();
126 rng.random_range(0..entities.len())
127 };
128 let corrupt_entity = &entities[corrupt_idx];
129
130 let should_corrupt_head = {
131 let mut rng = Random::default();
132 rng.random_bool_with_chance(0.5)
133 };
134 let negative_score = if should_corrupt_head {
135 self.compute_score(corrupt_entity, &triple.1, &triple.2)
136 .await?
137 } else {
138 self.compute_score(&triple.0, &triple.1, corrupt_entity)
139 .await?
140 };
141
142 if positive_score > negative_score {
144 correct += 1;
145 }
146 }
147 }
148
149 Ok(correct as f32 / total as f32)
150 }
151}
152
153#[async_trait::async_trait]
154impl KnowledgeGraphEmbedding for DistMult {
155 async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>> {
156 let entity_embs = self.entity_embeddings.read().await;
158 let mut embeddings = Vec::new();
159
160 for triple in triples {
161 let subject_str = triple.subject().to_string();
162 let object_str = triple.object().to_string();
163 let head_emb = entity_embs
164 .get(&subject_str)
165 .ok_or_else(|| anyhow!("Entity not found"))?;
166 let tail_emb = entity_embs
167 .get(&object_str)
168 .ok_or_else(|| anyhow!("Entity not found"))?;
169
170 let combined: Vec<f32> = head_emb
171 .iter()
172 .zip(tail_emb.iter())
173 .map(|(h, t)| h * t) .collect();
175
176 embeddings.push(combined);
177 }
178
179 Ok(embeddings)
180 }
181
182 async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32> {
183 self.compute_score(head, relation, tail).await
184 }
185
186 async fn predict_links(
187 &self,
188 entities: &[String],
189 relations: &[String],
190 ) -> Result<Vec<(String, String, String, f32)>> {
191 let mut predictions = Vec::new();
192
193 for head in entities {
194 for relation in relations {
195 for tail in entities {
196 if head != tail {
197 let score = self.score_triple(head, relation, tail).await?;
198 predictions.push((head.clone(), relation.clone(), tail.clone(), score));
199 }
200 }
201 }
202 }
203
204 predictions.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
206
207 Ok(predictions)
208 }
209
210 async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>> {
211 let entity_embs = self.entity_embeddings.read().await;
212 let embedding = entity_embs
213 .get(entity)
214 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
215 Ok(embedding.to_vec())
216 }
217
218 async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>> {
219 let relation_embs = self.relation_embeddings.read().await;
220 let embedding = relation_embs
221 .get(relation)
222 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
223 Ok(embedding.to_vec())
224 }
225
226 async fn train(
227 &mut self,
228 triples: &[Triple],
229 _config: &TrainingConfig,
230 ) -> Result<TrainingMetrics> {
231 self.initialize_embeddings(triples).await?;
233
234 let triple_strings: Vec<(String, String, String)> = triples
236 .iter()
237 .map(|t| {
238 (
239 t.subject().to_string(),
240 t.predicate().to_string(),
241 t.object().to_string(),
242 )
243 })
244 .collect();
245
246 let mut total_loss = 0.0;
247
248 for _epoch in 0..self.config.max_epochs {
249 let mut epoch_loss = 0.0;
250
251 for triple in &triple_strings {
253 let score = self.compute_score(&triple.0, &triple.1, &triple.2).await?;
254
255 epoch_loss += (1.0 - score).max(0.0);
258 }
259
260 total_loss = epoch_loss / triple_strings.len() as f32;
261
262 if total_loss < 1e-6 {
264 break;
265 }
266 }
267
268 self.trained = true;
269
270 let accuracy = self.calculate_accuracy(&triple_strings).await?;
272
273 Ok(TrainingMetrics {
274 loss: total_loss,
275 loss_history: vec![total_loss],
276 accuracy,
277 epochs: self.config.max_epochs,
278 time_elapsed: std::time::Duration::from_secs(0),
279 kg_metrics: KnowledgeGraphMetrics::default(),
280 })
281 }
282
283 async fn save(&self, _path: &str) -> Result<()> {
284 Ok(())
285 }
286
287 async fn load(&mut self, _path: &str) -> Result<()> {
288 Ok(())
289 }
290}