1use crate::models::{common::*, BaseModel};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use scirs2_core::ndarray_ext::Array2;
14#[allow(unused_imports)]
15use scirs2_core::random::{Random, Rng};
16use std::time::Instant;
17use tracing::{debug, info};
18use uuid::Uuid;
19
20#[derive(Debug)]
22pub struct RotatE {
23 base: BaseModel,
25 entity_embeddings_real: Array2<f64>,
27 entity_embeddings_imag: Array2<f64>,
29 relation_phases: Array2<f64>,
31 embeddings_initialized: bool,
33 adversarial_temperature: f64,
35 modulus_constraint: bool,
37}
38
39impl RotatE {
40 pub fn new(config: ModelConfig) -> Self {
42 let base = BaseModel::new(config.clone());
43
44 let adversarial_temperature = config
46 .model_params
47 .get("adversarial_temperature")
48 .copied()
49 .unwrap_or(1.0);
50
51 let modulus_constraint = config
52 .model_params
53 .get("modulus_constraint")
54 .map(|&x| x > 0.0)
55 .unwrap_or(true);
56
57 Self {
58 base,
59 entity_embeddings_real: Array2::zeros((0, config.dimensions)),
60 entity_embeddings_imag: Array2::zeros((0, config.dimensions)),
61 relation_phases: Array2::zeros((0, config.dimensions)),
62 embeddings_initialized: false,
63 adversarial_temperature,
64 modulus_constraint,
65 }
66 }
67
68 fn initialize_embeddings(&mut self) {
70 if self.embeddings_initialized {
71 return;
72 }
73
74 let num_entities = self.base.num_entities();
75 let num_relations = self.base.num_relations();
76 let dimensions = self.base.config.dimensions;
77
78 if num_entities == 0 || num_relations == 0 {
79 return;
80 }
81
82 let mut rng = Random::default();
83
84 self.entity_embeddings_real = uniform_init((num_entities, dimensions), -1.0, 1.0, &mut rng);
86
87 self.entity_embeddings_imag = uniform_init((num_entities, dimensions), -1.0, 1.0, &mut rng);
88
89 self.relation_phases = uniform_init(
91 (num_relations, dimensions),
92 0.0,
93 2.0 * std::f64::consts::PI,
94 &mut rng,
95 );
96
97 if self.modulus_constraint {
99 self.apply_modulus_constraint();
100 }
101
102 self.embeddings_initialized = true;
103 debug!(
104 "Initialized RotatE embeddings: {} entities, {} relations, {} dimensions",
105 num_entities, num_relations, dimensions
106 );
107 }
108
109 fn apply_modulus_constraint(&mut self) {
111 for i in 0..self.entity_embeddings_real.nrows() {
112 let mut real_row = self.entity_embeddings_real.row_mut(i);
113 let mut imag_row = self.entity_embeddings_imag.row_mut(i);
114
115 for j in 0..real_row.len() {
116 let real = real_row[j];
117 let imag = imag_row[j];
118 let modulus = (real * real + imag * imag).sqrt();
119
120 if modulus > 1e-10 {
121 real_row[j] = real / modulus;
122 imag_row[j] = imag / modulus;
123 }
124 }
125 }
126 }
127
128 fn score_triple_ids(
131 &self,
132 subject_id: usize,
133 predicate_id: usize,
134 object_id: usize,
135 ) -> Result<f64> {
136 if !self.embeddings_initialized {
137 return Err(anyhow!("Model not trained"));
138 }
139
140 let h_real = self.entity_embeddings_real.row(subject_id);
141 let h_imag = self.entity_embeddings_imag.row(subject_id);
142 let r_phases = self.relation_phases.row(predicate_id);
143 let t_real = self.entity_embeddings_real.row(object_id);
144 let t_imag = self.entity_embeddings_imag.row(object_id);
145
146 let mut distance_squared = 0.0;
152
153 for ((((&h_r, &h_i), &phase), &t_r), &t_i) in h_real
154 .iter()
155 .zip(h_imag.iter())
156 .zip(r_phases.iter())
157 .zip(t_real.iter())
158 .zip(t_imag.iter())
159 {
160 let cos_phase = phase.cos();
161 let sin_phase = phase.sin();
162
163 let rotated_real = h_r * cos_phase - h_i * sin_phase;
165 let rotated_imag = h_r * sin_phase + h_i * cos_phase;
166
167 let diff_real = rotated_real - t_r;
169 let diff_imag = rotated_imag - t_i;
170
171 distance_squared += diff_real * diff_real + diff_imag * diff_imag;
172 }
173
174 Ok(-distance_squared.sqrt())
176 }
177
178 fn compute_gradients(
180 &self,
181 pos_triple: (usize, usize, usize),
182 neg_triple: (usize, usize, usize),
183 pos_score: f64,
184 neg_score: f64,
185 ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)> {
186 let mut entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
187 let mut entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
188 let mut relation_grads = Array2::zeros(self.relation_phases.raw_dim());
189
190 let margin = self
192 .base
193 .config
194 .model_params
195 .get("margin")
196 .copied()
197 .unwrap_or(6.0);
198 let loss = margin + (-pos_score) - (-neg_score); if loss > 0.0 {
201 self.add_triple_gradients(
203 pos_triple,
204 1.0,
205 &mut entity_grads_real,
206 &mut entity_grads_imag,
207 &mut relation_grads,
208 );
209
210 self.add_triple_gradients(
212 neg_triple,
213 -1.0,
214 &mut entity_grads_real,
215 &mut entity_grads_imag,
216 &mut relation_grads,
217 );
218 }
219
220 Ok((entity_grads_real, entity_grads_imag, relation_grads))
221 }
222
223 fn add_triple_gradients(
225 &self,
226 triple: (usize, usize, usize),
227 grad_coeff: f64,
228 entity_grads_real: &mut Array2<f64>,
229 entity_grads_imag: &mut Array2<f64>,
230 relation_grads: &mut Array2<f64>,
231 ) {
232 let (s, p, o) = triple;
233
234 let h_real = self.entity_embeddings_real.row(s);
235 let h_imag = self.entity_embeddings_imag.row(s);
236 let r_phases = self.relation_phases.row(p);
237 let t_real = self.entity_embeddings_real.row(o);
238 let t_imag = self.entity_embeddings_imag.row(o);
239
240 for (i, ((((&h_r, &h_i), &phase), &t_r), &t_i)) in h_real
241 .iter()
242 .zip(h_imag.iter())
243 .zip(r_phases.iter())
244 .zip(t_real.iter())
245 .zip(t_imag.iter())
246 .enumerate()
247 {
248 let cos_phase = phase.cos();
249 let sin_phase = phase.sin();
250
251 let rotated_real = h_r * cos_phase - h_i * sin_phase;
253 let rotated_imag = h_r * sin_phase + h_i * cos_phase;
254
255 let diff_real = rotated_real - t_r;
257 let diff_imag = rotated_imag - t_i;
258
259 let distance = (diff_real * diff_real + diff_imag * diff_imag).sqrt();
260
261 if distance > 1e-10 {
262 let norm_factor = grad_coeff / distance;
263 let grad_real = diff_real * norm_factor;
264 let grad_imag = diff_imag * norm_factor;
265
266 entity_grads_real[[s, i]] += grad_real * cos_phase + grad_imag * sin_phase;
268 entity_grads_imag[[s, i]] += -grad_real * sin_phase + grad_imag * cos_phase;
269
270 entity_grads_real[[o, i]] -= grad_real;
272 entity_grads_imag[[o, i]] -= grad_imag;
273
274 let phase_grad = grad_real * (-h_r * sin_phase - h_i * cos_phase)
276 + grad_imag * (h_r * cos_phase - h_i * sin_phase);
277 relation_grads[[p, i]] += phase_grad;
278 }
279 }
280 }
281
282 fn generate_adversarial_negatives(
284 &self,
285 positive_triple: (usize, usize, usize),
286 num_samples: usize,
287 rng: &mut Random,
288 ) -> Vec<(usize, usize, usize)> {
289 let mut negatives = Vec::new();
290 let num_entities = self.base.num_entities();
291
292 for _ in 0..num_samples {
293 let corrupt_head = rng.random_f64() < 0.5;
295
296 if corrupt_head {
297 let mut candidate_scores = Vec::new();
299 for entity_id in 0..num_entities {
300 if entity_id != positive_triple.0 {
301 let neg_triple = (entity_id, positive_triple.1, positive_triple.2);
302 if let Ok(score) =
303 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)
304 {
305 candidate_scores.push((entity_id, score));
306 }
307 }
308 }
309
310 if !candidate_scores.is_empty() {
311 let weights: Vec<f64> = candidate_scores
313 .iter()
314 .map(|(_, score)| (-score / self.adversarial_temperature).exp())
315 .collect();
316
317 let total_weight: f64 = weights.iter().sum();
318 let mut cumulative = 0.0;
319 let threshold = rng.random_f64() * total_weight;
320
321 for (i, &weight) in weights.iter().enumerate() {
322 cumulative += weight;
323 if cumulative >= threshold {
324 let entity_id = candidate_scores[i].0;
325 negatives.push((entity_id, positive_triple.1, positive_triple.2));
326 break;
327 }
328 }
329 }
330 } else {
331 let mut candidate_scores = Vec::new();
333 for entity_id in 0..num_entities {
334 if entity_id != positive_triple.2 {
335 let neg_triple = (positive_triple.0, positive_triple.1, entity_id);
336 if let Ok(score) =
337 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)
338 {
339 candidate_scores.push((entity_id, score));
340 }
341 }
342 }
343
344 if !candidate_scores.is_empty() {
345 let weights: Vec<f64> = candidate_scores
346 .iter()
347 .map(|(_, score)| (-score / self.adversarial_temperature).exp())
348 .collect();
349
350 let total_weight: f64 = weights.iter().sum();
351 let mut cumulative = 0.0;
352 let threshold = rng.random_f64() * total_weight;
353
354 for (i, &weight) in weights.iter().enumerate() {
355 cumulative += weight;
356 if cumulative >= threshold {
357 let entity_id = candidate_scores[i].0;
358 negatives.push((positive_triple.0, positive_triple.1, entity_id));
359 break;
360 }
361 }
362 }
363 }
364 }
365
366 while negatives.len() < num_samples {
368 let corrupt_head = rng.random_f64() < 0.5;
369 let negative_triple = if corrupt_head {
370 let new_head = rng.random_range(0..num_entities);
371 (new_head, positive_triple.1, positive_triple.2)
372 } else {
373 let new_tail = rng.random_range(0..num_entities);
374 (positive_triple.0, positive_triple.1, new_tail)
375 };
376
377 if !self
378 .base
379 .has_triple(negative_triple.0, negative_triple.1, negative_triple.2)
380 {
381 negatives.push(negative_triple);
382 }
383 }
384
385 negatives
386 }
387
388 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
390 let mut rng = Random::default();
391
392 let mut total_loss = 0.0;
393 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
394 / self.base.config.batch_size;
395
396 let mut shuffled_triples = self.base.triples.clone();
397 for i in (1..shuffled_triples.len()).rev() {
399 let j = rng.random_range(0..i + 1);
400 shuffled_triples.swap(i, j);
401 }
402
403 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
404 let mut batch_entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
405 let mut batch_entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
406 let mut batch_relation_grads = Array2::zeros(self.relation_phases.raw_dim());
407 let mut batch_loss = 0.0;
408
409 for &pos_triple in batch_triples {
410 let neg_samples = self.generate_adversarial_negatives(
412 pos_triple,
413 self.base.config.negative_samples,
414 &mut rng,
415 );
416
417 for neg_triple in neg_samples {
418 let pos_score =
419 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
420 let neg_score =
421 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
422
423 let pos_distance = -pos_score;
425 let neg_distance = -neg_score;
426
427 let margin = self
428 .base
429 .config
430 .model_params
431 .get("margin")
432 .copied()
433 .unwrap_or(6.0);
434 let loss = margin_loss(pos_distance, neg_distance, margin);
435 batch_loss += loss;
436
437 if loss > 0.0 {
438 let (entity_grads_real, entity_grads_imag, relation_grads) =
439 self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
440
441 batch_entity_grads_real += &entity_grads_real;
442 batch_entity_grads_imag += &entity_grads_imag;
443 batch_relation_grads += &relation_grads;
444 }
445 }
446 }
447
448 gradient_update(
450 &mut self.entity_embeddings_real,
451 &batch_entity_grads_real,
452 learning_rate,
453 self.base.config.l2_reg,
454 );
455
456 gradient_update(
457 &mut self.entity_embeddings_imag,
458 &batch_entity_grads_imag,
459 learning_rate,
460 self.base.config.l2_reg,
461 );
462
463 gradient_update(
464 &mut self.relation_phases,
465 &batch_relation_grads,
466 learning_rate,
467 0.0, );
469
470 if self.modulus_constraint {
472 self.apply_modulus_constraint();
473 }
474
475 self.relation_phases.mapv_inplace(|x| {
477 let mut angle = x % (2.0 * std::f64::consts::PI);
478 if angle < 0.0 {
479 angle += 2.0 * std::f64::consts::PI;
480 }
481 angle
482 });
483
484 total_loss += batch_loss;
485 }
486
487 Ok(total_loss / num_batches as f64)
488 }
489
490 fn get_entity_embedding_vector(&self, entity_id: usize) -> Vector {
492 let real_part = self.entity_embeddings_real.row(entity_id);
493 let imag_part = self.entity_embeddings_imag.row(entity_id);
494
495 let mut values = Vec::with_capacity(real_part.len() * 2);
496 for &val in real_part.iter() {
497 values.push(val as f32);
498 }
499 for &val in imag_part.iter() {
500 values.push(val as f32);
501 }
502
503 Vector::new(values)
504 }
505
506 fn get_relation_embedding_vector(&self, relation_id: usize) -> Vector {
508 let phases = self.relation_phases.row(relation_id);
509 let values: Vec<f32> = phases.iter().copied().map(|x| x as f32).collect();
510 Vector::new(values)
511 }
512}
513
514#[async_trait]
515impl EmbeddingModel for RotatE {
516 fn config(&self) -> &ModelConfig {
517 &self.base.config
518 }
519
520 fn model_id(&self) -> &Uuid {
521 &self.base.model_id
522 }
523
524 fn model_type(&self) -> &'static str {
525 "RotatE"
526 }
527
528 fn add_triple(&mut self, triple: Triple) -> Result<()> {
529 self.base.add_triple(triple)
530 }
531
532 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
533 let start_time = Instant::now();
534 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
535
536 self.initialize_embeddings();
537
538 if !self.embeddings_initialized {
539 return Err(anyhow!("No training data available"));
540 }
541
542 let mut loss_history = Vec::new();
543 let learning_rate = self.base.config.learning_rate;
544
545 info!("Starting RotatE training for {} epochs", max_epochs);
546
547 for epoch in 0..max_epochs {
548 let epoch_loss = self.train_epoch(learning_rate).await?;
549 loss_history.push(epoch_loss);
550
551 if epoch % 100 == 0 {
552 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
553 }
554
555 if epoch > 10 && epoch_loss < 1e-6 {
556 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
557 break;
558 }
559 }
560
561 self.base.mark_trained();
562 let training_time = start_time.elapsed().as_secs_f64();
563
564 Ok(TrainingStats {
565 epochs_completed: loss_history.len(),
566 final_loss: loss_history.last().copied().unwrap_or(0.0),
567 training_time_seconds: training_time,
568 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
569 loss_history,
570 })
571 }
572
573 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
574 if !self.embeddings_initialized {
575 return Err(anyhow!("Model not trained"));
576 }
577
578 let entity_id = self
579 .base
580 .get_entity_id(entity)
581 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
582
583 Ok(self.get_entity_embedding_vector(entity_id))
584 }
585
586 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
587 if !self.embeddings_initialized {
588 return Err(anyhow!("Model not trained"));
589 }
590
591 let relation_id = self
592 .base
593 .get_relation_id(relation)
594 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
595
596 Ok(self.get_relation_embedding_vector(relation_id))
597 }
598
599 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
600 let subject_id = self
601 .base
602 .get_entity_id(subject)
603 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
604 let predicate_id = self
605 .base
606 .get_relation_id(predicate)
607 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
608 let object_id = self
609 .base
610 .get_entity_id(object)
611 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
612
613 self.score_triple_ids(subject_id, predicate_id, object_id)
614 }
615
616 fn predict_objects(
617 &self,
618 subject: &str,
619 predicate: &str,
620 k: usize,
621 ) -> Result<Vec<(String, f64)>> {
622 if !self.embeddings_initialized {
623 return Err(anyhow!("Model not trained"));
624 }
625
626 let subject_id = self
627 .base
628 .get_entity_id(subject)
629 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
630 let predicate_id = self
631 .base
632 .get_relation_id(predicate)
633 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
634
635 let mut scores = Vec::new();
636
637 for object_id in 0..self.base.num_entities() {
638 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
639 let object_name = self.base.get_entity(object_id).unwrap().clone();
640 scores.push((object_name, score));
641 }
642
643 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
644 scores.truncate(k);
645
646 Ok(scores)
647 }
648
649 fn predict_subjects(
650 &self,
651 predicate: &str,
652 object: &str,
653 k: usize,
654 ) -> Result<Vec<(String, f64)>> {
655 if !self.embeddings_initialized {
656 return Err(anyhow!("Model not trained"));
657 }
658
659 let predicate_id = self
660 .base
661 .get_relation_id(predicate)
662 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
663 let object_id = self
664 .base
665 .get_entity_id(object)
666 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
667
668 let mut scores = Vec::new();
669
670 for subject_id in 0..self.base.num_entities() {
671 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
672 let subject_name = self.base.get_entity(subject_id).unwrap().clone();
673 scores.push((subject_name, score));
674 }
675
676 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
677 scores.truncate(k);
678
679 Ok(scores)
680 }
681
682 fn predict_relations(
683 &self,
684 subject: &str,
685 object: &str,
686 k: usize,
687 ) -> Result<Vec<(String, f64)>> {
688 if !self.embeddings_initialized {
689 return Err(anyhow!("Model not trained"));
690 }
691
692 let subject_id = self
693 .base
694 .get_entity_id(subject)
695 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
696 let object_id = self
697 .base
698 .get_entity_id(object)
699 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
700
701 let mut scores = Vec::new();
702
703 for predicate_id in 0..self.base.num_relations() {
704 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
705 let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
706 scores.push((predicate_name, score));
707 }
708
709 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
710 scores.truncate(k);
711
712 Ok(scores)
713 }
714
715 fn get_entities(&self) -> Vec<String> {
716 self.base.get_entities()
717 }
718
719 fn get_relations(&self) -> Vec<String> {
720 self.base.get_relations()
721 }
722
723 fn get_stats(&self) -> ModelStats {
724 self.base.get_stats("RotatE")
725 }
726
727 fn save(&self, path: &str) -> Result<()> {
728 info!("Saving RotatE model to {}", path);
729 Ok(())
730 }
731
732 fn load(&mut self, path: &str) -> Result<()> {
733 info!("Loading RotatE model from {}", path);
734 Ok(())
735 }
736
737 fn clear(&mut self) {
738 self.base.clear();
739 self.entity_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
740 self.entity_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
741 self.relation_phases = Array2::zeros((0, self.base.config.dimensions));
742 self.embeddings_initialized = false;
743 }
744
745 fn is_trained(&self) -> bool {
746 self.base.is_trained
747 }
748
749 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
750 Err(anyhow!(
751 "Knowledge graph embedding model does not support text encoding"
752 ))
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 #[tokio::test]
761 async fn test_rotate_basic() -> Result<()> {
762 let config = ModelConfig::default()
763 .with_dimensions(10)
764 .with_max_epochs(5)
765 .with_seed(42);
766
767 let mut model = RotatE::new(config);
768
769 let alice = crate::NamedNode::new("http://example.org/alice")?;
770 let knows = crate::NamedNode::new("http://example.org/knows")?;
771 let bob = crate::NamedNode::new("http://example.org/bob")?;
772
773 model.add_triple(crate::Triple::new(
774 alice.clone(),
775 knows.clone(),
776 bob.clone(),
777 ))?;
778
779 let stats = model.train(Some(3)).await?;
780 assert!(stats.epochs_completed > 0);
781
782 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
783 assert_eq!(alice_emb.dimensions, 20); let score = model.score_triple(
786 "http://example.org/alice",
787 "http://example.org/knows",
788 "http://example.org/bob",
789 )?;
790
791 assert!(score.is_finite());
792
793 Ok(())
794 }
795}