oxirs_embed/models/
rotate.rs

1//! RotatE: Rotation-based Knowledge Graph Embeddings
2//!
3//! RotatE models relations as rotations in complex space, which allows it to
4//! handle symmetric, antisymmetric, inverse, and compositional relation patterns.
5//! Each relation is represented as a rotation from head to tail entity.
6//!
7//! Reference: Sun et al. "RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space" (2019)
8
9use crate::models::{common::*, BaseModel};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use scirs2_core::ndarray_ext::Array2;
14#[allow(unused_imports)]
15use scirs2_core::random::{Random, Rng};
16use std::time::Instant;
17use tracing::{debug, info};
18use uuid::Uuid;
19
20/// RotatE embedding model using complex rotations
21#[derive(Debug)]
22pub struct RotatE {
23    /// Base model functionality
24    base: BaseModel,
25    /// Real part of entity embeddings (num_entities × dimensions)
26    entity_embeddings_real: Array2<f64>,
27    /// Imaginary part of entity embeddings (num_entities × dimensions)
28    entity_embeddings_imag: Array2<f64>,
29    /// Relation phases/angles (num_relations × dimensions) - angles in [0, 2π]
30    relation_phases: Array2<f64>,
31    /// Whether embeddings have been initialized
32    embeddings_initialized: bool,
33    /// Adversarial temperature for negative sampling
34    adversarial_temperature: f64,
35    /// Modulus constraint for entity embeddings
36    modulus_constraint: bool,
37}
38
39impl RotatE {
40    /// Create a new RotatE model
41    pub fn new(config: ModelConfig) -> Self {
42        let base = BaseModel::new(config.clone());
43
44        // Get RotatE-specific parameters
45        let adversarial_temperature = config
46            .model_params
47            .get("adversarial_temperature")
48            .copied()
49            .unwrap_or(1.0);
50
51        let modulus_constraint = config
52            .model_params
53            .get("modulus_constraint")
54            .map(|&x| x > 0.0)
55            .unwrap_or(true);
56
57        Self {
58            base,
59            entity_embeddings_real: Array2::zeros((0, config.dimensions)),
60            entity_embeddings_imag: Array2::zeros((0, config.dimensions)),
61            relation_phases: Array2::zeros((0, config.dimensions)),
62            embeddings_initialized: false,
63            adversarial_temperature,
64            modulus_constraint,
65        }
66    }
67
68    /// Initialize embeddings with proper constraints
69    fn initialize_embeddings(&mut self) {
70        if self.embeddings_initialized {
71            return;
72        }
73
74        let num_entities = self.base.num_entities();
75        let num_relations = self.base.num_relations();
76        let dimensions = self.base.config.dimensions;
77
78        if num_entities == 0 || num_relations == 0 {
79            return;
80        }
81
82        let mut rng = Random::default();
83
84        // Initialize entity embeddings with uniform distribution
85        self.entity_embeddings_real = uniform_init((num_entities, dimensions), -1.0, 1.0, &mut rng);
86
87        self.entity_embeddings_imag = uniform_init((num_entities, dimensions), -1.0, 1.0, &mut rng);
88
89        // Initialize relation phases uniformly in [0, 2π]
90        self.relation_phases = uniform_init(
91            (num_relations, dimensions),
92            0.0,
93            2.0 * std::f64::consts::PI,
94            &mut rng,
95        );
96
97        // Apply modulus constraint to entity embeddings (normalize to unit circle)
98        if self.modulus_constraint {
99            self.apply_modulus_constraint();
100        }
101
102        self.embeddings_initialized = true;
103        debug!(
104            "Initialized RotatE embeddings: {} entities, {} relations, {} dimensions",
105            num_entities, num_relations, dimensions
106        );
107    }
108
109    /// Apply modulus constraint to entity embeddings
110    fn apply_modulus_constraint(&mut self) {
111        for i in 0..self.entity_embeddings_real.nrows() {
112            let mut real_row = self.entity_embeddings_real.row_mut(i);
113            let mut imag_row = self.entity_embeddings_imag.row_mut(i);
114
115            for j in 0..real_row.len() {
116                let real = real_row[j];
117                let imag = imag_row[j];
118                let modulus = (real * real + imag * imag).sqrt();
119
120                if modulus > 1e-10 {
121                    real_row[j] = real / modulus;
122                    imag_row[j] = imag / modulus;
123                }
124            }
125        }
126    }
127
128    /// Score a triple using RotatE scoring function
129    /// Score = ||h ○ r - t||, where ○ denotes complex multiplication (rotation)
130    fn score_triple_ids(
131        &self,
132        subject_id: usize,
133        predicate_id: usize,
134        object_id: usize,
135    ) -> Result<f64> {
136        if !self.embeddings_initialized {
137            return Err(anyhow!("Model not trained"));
138        }
139
140        let h_real = self.entity_embeddings_real.row(subject_id);
141        let h_imag = self.entity_embeddings_imag.row(subject_id);
142        let r_phases = self.relation_phases.row(predicate_id);
143        let t_real = self.entity_embeddings_real.row(object_id);
144        let t_imag = self.entity_embeddings_imag.row(object_id);
145
146        // Compute h ○ r (rotation of h by r)
147        // r is represented as e^(i*θ) = cos(θ) + i*sin(θ)
148        // h ○ r = (h_real + i*h_imag) * (cos(θ) + i*sin(θ))
149        //       = (h_real*cos(θ) - h_imag*sin(θ)) + i*(h_real*sin(θ) + h_imag*cos(θ))
150
151        let mut distance_squared = 0.0;
152
153        for ((((&h_r, &h_i), &phase), &t_r), &t_i) in h_real
154            .iter()
155            .zip(h_imag.iter())
156            .zip(r_phases.iter())
157            .zip(t_real.iter())
158            .zip(t_imag.iter())
159        {
160            let cos_phase = phase.cos();
161            let sin_phase = phase.sin();
162
163            // Rotated head entity
164            let rotated_real = h_r * cos_phase - h_i * sin_phase;
165            let rotated_imag = h_r * sin_phase + h_i * cos_phase;
166
167            // Distance components
168            let diff_real = rotated_real - t_r;
169            let diff_imag = rotated_imag - t_i;
170
171            distance_squared += diff_real * diff_real + diff_imag * diff_imag;
172        }
173
174        // Return negative distance as score (higher is better)
175        Ok(-distance_squared.sqrt())
176    }
177
178    /// Compute gradients for RotatE model
179    fn compute_gradients(
180        &self,
181        pos_triple: (usize, usize, usize),
182        neg_triple: (usize, usize, usize),
183        pos_score: f64,
184        neg_score: f64,
185    ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)> {
186        let mut entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
187        let mut entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
188        let mut relation_grads = Array2::zeros(self.relation_phases.raw_dim());
189
190        // Margin-based ranking loss gradients
191        let margin = self
192            .base
193            .config
194            .model_params
195            .get("margin")
196            .copied()
197            .unwrap_or(6.0);
198        let loss = margin + (-pos_score) - (-neg_score); // Convert back to distances
199
200        if loss > 0.0 {
201            // Compute gradients for positive triple (increase distance)
202            self.add_triple_gradients(
203                pos_triple,
204                1.0,
205                &mut entity_grads_real,
206                &mut entity_grads_imag,
207                &mut relation_grads,
208            );
209
210            // Compute gradients for negative triple (decrease distance)
211            self.add_triple_gradients(
212                neg_triple,
213                -1.0,
214                &mut entity_grads_real,
215                &mut entity_grads_imag,
216                &mut relation_grads,
217            );
218        }
219
220        Ok((entity_grads_real, entity_grads_imag, relation_grads))
221    }
222
223    /// Add gradients for a single triple
224    fn add_triple_gradients(
225        &self,
226        triple: (usize, usize, usize),
227        grad_coeff: f64,
228        entity_grads_real: &mut Array2<f64>,
229        entity_grads_imag: &mut Array2<f64>,
230        relation_grads: &mut Array2<f64>,
231    ) {
232        let (s, p, o) = triple;
233
234        let h_real = self.entity_embeddings_real.row(s);
235        let h_imag = self.entity_embeddings_imag.row(s);
236        let r_phases = self.relation_phases.row(p);
237        let t_real = self.entity_embeddings_real.row(o);
238        let t_imag = self.entity_embeddings_imag.row(o);
239
240        for (i, ((((&h_r, &h_i), &phase), &t_r), &t_i)) in h_real
241            .iter()
242            .zip(h_imag.iter())
243            .zip(r_phases.iter())
244            .zip(t_real.iter())
245            .zip(t_imag.iter())
246            .enumerate()
247        {
248            let cos_phase = phase.cos();
249            let sin_phase = phase.sin();
250
251            // Rotated head entity
252            let rotated_real = h_r * cos_phase - h_i * sin_phase;
253            let rotated_imag = h_r * sin_phase + h_i * cos_phase;
254
255            // Distance components
256            let diff_real = rotated_real - t_r;
257            let diff_imag = rotated_imag - t_i;
258
259            let distance = (diff_real * diff_real + diff_imag * diff_imag).sqrt();
260
261            if distance > 1e-10 {
262                let norm_factor = grad_coeff / distance;
263                let grad_real = diff_real * norm_factor;
264                let grad_imag = diff_imag * norm_factor;
265
266                // Gradients w.r.t. head entity (subject)
267                entity_grads_real[[s, i]] += grad_real * cos_phase + grad_imag * sin_phase;
268                entity_grads_imag[[s, i]] += -grad_real * sin_phase + grad_imag * cos_phase;
269
270                // Gradients w.r.t. tail entity (object)
271                entity_grads_real[[o, i]] -= grad_real;
272                entity_grads_imag[[o, i]] -= grad_imag;
273
274                // Gradients w.r.t. relation phases
275                let phase_grad = grad_real * (-h_r * sin_phase - h_i * cos_phase)
276                    + grad_imag * (h_r * cos_phase - h_i * sin_phase);
277                relation_grads[[p, i]] += phase_grad;
278            }
279        }
280    }
281
282    /// Generate adversarial negative samples
283    fn generate_adversarial_negatives(
284        &self,
285        positive_triple: (usize, usize, usize),
286        num_samples: usize,
287        rng: &mut Random,
288    ) -> Vec<(usize, usize, usize)> {
289        let mut negatives = Vec::new();
290        let num_entities = self.base.num_entities();
291
292        for _ in 0..num_samples {
293            // Choose to corrupt either head or tail
294            let corrupt_head = rng.random_f64() < 0.5;
295
296            if corrupt_head {
297                // Sample entity according to adversarial distribution
298                let mut candidate_scores = Vec::new();
299                for entity_id in 0..num_entities {
300                    if entity_id != positive_triple.0 {
301                        let neg_triple = (entity_id, positive_triple.1, positive_triple.2);
302                        if let Ok(score) =
303                            self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)
304                        {
305                            candidate_scores.push((entity_id, score));
306                        }
307                    }
308                }
309
310                if !candidate_scores.is_empty() {
311                    // Use adversarial sampling based on scores
312                    let weights: Vec<f64> = candidate_scores
313                        .iter()
314                        .map(|(_, score)| (-score / self.adversarial_temperature).exp())
315                        .collect();
316
317                    let total_weight: f64 = weights.iter().sum();
318                    let mut cumulative = 0.0;
319                    let threshold = rng.random_f64() * total_weight;
320
321                    for (i, &weight) in weights.iter().enumerate() {
322                        cumulative += weight;
323                        if cumulative >= threshold {
324                            let entity_id = candidate_scores[i].0;
325                            negatives.push((entity_id, positive_triple.1, positive_triple.2));
326                            break;
327                        }
328                    }
329                }
330            } else {
331                // Similar logic for corrupting tail
332                let mut candidate_scores = Vec::new();
333                for entity_id in 0..num_entities {
334                    if entity_id != positive_triple.2 {
335                        let neg_triple = (positive_triple.0, positive_triple.1, entity_id);
336                        if let Ok(score) =
337                            self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)
338                        {
339                            candidate_scores.push((entity_id, score));
340                        }
341                    }
342                }
343
344                if !candidate_scores.is_empty() {
345                    let weights: Vec<f64> = candidate_scores
346                        .iter()
347                        .map(|(_, score)| (-score / self.adversarial_temperature).exp())
348                        .collect();
349
350                    let total_weight: f64 = weights.iter().sum();
351                    let mut cumulative = 0.0;
352                    let threshold = rng.random_f64() * total_weight;
353
354                    for (i, &weight) in weights.iter().enumerate() {
355                        cumulative += weight;
356                        if cumulative >= threshold {
357                            let entity_id = candidate_scores[i].0;
358                            negatives.push((positive_triple.0, positive_triple.1, entity_id));
359                            break;
360                        }
361                    }
362                }
363            }
364        }
365
366        // Fall back to uniform sampling if adversarial sampling fails
367        while negatives.len() < num_samples {
368            let corrupt_head = rng.random_f64() < 0.5;
369            let negative_triple = if corrupt_head {
370                let new_head = rng.random_range(0..num_entities);
371                (new_head, positive_triple.1, positive_triple.2)
372            } else {
373                let new_tail = rng.random_range(0..num_entities);
374                (positive_triple.0, positive_triple.1, new_tail)
375            };
376
377            if !self
378                .base
379                .has_triple(negative_triple.0, negative_triple.1, negative_triple.2)
380            {
381                negatives.push(negative_triple);
382            }
383        }
384
385        negatives
386    }
387
388    /// Perform one training epoch
389    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
390        let mut rng = Random::default();
391
392        let mut total_loss = 0.0;
393        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
394            / self.base.config.batch_size;
395
396        let mut shuffled_triples = self.base.triples.clone();
397        // Manual Fisher-Yates shuffle using scirs2-core
398        for i in (1..shuffled_triples.len()).rev() {
399            let j = rng.random_range(0..i + 1);
400            shuffled_triples.swap(i, j);
401        }
402
403        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
404            let mut batch_entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
405            let mut batch_entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
406            let mut batch_relation_grads = Array2::zeros(self.relation_phases.raw_dim());
407            let mut batch_loss = 0.0;
408
409            for &pos_triple in batch_triples {
410                // Use adversarial negative sampling
411                let neg_samples = self.generate_adversarial_negatives(
412                    pos_triple,
413                    self.base.config.negative_samples,
414                    &mut rng,
415                );
416
417                for neg_triple in neg_samples {
418                    let pos_score =
419                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
420                    let neg_score =
421                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
422
423                    // Convert scores back to distances for loss computation
424                    let pos_distance = -pos_score;
425                    let neg_distance = -neg_score;
426
427                    let margin = self
428                        .base
429                        .config
430                        .model_params
431                        .get("margin")
432                        .copied()
433                        .unwrap_or(6.0);
434                    let loss = margin_loss(pos_distance, neg_distance, margin);
435                    batch_loss += loss;
436
437                    if loss > 0.0 {
438                        let (entity_grads_real, entity_grads_imag, relation_grads) =
439                            self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
440
441                        batch_entity_grads_real += &entity_grads_real;
442                        batch_entity_grads_imag += &entity_grads_imag;
443                        batch_relation_grads += &relation_grads;
444                    }
445                }
446            }
447
448            // Apply gradients with regularization
449            gradient_update(
450                &mut self.entity_embeddings_real,
451                &batch_entity_grads_real,
452                learning_rate,
453                self.base.config.l2_reg,
454            );
455
456            gradient_update(
457                &mut self.entity_embeddings_imag,
458                &batch_entity_grads_imag,
459                learning_rate,
460                self.base.config.l2_reg,
461            );
462
463            gradient_update(
464                &mut self.relation_phases,
465                &batch_relation_grads,
466                learning_rate,
467                0.0, // No regularization on phases
468            );
469
470            // Apply modulus constraint
471            if self.modulus_constraint {
472                self.apply_modulus_constraint();
473            }
474
475            // Constrain relation phases to [0, 2π]
476            self.relation_phases.mapv_inplace(|x| {
477                let mut angle = x % (2.0 * std::f64::consts::PI);
478                if angle < 0.0 {
479                    angle += 2.0 * std::f64::consts::PI;
480                }
481                angle
482            });
483
484            total_loss += batch_loss;
485        }
486
487        Ok(total_loss / num_batches as f64)
488    }
489
490    /// Get entity embedding as concatenated real/imaginary vector
491    fn get_entity_embedding_vector(&self, entity_id: usize) -> Vector {
492        let real_part = self.entity_embeddings_real.row(entity_id);
493        let imag_part = self.entity_embeddings_imag.row(entity_id);
494
495        let mut values = Vec::with_capacity(real_part.len() * 2);
496        for &val in real_part.iter() {
497            values.push(val as f32);
498        }
499        for &val in imag_part.iter() {
500            values.push(val as f32);
501        }
502
503        Vector::new(values)
504    }
505
506    /// Get relation embedding as phase vector
507    fn get_relation_embedding_vector(&self, relation_id: usize) -> Vector {
508        let phases = self.relation_phases.row(relation_id);
509        let values: Vec<f32> = phases.iter().copied().map(|x| x as f32).collect();
510        Vector::new(values)
511    }
512}
513
514#[async_trait]
515impl EmbeddingModel for RotatE {
516    fn config(&self) -> &ModelConfig {
517        &self.base.config
518    }
519
520    fn model_id(&self) -> &Uuid {
521        &self.base.model_id
522    }
523
524    fn model_type(&self) -> &'static str {
525        "RotatE"
526    }
527
528    fn add_triple(&mut self, triple: Triple) -> Result<()> {
529        self.base.add_triple(triple)
530    }
531
532    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
533        let start_time = Instant::now();
534        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
535
536        self.initialize_embeddings();
537
538        if !self.embeddings_initialized {
539            return Err(anyhow!("No training data available"));
540        }
541
542        let mut loss_history = Vec::new();
543        let learning_rate = self.base.config.learning_rate;
544
545        info!("Starting RotatE training for {} epochs", max_epochs);
546
547        for epoch in 0..max_epochs {
548            let epoch_loss = self.train_epoch(learning_rate).await?;
549            loss_history.push(epoch_loss);
550
551            if epoch % 100 == 0 {
552                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
553            }
554
555            if epoch > 10 && epoch_loss < 1e-6 {
556                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
557                break;
558            }
559        }
560
561        self.base.mark_trained();
562        let training_time = start_time.elapsed().as_secs_f64();
563
564        Ok(TrainingStats {
565            epochs_completed: loss_history.len(),
566            final_loss: loss_history.last().copied().unwrap_or(0.0),
567            training_time_seconds: training_time,
568            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
569            loss_history,
570        })
571    }
572
573    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
574        if !self.embeddings_initialized {
575            return Err(anyhow!("Model not trained"));
576        }
577
578        let entity_id = self
579            .base
580            .get_entity_id(entity)
581            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
582
583        Ok(self.get_entity_embedding_vector(entity_id))
584    }
585
586    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
587        if !self.embeddings_initialized {
588            return Err(anyhow!("Model not trained"));
589        }
590
591        let relation_id = self
592            .base
593            .get_relation_id(relation)
594            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
595
596        Ok(self.get_relation_embedding_vector(relation_id))
597    }
598
599    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
600        let subject_id = self
601            .base
602            .get_entity_id(subject)
603            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
604        let predicate_id = self
605            .base
606            .get_relation_id(predicate)
607            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
608        let object_id = self
609            .base
610            .get_entity_id(object)
611            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
612
613        self.score_triple_ids(subject_id, predicate_id, object_id)
614    }
615
616    fn predict_objects(
617        &self,
618        subject: &str,
619        predicate: &str,
620        k: usize,
621    ) -> Result<Vec<(String, f64)>> {
622        if !self.embeddings_initialized {
623            return Err(anyhow!("Model not trained"));
624        }
625
626        let subject_id = self
627            .base
628            .get_entity_id(subject)
629            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
630        let predicate_id = self
631            .base
632            .get_relation_id(predicate)
633            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
634
635        let mut scores = Vec::new();
636
637        for object_id in 0..self.base.num_entities() {
638            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
639            let object_name = self.base.get_entity(object_id).unwrap().clone();
640            scores.push((object_name, score));
641        }
642
643        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
644        scores.truncate(k);
645
646        Ok(scores)
647    }
648
649    fn predict_subjects(
650        &self,
651        predicate: &str,
652        object: &str,
653        k: usize,
654    ) -> Result<Vec<(String, f64)>> {
655        if !self.embeddings_initialized {
656            return Err(anyhow!("Model not trained"));
657        }
658
659        let predicate_id = self
660            .base
661            .get_relation_id(predicate)
662            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
663        let object_id = self
664            .base
665            .get_entity_id(object)
666            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
667
668        let mut scores = Vec::new();
669
670        for subject_id in 0..self.base.num_entities() {
671            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
672            let subject_name = self.base.get_entity(subject_id).unwrap().clone();
673            scores.push((subject_name, score));
674        }
675
676        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
677        scores.truncate(k);
678
679        Ok(scores)
680    }
681
682    fn predict_relations(
683        &self,
684        subject: &str,
685        object: &str,
686        k: usize,
687    ) -> Result<Vec<(String, f64)>> {
688        if !self.embeddings_initialized {
689            return Err(anyhow!("Model not trained"));
690        }
691
692        let subject_id = self
693            .base
694            .get_entity_id(subject)
695            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
696        let object_id = self
697            .base
698            .get_entity_id(object)
699            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
700
701        let mut scores = Vec::new();
702
703        for predicate_id in 0..self.base.num_relations() {
704            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
705            let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
706            scores.push((predicate_name, score));
707        }
708
709        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
710        scores.truncate(k);
711
712        Ok(scores)
713    }
714
715    fn get_entities(&self) -> Vec<String> {
716        self.base.get_entities()
717    }
718
719    fn get_relations(&self) -> Vec<String> {
720        self.base.get_relations()
721    }
722
723    fn get_stats(&self) -> ModelStats {
724        self.base.get_stats("RotatE")
725    }
726
727    fn save(&self, path: &str) -> Result<()> {
728        info!("Saving RotatE model to {}", path);
729        Ok(())
730    }
731
732    fn load(&mut self, path: &str) -> Result<()> {
733        info!("Loading RotatE model from {}", path);
734        Ok(())
735    }
736
737    fn clear(&mut self) {
738        self.base.clear();
739        self.entity_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
740        self.entity_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
741        self.relation_phases = Array2::zeros((0, self.base.config.dimensions));
742        self.embeddings_initialized = false;
743    }
744
745    fn is_trained(&self) -> bool {
746        self.base.is_trained
747    }
748
749    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
750        Err(anyhow!(
751            "Knowledge graph embedding model does not support text encoding"
752        ))
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[tokio::test]
761    async fn test_rotate_basic() -> Result<()> {
762        let config = ModelConfig::default()
763            .with_dimensions(10)
764            .with_max_epochs(5)
765            .with_seed(42);
766
767        let mut model = RotatE::new(config);
768
769        let alice = crate::NamedNode::new("http://example.org/alice")?;
770        let knows = crate::NamedNode::new("http://example.org/knows")?;
771        let bob = crate::NamedNode::new("http://example.org/bob")?;
772
773        model.add_triple(crate::Triple::new(
774            alice.clone(),
775            knows.clone(),
776            bob.clone(),
777        ))?;
778
779        let stats = model.train(Some(3)).await?;
780        assert!(stats.epochs_completed > 0);
781
782        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
783        assert_eq!(alice_emb.dimensions, 20); // 2 * 10 (real + imaginary)
784
785        let score = model.score_triple(
786            "http://example.org/alice",
787            "http://example.org/knows",
788            "http://example.org/bob",
789        )?;
790
791        assert!(score.is_finite());
792
793        Ok(())
794    }
795}