Skip to main content

oxirs_core/ai/embeddings/
distmult.rs

1use super::{
2    EmbeddingConfig, KnowledgeGraphEmbedding, KnowledgeGraphMetrics, TrainingConfig,
3    TrainingMetrics,
4};
5use crate::model::Triple;
6use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::Array1;
8use scirs2_core::random::{Random, RngExt};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// DistMult embedding model
14pub struct DistMult {
15    #[allow(dead_code)]
16    config: EmbeddingConfig,
17    entity_embeddings: Arc<RwLock<HashMap<String, Array1<f32>>>>,
18    relation_embeddings: Arc<RwLock<HashMap<String, Array1<f32>>>>,
19    #[allow(dead_code)]
20    entity_vocab: HashMap<String, usize>,
21    #[allow(dead_code)]
22    relation_vocab: HashMap<String, usize>,
23    trained: bool,
24}
25
26impl DistMult {
27    pub fn new(config: EmbeddingConfig) -> Self {
28        Self {
29            config,
30            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
31            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
32            entity_vocab: HashMap::new(),
33            relation_vocab: HashMap::new(),
34            trained: false,
35        }
36    }
37
38    /// Compute DistMult score: `<h, r, t>` = sum(h * r * t)
39    async fn compute_score(&self, head: &str, relation: &str, tail: &str) -> Result<f32> {
40        let entity_embs = self.entity_embeddings.read().await;
41        let relation_embs = self.relation_embeddings.read().await;
42
43        let h = entity_embs
44            .get(head)
45            .ok_or_else(|| anyhow!("Entity not found: {}", head))?;
46        let r = relation_embs
47            .get(relation)
48            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
49        let t = entity_embs
50            .get(tail)
51            .ok_or_else(|| anyhow!("Entity not found: {}", tail))?;
52
53        // Compute element-wise product and sum
54        let score = (h * r * t).sum();
55
56        Ok(score)
57    }
58
59    /// Initialize embeddings from vocabulary
60    async fn initialize_embeddings(&mut self, triples: &[Triple]) -> Result<()> {
61        let mut entities = HashSet::new();
62        let mut relations = HashSet::new();
63
64        // Collect vocabulary
65        for triple in triples {
66            entities.insert(triple.subject().to_string());
67            entities.insert(triple.object().to_string());
68            relations.insert(triple.predicate().to_string());
69        }
70
71        // Create vocabularies
72        self.entity_vocab = entities
73            .iter()
74            .enumerate()
75            .map(|(i, entity)| (entity.clone(), i))
76            .collect();
77
78        self.relation_vocab = relations
79            .iter()
80            .enumerate()
81            .map(|(i, relation)| (relation.clone(), i))
82            .collect();
83
84        // Initialize embeddings with Xavier initialization
85        let mut entity_embs = self.entity_embeddings.write().await;
86        let mut relation_embs = self.relation_embeddings.write().await;
87
88        let bound = (6.0 / self.config.embedding_dim as f32).sqrt();
89
90        for entity in entities {
91            let embedding = Array1::from_shape_simple_fn(self.config.embedding_dim, || {
92                let mut rng = Random::default();
93                rng.random::<f32>() * 2.0 * bound - bound
94            });
95            entity_embs.insert(entity, embedding);
96        }
97
98        for relation in relations {
99            let embedding = Array1::from_shape_simple_fn(self.config.embedding_dim, || {
100                let mut rng = Random::default();
101                rng.random::<f32>() * 2.0 * bound - bound
102            });
103            relation_embs.insert(relation, embedding);
104        }
105
106        Ok(())
107    }
108
109    /// Calculate accuracy on validation triples
110    async fn calculate_accuracy(&self, triples: &[(String, String, String)]) -> Result<f32> {
111        if triples.is_empty() {
112            return Ok(0.0);
113        }
114
115        let mut correct = 0;
116        let total = triples.len().min(100); // Sample for efficiency
117
118        for triple in triples.iter().take(total) {
119            let positive_score = self.compute_score(&triple.0, &triple.1, &triple.2).await?;
120
121            // Generate a random negative and compare
122            let entities: Vec<String> = self.entity_vocab.keys().cloned().collect();
123            if entities.len() >= 2 {
124                let corrupt_idx = {
125                    let mut rng = Random::default();
126                    rng.random_range(0..entities.len())
127                };
128                let corrupt_entity = &entities[corrupt_idx];
129
130                let should_corrupt_head = {
131                    let mut rng = Random::default();
132                    rng.random_bool_with_chance(0.5)
133                };
134                let negative_score = if should_corrupt_head {
135                    self.compute_score(corrupt_entity, &triple.1, &triple.2)
136                        .await?
137                } else {
138                    self.compute_score(&triple.0, &triple.1, corrupt_entity)
139                        .await?
140                };
141
142                // For DistMult, higher score is better
143                if positive_score > negative_score {
144                    correct += 1;
145                }
146            }
147        }
148
149        Ok(correct as f32 / total as f32)
150    }
151}
152
153#[async_trait::async_trait]
154impl KnowledgeGraphEmbedding for DistMult {
155    async fn generate_embeddings(&self, triples: &[Triple]) -> Result<Vec<Vec<f32>>> {
156        // Similar to TransE but with different scoring function
157        let entity_embs = self.entity_embeddings.read().await;
158        let mut embeddings = Vec::new();
159
160        for triple in triples {
161            let subject_str = triple.subject().to_string();
162            let object_str = triple.object().to_string();
163            let head_emb = entity_embs
164                .get(&subject_str)
165                .ok_or_else(|| anyhow!("Entity not found"))?;
166            let tail_emb = entity_embs
167                .get(&object_str)
168                .ok_or_else(|| anyhow!("Entity not found"))?;
169
170            let combined: Vec<f32> = head_emb
171                .iter()
172                .zip(tail_emb.iter())
173                .map(|(h, t)| h * t) // Element-wise product for DistMult
174                .collect();
175
176            embeddings.push(combined);
177        }
178
179        Ok(embeddings)
180    }
181
182    async fn score_triple(&self, head: &str, relation: &str, tail: &str) -> Result<f32> {
183        self.compute_score(head, relation, tail).await
184    }
185
186    async fn predict_links(
187        &self,
188        entities: &[String],
189        relations: &[String],
190    ) -> Result<Vec<(String, String, String, f32)>> {
191        let mut predictions = Vec::new();
192
193        for head in entities {
194            for relation in relations {
195                for tail in entities {
196                    if head != tail {
197                        let score = self.score_triple(head, relation, tail).await?;
198                        predictions.push((head.clone(), relation.clone(), tail.clone(), score));
199                    }
200                }
201            }
202        }
203
204        // Sort by score (higher is better for DistMult)
205        predictions.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
206
207        Ok(predictions)
208    }
209
210    async fn get_entity_embedding(&self, entity: &str) -> Result<Vec<f32>> {
211        let entity_embs = self.entity_embeddings.read().await;
212        let embedding = entity_embs
213            .get(entity)
214            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
215        Ok(embedding.to_vec())
216    }
217
218    async fn get_relation_embedding(&self, relation: &str) -> Result<Vec<f32>> {
219        let relation_embs = self.relation_embeddings.read().await;
220        let embedding = relation_embs
221            .get(relation)
222            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
223        Ok(embedding.to_vec())
224    }
225
226    async fn train(
227        &mut self,
228        triples: &[Triple],
229        _config: &TrainingConfig,
230    ) -> Result<TrainingMetrics> {
231        // Initialize embeddings similar to TransE
232        self.initialize_embeddings(triples).await?;
233
234        // Convert triples to string format
235        let triple_strings: Vec<(String, String, String)> = triples
236            .iter()
237            .map(|t| {
238                (
239                    t.subject().to_string(),
240                    t.predicate().to_string(),
241                    t.object().to_string(),
242                )
243            })
244            .collect();
245
246        let mut total_loss = 0.0;
247
248        for _epoch in 0..self.config.max_epochs {
249            let mut epoch_loss = 0.0;
250
251            // Simplified training - in practice would use proper SGD with gradients
252            for triple in &triple_strings {
253                let score = self.compute_score(&triple.0, &triple.1, &triple.2).await?;
254
255                // For DistMult, we want to maximize the score for positive triples
256                // This is a simplified loss - negative log-likelihood would be better
257                epoch_loss += (1.0 - score).max(0.0);
258            }
259
260            total_loss = epoch_loss / triple_strings.len() as f32;
261
262            // Early stopping
263            if total_loss < 1e-6 {
264                break;
265            }
266        }
267
268        self.trained = true;
269
270        // Calculate accuracy on validation set
271        let accuracy = self.calculate_accuracy(&triple_strings).await?;
272
273        Ok(TrainingMetrics {
274            loss: total_loss,
275            loss_history: vec![total_loss],
276            accuracy,
277            epochs: self.config.max_epochs,
278            time_elapsed: std::time::Duration::from_secs(0),
279            kg_metrics: KnowledgeGraphMetrics::default(),
280        })
281    }
282
283    async fn save(&self, _path: &str) -> Result<()> {
284        Ok(())
285    }
286
287    async fn load(&mut self, _path: &str) -> Result<()> {
288        Ok(())
289    }
290}