Skip to main content

oxirs_embed/models/
complex.rs

1//! ComplEx: Complex Embeddings for Simple Link Prediction
2//!
3//! ComplEx uses complex-valued embeddings to better model asymmetric relations.
4//! The scoring function is: Re(<h, r, conj(t)>) where Re denotes real part,
5//! <> denotes complex dot product, and conj denotes complex conjugate.
6//!
7//! Reference: Trouillon et al. "Complex Embeddings for Simple Link Prediction" (2016)
8
9use crate::models::{common::*, BaseModel};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use scirs2_core::ndarray_ext::Array2;
14#[allow(unused_imports)]
15use scirs2_core::random::{Random, Rng};
16use serde::{Deserialize, Serialize};
17use std::ops::AddAssign;
18use std::time::Instant;
19use tracing::{debug, info};
20use uuid::Uuid;
21
22/// Type alias for gradient tensors
23type GradientTuple = (Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>);
24
25/// ComplEx embedding model using complex-valued embeddings
26#[derive(Debug)]
27pub struct ComplEx {
28    /// Base model functionality
29    base: BaseModel,
30    /// Real part of entity embeddings (num_entities × dimensions)
31    entity_embeddings_real: Array2<f64>,
32    /// Imaginary part of entity embeddings (num_entities × dimensions)
33    entity_embeddings_imag: Array2<f64>,
34    /// Real part of relation embeddings (num_relations × dimensions)
35    relation_embeddings_real: Array2<f64>,
36    /// Imaginary part of relation embeddings (num_relations × dimensions)
37    relation_embeddings_imag: Array2<f64>,
38    /// Whether embeddings have been initialized
39    embeddings_initialized: bool,
40    /// Regularization method
41    regularization: RegularizationType,
42}
43
44/// Regularization types for ComplEx
45#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
46pub enum RegularizationType {
47    /// L2 regularization on embeddings
48    L2,
49    /// N3 regularization (nuclear 3-norm)
50    N3,
51    /// No additional regularization
52    None,
53}
54
55impl ComplEx {
56    /// Create a new ComplEx model
57    pub fn new(config: ModelConfig) -> Self {
58        let base = BaseModel::new(config.clone());
59
60        // Get ComplEx-specific parameters
61        let regularization = match config.model_params.get("regularization") {
62            Some(0.0) => RegularizationType::None,
63            Some(1.0) => RegularizationType::L2,
64            Some(2.0) => RegularizationType::N3,
65            _ => RegularizationType::N3, // Default to N3
66        };
67
68        Self {
69            base,
70            entity_embeddings_real: Array2::zeros((0, config.dimensions)),
71            entity_embeddings_imag: Array2::zeros((0, config.dimensions)),
72            relation_embeddings_real: Array2::zeros((0, config.dimensions)),
73            relation_embeddings_imag: Array2::zeros((0, config.dimensions)),
74            embeddings_initialized: false,
75            regularization,
76        }
77    }
78
79    /// Initialize complex embeddings
80    fn initialize_embeddings(&mut self) {
81        if self.embeddings_initialized {
82            return;
83        }
84
85        let num_entities = self.base.num_entities();
86        let num_relations = self.base.num_relations();
87        let dimensions = self.base.config.dimensions;
88
89        if num_entities == 0 || num_relations == 0 {
90            return;
91        }
92
93        let mut rng = Random::default();
94
95        // Initialize all embedding components with Xavier initialization
96        self.entity_embeddings_real =
97            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
98
99        self.entity_embeddings_imag =
100            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
101
102        self.relation_embeddings_real = xavier_init(
103            (num_relations, dimensions),
104            dimensions,
105            dimensions,
106            &mut rng,
107        );
108
109        self.relation_embeddings_imag = xavier_init(
110            (num_relations, dimensions),
111            dimensions,
112            dimensions,
113            &mut rng,
114        );
115
116        self.embeddings_initialized = true;
117        debug!(
118            "Initialized ComplEx embeddings: {} entities, {} relations, {} dimensions",
119            num_entities, num_relations, dimensions
120        );
121    }
122
123    /// Score a triple using ComplEx scoring function
124    /// Score = Re(<h, r, conj(t)>) = Re(h) * Re(r) * Re(t) + Re(h) * Im(r) * Im(t) +
125    ///                                 Im(h) * Re(r) * Im(t) - Im(h) * Im(r) * Re(t)
126    fn score_triple_ids(
127        &self,
128        subject_id: usize,
129        predicate_id: usize,
130        object_id: usize,
131    ) -> Result<f64> {
132        if !self.embeddings_initialized {
133            return Err(anyhow!("Model not trained"));
134        }
135
136        let h_real = self.entity_embeddings_real.row(subject_id);
137        let h_imag = self.entity_embeddings_imag.row(subject_id);
138        let r_real = self.relation_embeddings_real.row(predicate_id);
139        let r_imag = self.relation_embeddings_imag.row(predicate_id);
140        let t_real = self.entity_embeddings_real.row(object_id);
141        let t_imag = self.entity_embeddings_imag.row(object_id);
142
143        // Complex multiplication: (h_real + i*h_imag) * (r_real + i*r_imag) * conj(t_real + i*t_imag)
144        // = (h_real + i*h_imag) * (r_real + i*r_imag) * (t_real - i*t_imag)
145        let score = (&h_real * &r_real * t_real).sum()
146            + (&h_real * &r_imag * t_imag).sum()
147            + (&h_imag * &r_real * t_imag).sum()
148            - (&h_imag * &r_imag * t_real).sum();
149
150        Ok(score)
151    }
152
153    /// Compute gradients for ComplEx model
154    fn compute_gradients(
155        &self,
156        pos_triple: (usize, usize, usize),
157        neg_triple: (usize, usize, usize),
158        pos_score: f64,
159        neg_score: f64,
160    ) -> Result<GradientTuple> {
161        let mut entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
162        let mut entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
163        let mut relation_grads_real = Array2::zeros(self.relation_embeddings_real.raw_dim());
164        let mut relation_grads_imag = Array2::zeros(self.relation_embeddings_imag.raw_dim());
165
166        // Logistic loss gradients
167        let pos_sigmoid = sigmoid(pos_score);
168        let neg_sigmoid = sigmoid(neg_score);
169
170        let pos_grad_coeff = pos_sigmoid - 1.0; // Derivative of log(sigmoid(x))
171        let neg_grad_coeff = neg_sigmoid; // Derivative of log(1 - sigmoid(x))
172
173        // Compute gradients for positive triple
174        self.add_triple_gradients(
175            pos_triple,
176            pos_grad_coeff,
177            &mut entity_grads_real,
178            &mut entity_grads_imag,
179            &mut relation_grads_real,
180            &mut relation_grads_imag,
181        );
182
183        // Compute gradients for negative triple
184        self.add_triple_gradients(
185            neg_triple,
186            neg_grad_coeff,
187            &mut entity_grads_real,
188            &mut entity_grads_imag,
189            &mut relation_grads_real,
190            &mut relation_grads_imag,
191        );
192
193        Ok((
194            entity_grads_real,
195            entity_grads_imag,
196            relation_grads_real,
197            relation_grads_imag,
198        ))
199    }
200
201    /// Add gradients for a single triple
202    fn add_triple_gradients(
203        &self,
204        triple: (usize, usize, usize),
205        grad_coeff: f64,
206        entity_grads_real: &mut Array2<f64>,
207        entity_grads_imag: &mut Array2<f64>,
208        relation_grads_real: &mut Array2<f64>,
209        relation_grads_imag: &mut Array2<f64>,
210    ) {
211        let (s, p, o) = triple;
212
213        let h_real = self.entity_embeddings_real.row(s);
214        let h_imag = self.entity_embeddings_imag.row(s);
215        let r_real = self.relation_embeddings_real.row(p);
216        let r_imag = self.relation_embeddings_imag.row(p);
217        let t_real = self.entity_embeddings_real.row(o);
218        let t_imag = self.entity_embeddings_imag.row(o);
219
220        // Gradients w.r.t. h (subject)
221        // ∂score/∂h_real = r_real * t_real + r_imag * t_imag
222        // ∂score/∂h_imag = r_real * t_imag - r_imag * t_real
223        let h_real_grad = (&r_real * &t_real + &r_imag * &t_imag) * grad_coeff;
224        let h_imag_grad = (&r_real * &t_imag - &r_imag * &t_real) * grad_coeff;
225
226        entity_grads_real.row_mut(s).add_assign(&h_real_grad);
227        entity_grads_imag.row_mut(s).add_assign(&h_imag_grad);
228
229        // Gradients w.r.t. r (relation)
230        // ∂score/∂r_real = h_real * t_real + h_imag * t_imag
231        // ∂score/∂r_imag = h_real * t_imag - h_imag * t_real
232        let r_real_grad = (&h_real * &t_real + &h_imag * &t_imag) * grad_coeff;
233        let r_imag_grad = (&h_real * &t_imag - &h_imag * &t_real) * grad_coeff;
234
235        relation_grads_real.row_mut(p).add_assign(&r_real_grad);
236        relation_grads_imag.row_mut(p).add_assign(&r_imag_grad);
237
238        // Gradients w.r.t. t (object) - note the conjugate
239        // ∂score/∂t_real = h_real * r_real - h_imag * r_imag
240        // ∂score/∂t_imag = -(h_real * r_imag + h_imag * r_real)
241        let t_real_grad = (&h_real * &r_real - &h_imag * &r_imag) * grad_coeff;
242        let t_imag_grad = -(&h_real * &r_imag + &h_imag * &r_real) * grad_coeff;
243
244        entity_grads_real.row_mut(o).add_assign(&t_real_grad);
245        entity_grads_imag.row_mut(o).add_assign(&t_imag_grad);
246    }
247
248    /// Apply N3 regularization
249    fn apply_n3_regularization(
250        &self,
251        entity_grads_real: &mut Array2<f64>,
252        entity_grads_imag: &mut Array2<f64>,
253        relation_grads_real: &mut Array2<f64>,
254        relation_grads_imag: &mut Array2<f64>,
255        regularization_weight: f64,
256    ) {
257        // N3 regularization: penalize the nuclear 3-norm
258        // For complex embeddings, this becomes more involved
259        // For simplicity, we apply L2 regularization here
260        // A full N3 implementation would require more complex tensor operations
261
262        *entity_grads_real += &(&self.entity_embeddings_real * regularization_weight);
263        *entity_grads_imag += &(&self.entity_embeddings_imag * regularization_weight);
264        *relation_grads_real += &(&self.relation_embeddings_real * regularization_weight);
265        *relation_grads_imag += &(&self.relation_embeddings_imag * regularization_weight);
266    }
267
268    /// Perform one training epoch
269    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
270        let mut rng = Random::default();
271
272        let mut total_loss = 0.0;
273        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
274            / self.base.config.batch_size;
275
276        // Create shuffled batches
277        let mut shuffled_triples = self.base.triples.clone();
278        // Manual Fisher-Yates shuffle using scirs2-core
279        for i in (1..shuffled_triples.len()).rev() {
280            let j = rng.random_range(0..i + 1);
281            shuffled_triples.swap(i, j);
282        }
283
284        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
285            let mut batch_entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
286            let mut batch_entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
287            let mut batch_relation_grads_real =
288                Array2::zeros(self.relation_embeddings_real.raw_dim());
289            let mut batch_relation_grads_imag =
290                Array2::zeros(self.relation_embeddings_imag.raw_dim());
291            let mut batch_loss = 0.0;
292
293            for &pos_triple in batch_triples {
294                // Generate negative samples
295                let neg_samples = self
296                    .base
297                    .generate_negative_samples(self.base.config.negative_samples, &mut rng);
298
299                for neg_triple in neg_samples {
300                    // Compute scores
301                    let pos_score =
302                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
303                    let neg_score =
304                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
305
306                    // Compute logistic loss
307                    let pos_loss = logistic_loss(pos_score, 1.0);
308                    let neg_loss = logistic_loss(neg_score, -1.0);
309                    let total_triple_loss = pos_loss + neg_loss;
310
311                    batch_loss += total_triple_loss;
312
313                    // Compute and accumulate gradients
314                    let (
315                        entity_grads_real,
316                        entity_grads_imag,
317                        relation_grads_real,
318                        relation_grads_imag,
319                    ) = self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
320
321                    batch_entity_grads_real += &entity_grads_real;
322                    batch_entity_grads_imag += &entity_grads_imag;
323                    batch_relation_grads_real += &relation_grads_real;
324                    batch_relation_grads_imag += &relation_grads_imag;
325                }
326            }
327
328            // Apply regularization
329            match self.regularization {
330                RegularizationType::L2 => {
331                    let reg_weight = self.base.config.l2_reg;
332                    batch_entity_grads_real += &(&self.entity_embeddings_real * reg_weight);
333                    batch_entity_grads_imag += &(&self.entity_embeddings_imag * reg_weight);
334                    batch_relation_grads_real += &(&self.relation_embeddings_real * reg_weight);
335                    batch_relation_grads_imag += &(&self.relation_embeddings_imag * reg_weight);
336                }
337                RegularizationType::N3 => {
338                    self.apply_n3_regularization(
339                        &mut batch_entity_grads_real,
340                        &mut batch_entity_grads_imag,
341                        &mut batch_relation_grads_real,
342                        &mut batch_relation_grads_imag,
343                        self.base.config.l2_reg,
344                    );
345                }
346                RegularizationType::None => {}
347            }
348
349            // Apply gradients
350            self.entity_embeddings_real -= &(&batch_entity_grads_real * learning_rate);
351            self.entity_embeddings_imag -= &(&batch_entity_grads_imag * learning_rate);
352            self.relation_embeddings_real -= &(&batch_relation_grads_real * learning_rate);
353            self.relation_embeddings_imag -= &(&batch_relation_grads_imag * learning_rate);
354
355            total_loss += batch_loss;
356        }
357
358        Ok(total_loss / num_batches as f64)
359    }
360
361    /// Get entity embedding as a concatenated real/imaginary vector
362    fn get_entity_embedding_vector(&self, entity_id: usize) -> Vector {
363        let real_part = self.entity_embeddings_real.row(entity_id);
364        let imag_part = self.entity_embeddings_imag.row(entity_id);
365
366        // Concatenate real and imaginary parts
367        let mut values = Vec::with_capacity(real_part.len() * 2);
368        for &val in real_part.iter() {
369            values.push(val as f32);
370        }
371        for &val in imag_part.iter() {
372            values.push(val as f32);
373        }
374
375        Vector::new(values)
376    }
377
378    /// Get relation embedding as a concatenated real/imaginary vector
379    fn get_relation_embedding_vector(&self, relation_id: usize) -> Vector {
380        let real_part = self.relation_embeddings_real.row(relation_id);
381        let imag_part = self.relation_embeddings_imag.row(relation_id);
382
383        // Concatenate real and imaginary parts
384        let mut values = Vec::with_capacity(real_part.len() * 2);
385        for &val in real_part.iter() {
386            values.push(val as f32);
387        }
388        for &val in imag_part.iter() {
389            values.push(val as f32);
390        }
391
392        Vector::new(values)
393    }
394}
395
396#[async_trait]
397impl EmbeddingModel for ComplEx {
398    fn config(&self) -> &ModelConfig {
399        &self.base.config
400    }
401
402    fn model_id(&self) -> &Uuid {
403        &self.base.model_id
404    }
405
406    fn model_type(&self) -> &'static str {
407        "ComplEx"
408    }
409
410    fn add_triple(&mut self, triple: Triple) -> Result<()> {
411        self.base.add_triple(triple)
412    }
413
414    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
415        let start_time = Instant::now();
416        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
417
418        // Initialize embeddings if needed
419        self.initialize_embeddings();
420
421        if !self.embeddings_initialized {
422            return Err(anyhow!("No training data available"));
423        }
424
425        let mut loss_history = Vec::new();
426        let learning_rate = self.base.config.learning_rate;
427
428        info!("Starting ComplEx training for {} epochs", max_epochs);
429
430        for epoch in 0..max_epochs {
431            let epoch_loss = self.train_epoch(learning_rate).await?;
432            loss_history.push(epoch_loss);
433
434            if epoch % 100 == 0 {
435                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
436            }
437
438            // Simple convergence check
439            if epoch > 10 && epoch_loss < 1e-6 {
440                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
441                break;
442            }
443        }
444
445        self.base.mark_trained();
446        let training_time = start_time.elapsed().as_secs_f64();
447
448        Ok(TrainingStats {
449            epochs_completed: loss_history.len(),
450            final_loss: loss_history.last().copied().unwrap_or(0.0),
451            training_time_seconds: training_time,
452            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
453            loss_history,
454        })
455    }
456
457    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
458        if !self.embeddings_initialized {
459            return Err(anyhow!("Model not trained"));
460        }
461
462        let entity_id = self
463            .base
464            .get_entity_id(entity)
465            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
466
467        Ok(self.get_entity_embedding_vector(entity_id))
468    }
469
470    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
471        if !self.embeddings_initialized {
472            return Err(anyhow!("Model not trained"));
473        }
474
475        let relation_id = self
476            .base
477            .get_relation_id(relation)
478            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
479
480        Ok(self.get_relation_embedding_vector(relation_id))
481    }
482
483    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
484        let subject_id = self
485            .base
486            .get_entity_id(subject)
487            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
488        let predicate_id = self
489            .base
490            .get_relation_id(predicate)
491            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
492        let object_id = self
493            .base
494            .get_entity_id(object)
495            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
496
497        self.score_triple_ids(subject_id, predicate_id, object_id)
498    }
499
500    fn predict_objects(
501        &self,
502        subject: &str,
503        predicate: &str,
504        k: usize,
505    ) -> Result<Vec<(String, f64)>> {
506        if !self.embeddings_initialized {
507            return Err(anyhow!("Model not trained"));
508        }
509
510        let subject_id = self
511            .base
512            .get_entity_id(subject)
513            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
514        let predicate_id = self
515            .base
516            .get_relation_id(predicate)
517            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
518
519        let mut scores = Vec::new();
520
521        for object_id in 0..self.base.num_entities() {
522            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
523            let object_name = self
524                .base
525                .get_entity(object_id)
526                .expect("entity should exist for valid id")
527                .clone();
528            scores.push((object_name, score));
529        }
530
531        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
532        scores.truncate(k);
533
534        Ok(scores)
535    }
536
537    fn predict_subjects(
538        &self,
539        predicate: &str,
540        object: &str,
541        k: usize,
542    ) -> Result<Vec<(String, f64)>> {
543        if !self.embeddings_initialized {
544            return Err(anyhow!("Model not trained"));
545        }
546
547        let predicate_id = self
548            .base
549            .get_relation_id(predicate)
550            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
551        let object_id = self
552            .base
553            .get_entity_id(object)
554            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
555
556        let mut scores = Vec::new();
557
558        for subject_id in 0..self.base.num_entities() {
559            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
560            let subject_name = self
561                .base
562                .get_entity(subject_id)
563                .expect("entity should exist for valid id")
564                .clone();
565            scores.push((subject_name, score));
566        }
567
568        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
569        scores.truncate(k);
570
571        Ok(scores)
572    }
573
574    fn predict_relations(
575        &self,
576        subject: &str,
577        object: &str,
578        k: usize,
579    ) -> Result<Vec<(String, f64)>> {
580        if !self.embeddings_initialized {
581            return Err(anyhow!("Model not trained"));
582        }
583
584        let subject_id = self
585            .base
586            .get_entity_id(subject)
587            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
588        let object_id = self
589            .base
590            .get_entity_id(object)
591            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
592
593        let mut scores = Vec::new();
594
595        for predicate_id in 0..self.base.num_relations() {
596            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
597            let predicate_name = self
598                .base
599                .get_relation(predicate_id)
600                .expect("relation should exist for valid id")
601                .clone();
602            scores.push((predicate_name, score));
603        }
604
605        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
606        scores.truncate(k);
607
608        Ok(scores)
609    }
610
611    fn get_entities(&self) -> Vec<String> {
612        self.base.get_entities()
613    }
614
615    fn get_relations(&self) -> Vec<String> {
616        self.base.get_relations()
617    }
618
619    fn get_stats(&self) -> ModelStats {
620        self.base.get_stats("ComplEx")
621    }
622
623    fn save(&self, path: &str) -> Result<()> {
624        info!("Saving ComplEx model to {}", path);
625        Ok(())
626    }
627
628    fn load(&mut self, path: &str) -> Result<()> {
629        info!("Loading ComplEx model from {}", path);
630        Ok(())
631    }
632
633    fn clear(&mut self) {
634        self.base.clear();
635        self.entity_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
636        self.entity_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
637        self.relation_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
638        self.relation_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
639        self.embeddings_initialized = false;
640    }
641
642    fn is_trained(&self) -> bool {
643        self.base.is_trained
644    }
645
646    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
647        Err(anyhow!(
648            "Knowledge graph embedding model does not support text encoding"
649        ))
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use crate::NamedNode;
657
658    #[tokio::test]
659    async fn test_complex_basic() -> Result<()> {
660        let config = ModelConfig::default()
661            .with_dimensions(50)
662            .with_max_epochs(10)
663            .with_seed(42);
664
665        let mut model = ComplEx::new(config);
666
667        // Add test triples
668        let alice = NamedNode::new("http://example.org/alice")?;
669        let knows = NamedNode::new("http://example.org/knows")?;
670        let bob = NamedNode::new("http://example.org/bob")?;
671
672        model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
673
674        // Train
675        let stats = model.train(Some(5)).await?;
676        assert!(stats.epochs_completed > 0);
677
678        // Test embeddings (should be 2x dimensions due to complex)
679        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
680        assert_eq!(alice_emb.dimensions, 100); // 2 * 50
681
682        // Test scoring
683        let score = model.score_triple(
684            "http://example.org/alice",
685            "http://example.org/knows",
686            "http://example.org/bob",
687        )?;
688
689        // Score should be a finite number
690        assert!(score.is_finite());
691
692        Ok(())
693    }
694}