oxirs_embed/models/
distmult.rs

1//! DistMult: Embedding Entities and Relations for Learning and Inference in Knowledge Bases
2//!
3//! DistMult is a bilinear model that uses element-wise multiplication:
4//! score(h, r, t) = h^T diag(r) t = sum(h * r * t)
5//!
6//! Note: DistMult can only model symmetric relations due to its formulation.
7//!
8//! Reference: Yang et al. "Embedding Entities and Relations for Learning and Inference in Knowledge Bases" (2014)
9
10use crate::models::{common::*, BaseModel};
11use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
12use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use scirs2_core::ndarray_ext::{Array1, Array2};
15#[allow(unused_imports)]
16use scirs2_core::random::{Random, Rng};
17use serde::{Deserialize, Serialize};
18use std::ops::AddAssign;
19use std::time::Instant;
20use tracing::{debug, info, warn};
21use uuid::Uuid;
22
23/// DistMult embedding model
24#[derive(Debug)]
25pub struct DistMult {
26    /// Base model functionality
27    base: BaseModel,
28    /// Entity embeddings matrix (num_entities × dimensions)
29    entity_embeddings: Array2<f64>,
30    /// Relation embeddings matrix (num_relations × dimensions)
31    relation_embeddings: Array2<f64>,
32    /// Whether embeddings have been initialized
33    embeddings_initialized: bool,
34    /// Whether to apply dropout during training
35    #[allow(dead_code)]
36    dropout_rate: f64,
37    /// Loss function type
38    loss_function: LossFunction,
39}
40
41/// Loss function types for DistMult
42#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
43pub enum LossFunction {
44    /// Logistic loss (binary cross-entropy)
45    Logistic,
46    /// Margin-based ranking loss
47    MarginRanking,
48    /// Squared loss
49    SquaredLoss,
50}
51
52impl DistMult {
53    /// Create a new DistMult model
54    pub fn new(config: ModelConfig) -> Self {
55        let base = BaseModel::new(config.clone());
56
57        // Get DistMult-specific parameters
58        let dropout_rate = config
59            .model_params
60            .get("dropout_rate")
61            .copied()
62            .unwrap_or(0.0);
63
64        let loss_function = match config.model_params.get("loss_function") {
65            Some(0.0) => LossFunction::Logistic,
66            Some(1.0) => LossFunction::MarginRanking,
67            Some(2.0) => LossFunction::SquaredLoss,
68            _ => LossFunction::Logistic, // Default to logistic
69        };
70
71        Self {
72            base,
73            entity_embeddings: Array2::zeros((0, config.dimensions)),
74            relation_embeddings: Array2::zeros((0, config.dimensions)),
75            embeddings_initialized: false,
76            dropout_rate,
77            loss_function,
78        }
79    }
80
81    /// Initialize embeddings
82    fn initialize_embeddings(&mut self) {
83        if self.embeddings_initialized {
84            return;
85        }
86
87        let num_entities = self.base.num_entities();
88        let num_relations = self.base.num_relations();
89        let dimensions = self.base.config.dimensions;
90
91        if num_entities == 0 || num_relations == 0 {
92            return;
93        }
94
95        let mut rng = Random::default();
96
97        // Initialize embeddings with Xavier initialization
98        self.entity_embeddings =
99            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
100
101        self.relation_embeddings = xavier_init(
102            (num_relations, dimensions),
103            dimensions,
104            dimensions,
105            &mut rng,
106        );
107
108        // For DistMult, it's common to normalize entity embeddings initially
109        normalize_embeddings(&mut self.entity_embeddings);
110
111        self.embeddings_initialized = true;
112        debug!(
113            "Initialized DistMult embeddings: {} entities, {} relations, {} dimensions",
114            num_entities, num_relations, dimensions
115        );
116    }
117
118    /// Score a triple using DistMult scoring function: sum(h * r * t)
119    fn score_triple_ids(
120        &self,
121        subject_id: usize,
122        predicate_id: usize,
123        object_id: usize,
124    ) -> Result<f64> {
125        if !self.embeddings_initialized {
126            return Err(anyhow!("Model not trained"));
127        }
128
129        let h = self.entity_embeddings.row(subject_id);
130        let r = self.relation_embeddings.row(predicate_id);
131        let t = self.entity_embeddings.row(object_id);
132
133        // DistMult score: sum of element-wise multiplication
134        let score = (&h * &r * t).sum();
135
136        Ok(score)
137    }
138
139    /// Apply dropout to embeddings during training
140    #[allow(dead_code)]
141    fn apply_dropout(&self, embeddings: &Array1<f64>, rng: &mut Random) -> Array1<f64> {
142        if self.dropout_rate > 0.0 {
143            embeddings.mapv(|x| {
144                if rng.random_f64() < self.dropout_rate {
145                    0.0
146                } else {
147                    x / (1.0 - self.dropout_rate)
148                }
149            })
150        } else {
151            embeddings.to_owned()
152        }
153    }
154
155    /// Compute gradients for DistMult model
156    fn compute_gradients(
157        &self,
158        pos_triple: (usize, usize, usize),
159        neg_triple: (usize, usize, usize),
160        pos_score: f64,
161        neg_score: f64,
162    ) -> Result<(Array2<f64>, Array2<f64>)> {
163        let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
164        let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
165
166        match self.loss_function {
167            LossFunction::Logistic => {
168                // Logistic loss gradients
169                let pos_sigmoid = sigmoid(pos_score);
170                let neg_sigmoid = sigmoid(neg_score);
171
172                let pos_grad_coeff = pos_sigmoid - 1.0; // d/dx log(sigmoid(x)) = sigmoid(x) - 1
173                let neg_grad_coeff = neg_sigmoid; // d/dx log(1 - sigmoid(x)) = sigmoid(x)
174
175                self.add_triple_gradients(
176                    pos_triple,
177                    pos_grad_coeff,
178                    &mut entity_grads,
179                    &mut relation_grads,
180                );
181                self.add_triple_gradients(
182                    neg_triple,
183                    neg_grad_coeff,
184                    &mut entity_grads,
185                    &mut relation_grads,
186                );
187            }
188            LossFunction::MarginRanking => {
189                // Margin ranking loss: max(0, margin + neg_score - pos_score)
190                let margin = self
191                    .base
192                    .config
193                    .model_params
194                    .get("margin")
195                    .copied()
196                    .unwrap_or(1.0);
197                let loss = margin + neg_score - pos_score;
198
199                if loss > 0.0 {
200                    // Gradients: -1 for positive triple, +1 for negative triple
201                    self.add_triple_gradients(
202                        pos_triple,
203                        -1.0,
204                        &mut entity_grads,
205                        &mut relation_grads,
206                    );
207                    self.add_triple_gradients(
208                        neg_triple,
209                        1.0,
210                        &mut entity_grads,
211                        &mut relation_grads,
212                    );
213                }
214            }
215            LossFunction::SquaredLoss => {
216                // Squared loss: (1 - pos_score)^2 + (0 - neg_score)^2
217                let pos_grad_coeff = -2.0 * (1.0 - pos_score);
218                let neg_grad_coeff = -2.0 * neg_score;
219
220                self.add_triple_gradients(
221                    pos_triple,
222                    pos_grad_coeff,
223                    &mut entity_grads,
224                    &mut relation_grads,
225                );
226                self.add_triple_gradients(
227                    neg_triple,
228                    neg_grad_coeff,
229                    &mut entity_grads,
230                    &mut relation_grads,
231                );
232            }
233        }
234
235        Ok((entity_grads, relation_grads))
236    }
237
238    /// Add gradients for a single triple
239    fn add_triple_gradients(
240        &self,
241        triple: (usize, usize, usize),
242        grad_coeff: f64,
243        entity_grads: &mut Array2<f64>,
244        relation_grads: &mut Array2<f64>,
245    ) {
246        let (s, p, o) = triple;
247
248        let h = self.entity_embeddings.row(s);
249        let r = self.relation_embeddings.row(p);
250        let t = self.entity_embeddings.row(o);
251
252        // DistMult gradients:
253        // ∂score/∂h = r * t
254        // ∂score/∂r = h * t
255        // ∂score/∂t = h * r
256
257        let h_grad = (&r * &t) * grad_coeff;
258        let r_grad = (&h * &t) * grad_coeff;
259        let t_grad = (&h * &r) * grad_coeff;
260
261        entity_grads.row_mut(s).add_assign(&h_grad);
262        relation_grads.row_mut(p).add_assign(&r_grad);
263        entity_grads.row_mut(o).add_assign(&t_grad);
264    }
265
266    /// Perform one training epoch
267    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
268        let mut rng = Random::default();
269
270        let mut total_loss = 0.0;
271        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
272            / self.base.config.batch_size;
273
274        // Create shuffled batches
275        let mut shuffled_triples = self.base.triples.clone();
276        // Manual Fisher-Yates shuffle using scirs2-core
277        for i in (1..shuffled_triples.len()).rev() {
278            let j = rng.random_range(0..i + 1);
279            shuffled_triples.swap(i, j);
280        }
281
282        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
283            let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
284            let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
285            let mut batch_loss = 0.0;
286
287            for &pos_triple in batch_triples {
288                // Generate negative samples
289                let neg_samples = self
290                    .base
291                    .generate_negative_samples(self.base.config.negative_samples, &mut rng);
292
293                for neg_triple in neg_samples {
294                    // Compute scores
295                    let pos_score =
296                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
297                    let neg_score =
298                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
299
300                    // Compute loss
301                    let triple_loss = match self.loss_function {
302                        LossFunction::Logistic => {
303                            logistic_loss(pos_score, 1.0) + logistic_loss(neg_score, -1.0)
304                        }
305                        LossFunction::MarginRanking => {
306                            let margin = self
307                                .base
308                                .config
309                                .model_params
310                                .get("margin")
311                                .copied()
312                                .unwrap_or(1.0);
313                            margin_loss(pos_score, neg_score, margin)
314                        }
315                        LossFunction::SquaredLoss => (1.0 - pos_score).powi(2) + neg_score.powi(2),
316                    };
317
318                    batch_loss += triple_loss;
319
320                    // Compute and accumulate gradients
321                    let (entity_grads, relation_grads) =
322                        self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
323
324                    batch_entity_grads += &entity_grads;
325                    batch_relation_grads += &relation_grads;
326                }
327            }
328
329            // Apply gradients with L2 regularization
330            gradient_update(
331                &mut self.entity_embeddings,
332                &batch_entity_grads,
333                learning_rate,
334                self.base.config.l2_reg,
335            );
336
337            gradient_update(
338                &mut self.relation_embeddings,
339                &batch_relation_grads,
340                learning_rate,
341                self.base.config.l2_reg,
342            );
343
344            // Optional: re-normalize entity embeddings
345            if self
346                .base
347                .config
348                .model_params
349                .get("normalize_entities")
350                .copied()
351                .unwrap_or(0.0)
352                > 0.0
353            {
354                normalize_embeddings(&mut self.entity_embeddings);
355            }
356
357            total_loss += batch_loss;
358        }
359
360        Ok(total_loss / num_batches as f64)
361    }
362}
363
364#[async_trait]
365impl EmbeddingModel for DistMult {
366    fn config(&self) -> &ModelConfig {
367        &self.base.config
368    }
369
370    fn model_id(&self) -> &Uuid {
371        &self.base.model_id
372    }
373
374    fn model_type(&self) -> &'static str {
375        "DistMult"
376    }
377
378    fn add_triple(&mut self, triple: Triple) -> Result<()> {
379        // Check for asymmetric relations and warn
380        let predicate_str = triple.predicate.to_string();
381        if predicate_str.contains("parent")
382            || predicate_str.contains("child")
383            || predicate_str.contains("born")
384            || predicate_str.contains("founder")
385        {
386            warn!(
387                "DistMult may not handle asymmetric relation well: {}",
388                predicate_str
389            );
390        }
391
392        self.base.add_triple(triple)
393    }
394
395    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
396        let start_time = Instant::now();
397        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
398
399        // Initialize embeddings if needed
400        self.initialize_embeddings();
401
402        if !self.embeddings_initialized {
403            return Err(anyhow!("No training data available"));
404        }
405
406        let mut loss_history = Vec::new();
407        let learning_rate = self.base.config.learning_rate;
408
409        info!("Starting DistMult training for {} epochs", max_epochs);
410
411        for epoch in 0..max_epochs {
412            let epoch_loss = self.train_epoch(learning_rate).await?;
413            loss_history.push(epoch_loss);
414
415            if epoch % 100 == 0 {
416                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
417            }
418
419            // Simple convergence check
420            if epoch > 10 && epoch_loss < 1e-6 {
421                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
422                break;
423            }
424        }
425
426        self.base.mark_trained();
427        let training_time = start_time.elapsed().as_secs_f64();
428
429        Ok(TrainingStats {
430            epochs_completed: loss_history.len(),
431            final_loss: loss_history.last().copied().unwrap_or(0.0),
432            training_time_seconds: training_time,
433            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
434            loss_history,
435        })
436    }
437
438    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
439        if !self.embeddings_initialized {
440            return Err(anyhow!("Model not trained"));
441        }
442
443        let entity_id = self
444            .base
445            .get_entity_id(entity)
446            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
447
448        let embedding = self.entity_embeddings.row(entity_id).to_owned();
449        Ok(ndarray_to_vector(&embedding))
450    }
451
452    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
453        if !self.embeddings_initialized {
454            return Err(anyhow!("Model not trained"));
455        }
456
457        let relation_id = self
458            .base
459            .get_relation_id(relation)
460            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
461
462        let embedding = self.relation_embeddings.row(relation_id).to_owned();
463        Ok(ndarray_to_vector(&embedding))
464    }
465
466    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
467        let subject_id = self
468            .base
469            .get_entity_id(subject)
470            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
471        let predicate_id = self
472            .base
473            .get_relation_id(predicate)
474            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
475        let object_id = self
476            .base
477            .get_entity_id(object)
478            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
479
480        self.score_triple_ids(subject_id, predicate_id, object_id)
481    }
482
483    fn predict_objects(
484        &self,
485        subject: &str,
486        predicate: &str,
487        k: usize,
488    ) -> Result<Vec<(String, f64)>> {
489        if !self.embeddings_initialized {
490            return Err(anyhow!("Model not trained"));
491        }
492
493        let subject_id = self
494            .base
495            .get_entity_id(subject)
496            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
497        let predicate_id = self
498            .base
499            .get_relation_id(predicate)
500            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
501
502        let mut scores = Vec::new();
503
504        for object_id in 0..self.base.num_entities() {
505            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
506            let object_name = self.base.get_entity(object_id).unwrap().clone();
507            scores.push((object_name, score));
508        }
509
510        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
511        scores.truncate(k);
512
513        Ok(scores)
514    }
515
516    fn predict_subjects(
517        &self,
518        predicate: &str,
519        object: &str,
520        k: usize,
521    ) -> Result<Vec<(String, f64)>> {
522        if !self.embeddings_initialized {
523            return Err(anyhow!("Model not trained"));
524        }
525
526        let predicate_id = self
527            .base
528            .get_relation_id(predicate)
529            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
530        let object_id = self
531            .base
532            .get_entity_id(object)
533            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
534
535        let mut scores = Vec::new();
536
537        for subject_id in 0..self.base.num_entities() {
538            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
539            let subject_name = self.base.get_entity(subject_id).unwrap().clone();
540            scores.push((subject_name, score));
541        }
542
543        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
544        scores.truncate(k);
545
546        Ok(scores)
547    }
548
549    fn predict_relations(
550        &self,
551        subject: &str,
552        object: &str,
553        k: usize,
554    ) -> Result<Vec<(String, f64)>> {
555        if !self.embeddings_initialized {
556            return Err(anyhow!("Model not trained"));
557        }
558
559        let subject_id = self
560            .base
561            .get_entity_id(subject)
562            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
563        let object_id = self
564            .base
565            .get_entity_id(object)
566            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
567
568        let mut scores = Vec::new();
569
570        for predicate_id in 0..self.base.num_relations() {
571            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
572            let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
573            scores.push((predicate_name, score));
574        }
575
576        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
577        scores.truncate(k);
578
579        Ok(scores)
580    }
581
582    fn get_entities(&self) -> Vec<String> {
583        self.base.get_entities()
584    }
585
586    fn get_relations(&self) -> Vec<String> {
587        self.base.get_relations()
588    }
589
590    fn get_stats(&self) -> ModelStats {
591        self.base.get_stats("DistMult")
592    }
593
594    fn save(&self, path: &str) -> Result<()> {
595        info!("Saving DistMult model to {}", path);
596        Ok(())
597    }
598
599    fn load(&mut self, path: &str) -> Result<()> {
600        info!("Loading DistMult model from {}", path);
601        Ok(())
602    }
603
604    fn clear(&mut self) {
605        self.base.clear();
606        self.entity_embeddings = Array2::zeros((0, self.base.config.dimensions));
607        self.relation_embeddings = Array2::zeros((0, self.base.config.dimensions));
608        self.embeddings_initialized = false;
609    }
610
611    fn is_trained(&self) -> bool {
612        self.base.is_trained
613    }
614
615    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
616        Err(anyhow!(
617            "Knowledge graph embedding model does not support text encoding"
618        ))
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use crate::NamedNode;
626
627    #[tokio::test]
628    async fn test_distmult_basic() -> Result<()> {
629        let config = ModelConfig::default()
630            .with_dimensions(50)
631            .with_max_epochs(10)
632            .with_seed(42);
633
634        let mut model = DistMult::new(config);
635
636        // Add test triples (use symmetric relation for DistMult)
637        let alice = NamedNode::new("http://example.org/alice")?;
638        let similar_to = NamedNode::new("http://example.org/similarTo")?;
639        let bob = NamedNode::new("http://example.org/bob")?;
640
641        model.add_triple(Triple::new(alice.clone(), similar_to.clone(), bob.clone()))?;
642        model.add_triple(Triple::new(bob.clone(), similar_to.clone(), alice.clone()))?;
643
644        // Train
645        let stats = model.train(Some(5)).await?;
646        assert!(stats.epochs_completed > 0);
647
648        // Test embeddings
649        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
650        assert_eq!(alice_emb.dimensions, 50);
651
652        // Test scoring
653        let score = model.score_triple(
654            "http://example.org/alice",
655            "http://example.org/similarTo",
656            "http://example.org/bob",
657        )?;
658
659        // Score should be a finite number
660        assert!(score.is_finite());
661
662        Ok(())
663    }
664}