1use crate::models::{common::*, BaseModel};
11use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
12use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use scirs2_core::ndarray_ext::{Array1, Array2};
15#[allow(unused_imports)]
16use scirs2_core::random::{Random, Rng};
17use serde::{Deserialize, Serialize};
18use std::ops::AddAssign;
19use std::time::Instant;
20use tracing::{debug, info, warn};
21use uuid::Uuid;
22
23#[derive(Debug)]
25pub struct DistMult {
26 base: BaseModel,
28 entity_embeddings: Array2<f64>,
30 relation_embeddings: Array2<f64>,
32 embeddings_initialized: bool,
34 #[allow(dead_code)]
36 dropout_rate: f64,
37 loss_function: LossFunction,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
43pub enum LossFunction {
44 Logistic,
46 MarginRanking,
48 SquaredLoss,
50}
51
52impl DistMult {
53 pub fn new(config: ModelConfig) -> Self {
55 let base = BaseModel::new(config.clone());
56
57 let dropout_rate = config
59 .model_params
60 .get("dropout_rate")
61 .copied()
62 .unwrap_or(0.0);
63
64 let loss_function = match config.model_params.get("loss_function") {
65 Some(0.0) => LossFunction::Logistic,
66 Some(1.0) => LossFunction::MarginRanking,
67 Some(2.0) => LossFunction::SquaredLoss,
68 _ => LossFunction::Logistic, };
70
71 Self {
72 base,
73 entity_embeddings: Array2::zeros((0, config.dimensions)),
74 relation_embeddings: Array2::zeros((0, config.dimensions)),
75 embeddings_initialized: false,
76 dropout_rate,
77 loss_function,
78 }
79 }
80
81 fn initialize_embeddings(&mut self) {
83 if self.embeddings_initialized {
84 return;
85 }
86
87 let num_entities = self.base.num_entities();
88 let num_relations = self.base.num_relations();
89 let dimensions = self.base.config.dimensions;
90
91 if num_entities == 0 || num_relations == 0 {
92 return;
93 }
94
95 let mut rng = Random::default();
96
97 self.entity_embeddings =
99 xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
100
101 self.relation_embeddings = xavier_init(
102 (num_relations, dimensions),
103 dimensions,
104 dimensions,
105 &mut rng,
106 );
107
108 normalize_embeddings(&mut self.entity_embeddings);
110
111 self.embeddings_initialized = true;
112 debug!(
113 "Initialized DistMult embeddings: {} entities, {} relations, {} dimensions",
114 num_entities, num_relations, dimensions
115 );
116 }
117
118 fn score_triple_ids(
120 &self,
121 subject_id: usize,
122 predicate_id: usize,
123 object_id: usize,
124 ) -> Result<f64> {
125 if !self.embeddings_initialized {
126 return Err(anyhow!("Model not trained"));
127 }
128
129 let h = self.entity_embeddings.row(subject_id);
130 let r = self.relation_embeddings.row(predicate_id);
131 let t = self.entity_embeddings.row(object_id);
132
133 let score = (&h * &r * t).sum();
135
136 Ok(score)
137 }
138
139 #[allow(dead_code)]
141 fn apply_dropout(&self, embeddings: &Array1<f64>, rng: &mut Random) -> Array1<f64> {
142 if self.dropout_rate > 0.0 {
143 embeddings.mapv(|x| {
144 if rng.random_f64() < self.dropout_rate {
145 0.0
146 } else {
147 x / (1.0 - self.dropout_rate)
148 }
149 })
150 } else {
151 embeddings.to_owned()
152 }
153 }
154
155 fn compute_gradients(
157 &self,
158 pos_triple: (usize, usize, usize),
159 neg_triple: (usize, usize, usize),
160 pos_score: f64,
161 neg_score: f64,
162 ) -> Result<(Array2<f64>, Array2<f64>)> {
163 let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
164 let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
165
166 match self.loss_function {
167 LossFunction::Logistic => {
168 let pos_sigmoid = sigmoid(pos_score);
170 let neg_sigmoid = sigmoid(neg_score);
171
172 let pos_grad_coeff = pos_sigmoid - 1.0; let neg_grad_coeff = neg_sigmoid; self.add_triple_gradients(
176 pos_triple,
177 pos_grad_coeff,
178 &mut entity_grads,
179 &mut relation_grads,
180 );
181 self.add_triple_gradients(
182 neg_triple,
183 neg_grad_coeff,
184 &mut entity_grads,
185 &mut relation_grads,
186 );
187 }
188 LossFunction::MarginRanking => {
189 let margin = self
191 .base
192 .config
193 .model_params
194 .get("margin")
195 .copied()
196 .unwrap_or(1.0);
197 let loss = margin + neg_score - pos_score;
198
199 if loss > 0.0 {
200 self.add_triple_gradients(
202 pos_triple,
203 -1.0,
204 &mut entity_grads,
205 &mut relation_grads,
206 );
207 self.add_triple_gradients(
208 neg_triple,
209 1.0,
210 &mut entity_grads,
211 &mut relation_grads,
212 );
213 }
214 }
215 LossFunction::SquaredLoss => {
216 let pos_grad_coeff = -2.0 * (1.0 - pos_score);
218 let neg_grad_coeff = -2.0 * neg_score;
219
220 self.add_triple_gradients(
221 pos_triple,
222 pos_grad_coeff,
223 &mut entity_grads,
224 &mut relation_grads,
225 );
226 self.add_triple_gradients(
227 neg_triple,
228 neg_grad_coeff,
229 &mut entity_grads,
230 &mut relation_grads,
231 );
232 }
233 }
234
235 Ok((entity_grads, relation_grads))
236 }
237
238 fn add_triple_gradients(
240 &self,
241 triple: (usize, usize, usize),
242 grad_coeff: f64,
243 entity_grads: &mut Array2<f64>,
244 relation_grads: &mut Array2<f64>,
245 ) {
246 let (s, p, o) = triple;
247
248 let h = self.entity_embeddings.row(s);
249 let r = self.relation_embeddings.row(p);
250 let t = self.entity_embeddings.row(o);
251
252 let h_grad = (&r * &t) * grad_coeff;
258 let r_grad = (&h * &t) * grad_coeff;
259 let t_grad = (&h * &r) * grad_coeff;
260
261 entity_grads.row_mut(s).add_assign(&h_grad);
262 relation_grads.row_mut(p).add_assign(&r_grad);
263 entity_grads.row_mut(o).add_assign(&t_grad);
264 }
265
266 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
268 let mut rng = Random::default();
269
270 let mut total_loss = 0.0;
271 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
272 / self.base.config.batch_size;
273
274 let mut shuffled_triples = self.base.triples.clone();
276 for i in (1..shuffled_triples.len()).rev() {
278 let j = rng.random_range(0..i + 1);
279 shuffled_triples.swap(i, j);
280 }
281
282 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
283 let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
284 let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
285 let mut batch_loss = 0.0;
286
287 for &pos_triple in batch_triples {
288 let neg_samples = self
290 .base
291 .generate_negative_samples(self.base.config.negative_samples, &mut rng);
292
293 for neg_triple in neg_samples {
294 let pos_score =
296 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
297 let neg_score =
298 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
299
300 let triple_loss = match self.loss_function {
302 LossFunction::Logistic => {
303 logistic_loss(pos_score, 1.0) + logistic_loss(neg_score, -1.0)
304 }
305 LossFunction::MarginRanking => {
306 let margin = self
307 .base
308 .config
309 .model_params
310 .get("margin")
311 .copied()
312 .unwrap_or(1.0);
313 margin_loss(pos_score, neg_score, margin)
314 }
315 LossFunction::SquaredLoss => (1.0 - pos_score).powi(2) + neg_score.powi(2),
316 };
317
318 batch_loss += triple_loss;
319
320 let (entity_grads, relation_grads) =
322 self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
323
324 batch_entity_grads += &entity_grads;
325 batch_relation_grads += &relation_grads;
326 }
327 }
328
329 gradient_update(
331 &mut self.entity_embeddings,
332 &batch_entity_grads,
333 learning_rate,
334 self.base.config.l2_reg,
335 );
336
337 gradient_update(
338 &mut self.relation_embeddings,
339 &batch_relation_grads,
340 learning_rate,
341 self.base.config.l2_reg,
342 );
343
344 if self
346 .base
347 .config
348 .model_params
349 .get("normalize_entities")
350 .copied()
351 .unwrap_or(0.0)
352 > 0.0
353 {
354 normalize_embeddings(&mut self.entity_embeddings);
355 }
356
357 total_loss += batch_loss;
358 }
359
360 Ok(total_loss / num_batches as f64)
361 }
362}
363
364#[async_trait]
365impl EmbeddingModel for DistMult {
366 fn config(&self) -> &ModelConfig {
367 &self.base.config
368 }
369
370 fn model_id(&self) -> &Uuid {
371 &self.base.model_id
372 }
373
374 fn model_type(&self) -> &'static str {
375 "DistMult"
376 }
377
378 fn add_triple(&mut self, triple: Triple) -> Result<()> {
379 let predicate_str = triple.predicate.to_string();
381 if predicate_str.contains("parent")
382 || predicate_str.contains("child")
383 || predicate_str.contains("born")
384 || predicate_str.contains("founder")
385 {
386 warn!(
387 "DistMult may not handle asymmetric relation well: {}",
388 predicate_str
389 );
390 }
391
392 self.base.add_triple(triple)
393 }
394
395 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
396 let start_time = Instant::now();
397 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
398
399 self.initialize_embeddings();
401
402 if !self.embeddings_initialized {
403 return Err(anyhow!("No training data available"));
404 }
405
406 let mut loss_history = Vec::new();
407 let learning_rate = self.base.config.learning_rate;
408
409 info!("Starting DistMult training for {} epochs", max_epochs);
410
411 for epoch in 0..max_epochs {
412 let epoch_loss = self.train_epoch(learning_rate).await?;
413 loss_history.push(epoch_loss);
414
415 if epoch % 100 == 0 {
416 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
417 }
418
419 if epoch > 10 && epoch_loss < 1e-6 {
421 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
422 break;
423 }
424 }
425
426 self.base.mark_trained();
427 let training_time = start_time.elapsed().as_secs_f64();
428
429 Ok(TrainingStats {
430 epochs_completed: loss_history.len(),
431 final_loss: loss_history.last().copied().unwrap_or(0.0),
432 training_time_seconds: training_time,
433 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
434 loss_history,
435 })
436 }
437
438 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
439 if !self.embeddings_initialized {
440 return Err(anyhow!("Model not trained"));
441 }
442
443 let entity_id = self
444 .base
445 .get_entity_id(entity)
446 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
447
448 let embedding = self.entity_embeddings.row(entity_id).to_owned();
449 Ok(ndarray_to_vector(&embedding))
450 }
451
452 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
453 if !self.embeddings_initialized {
454 return Err(anyhow!("Model not trained"));
455 }
456
457 let relation_id = self
458 .base
459 .get_relation_id(relation)
460 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
461
462 let embedding = self.relation_embeddings.row(relation_id).to_owned();
463 Ok(ndarray_to_vector(&embedding))
464 }
465
466 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
467 let subject_id = self
468 .base
469 .get_entity_id(subject)
470 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
471 let predicate_id = self
472 .base
473 .get_relation_id(predicate)
474 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
475 let object_id = self
476 .base
477 .get_entity_id(object)
478 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
479
480 self.score_triple_ids(subject_id, predicate_id, object_id)
481 }
482
483 fn predict_objects(
484 &self,
485 subject: &str,
486 predicate: &str,
487 k: usize,
488 ) -> Result<Vec<(String, f64)>> {
489 if !self.embeddings_initialized {
490 return Err(anyhow!("Model not trained"));
491 }
492
493 let subject_id = self
494 .base
495 .get_entity_id(subject)
496 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
497 let predicate_id = self
498 .base
499 .get_relation_id(predicate)
500 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
501
502 let mut scores = Vec::new();
503
504 for object_id in 0..self.base.num_entities() {
505 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
506 let object_name = self.base.get_entity(object_id).unwrap().clone();
507 scores.push((object_name, score));
508 }
509
510 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
511 scores.truncate(k);
512
513 Ok(scores)
514 }
515
516 fn predict_subjects(
517 &self,
518 predicate: &str,
519 object: &str,
520 k: usize,
521 ) -> Result<Vec<(String, f64)>> {
522 if !self.embeddings_initialized {
523 return Err(anyhow!("Model not trained"));
524 }
525
526 let predicate_id = self
527 .base
528 .get_relation_id(predicate)
529 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
530 let object_id = self
531 .base
532 .get_entity_id(object)
533 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
534
535 let mut scores = Vec::new();
536
537 for subject_id in 0..self.base.num_entities() {
538 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
539 let subject_name = self.base.get_entity(subject_id).unwrap().clone();
540 scores.push((subject_name, score));
541 }
542
543 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
544 scores.truncate(k);
545
546 Ok(scores)
547 }
548
549 fn predict_relations(
550 &self,
551 subject: &str,
552 object: &str,
553 k: usize,
554 ) -> Result<Vec<(String, f64)>> {
555 if !self.embeddings_initialized {
556 return Err(anyhow!("Model not trained"));
557 }
558
559 let subject_id = self
560 .base
561 .get_entity_id(subject)
562 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
563 let object_id = self
564 .base
565 .get_entity_id(object)
566 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
567
568 let mut scores = Vec::new();
569
570 for predicate_id in 0..self.base.num_relations() {
571 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
572 let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
573 scores.push((predicate_name, score));
574 }
575
576 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
577 scores.truncate(k);
578
579 Ok(scores)
580 }
581
582 fn get_entities(&self) -> Vec<String> {
583 self.base.get_entities()
584 }
585
586 fn get_relations(&self) -> Vec<String> {
587 self.base.get_relations()
588 }
589
590 fn get_stats(&self) -> ModelStats {
591 self.base.get_stats("DistMult")
592 }
593
594 fn save(&self, path: &str) -> Result<()> {
595 info!("Saving DistMult model to {}", path);
596 Ok(())
597 }
598
599 fn load(&mut self, path: &str) -> Result<()> {
600 info!("Loading DistMult model from {}", path);
601 Ok(())
602 }
603
604 fn clear(&mut self) {
605 self.base.clear();
606 self.entity_embeddings = Array2::zeros((0, self.base.config.dimensions));
607 self.relation_embeddings = Array2::zeros((0, self.base.config.dimensions));
608 self.embeddings_initialized = false;
609 }
610
611 fn is_trained(&self) -> bool {
612 self.base.is_trained
613 }
614
615 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
616 Err(anyhow!(
617 "Knowledge graph embedding model does not support text encoding"
618 ))
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use crate::NamedNode;
626
627 #[tokio::test]
628 async fn test_distmult_basic() -> Result<()> {
629 let config = ModelConfig::default()
630 .with_dimensions(50)
631 .with_max_epochs(10)
632 .with_seed(42);
633
634 let mut model = DistMult::new(config);
635
636 let alice = NamedNode::new("http://example.org/alice")?;
638 let similar_to = NamedNode::new("http://example.org/similarTo")?;
639 let bob = NamedNode::new("http://example.org/bob")?;
640
641 model.add_triple(Triple::new(alice.clone(), similar_to.clone(), bob.clone()))?;
642 model.add_triple(Triple::new(bob.clone(), similar_to.clone(), alice.clone()))?;
643
644 let stats = model.train(Some(5)).await?;
646 assert!(stats.epochs_completed > 0);
647
648 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
650 assert_eq!(alice_emb.dimensions, 50);
651
652 let score = model.score_triple(
654 "http://example.org/alice",
655 "http://example.org/similarTo",
656 "http://example.org/bob",
657 )?;
658
659 assert!(score.is_finite());
661
662 Ok(())
663 }
664}