oxirs_embed/models/
complex.rs

1//! ComplEx: Complex Embeddings for Simple Link Prediction
2//!
3//! ComplEx uses complex-valued embeddings to better model asymmetric relations.
4//! The scoring function is: Re(<h, r, conj(t)>) where Re denotes real part,
5//! <> denotes complex dot product, and conj denotes complex conjugate.
6//!
7//! Reference: Trouillon et al. "Complex Embeddings for Simple Link Prediction" (2016)
8
9use crate::models::{common::*, BaseModel};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use scirs2_core::ndarray_ext::Array2;
14#[allow(unused_imports)]
15use scirs2_core::random::{Random, Rng};
16use serde::{Deserialize, Serialize};
17use std::ops::AddAssign;
18use std::time::Instant;
19use tracing::{debug, info};
20use uuid::Uuid;
21
22/// Type alias for gradient tensors
23type GradientTuple = (Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>);
24
25/// ComplEx embedding model using complex-valued embeddings
26#[derive(Debug)]
27pub struct ComplEx {
28    /// Base model functionality
29    base: BaseModel,
30    /// Real part of entity embeddings (num_entities × dimensions)
31    entity_embeddings_real: Array2<f64>,
32    /// Imaginary part of entity embeddings (num_entities × dimensions)
33    entity_embeddings_imag: Array2<f64>,
34    /// Real part of relation embeddings (num_relations × dimensions)
35    relation_embeddings_real: Array2<f64>,
36    /// Imaginary part of relation embeddings (num_relations × dimensions)
37    relation_embeddings_imag: Array2<f64>,
38    /// Whether embeddings have been initialized
39    embeddings_initialized: bool,
40    /// Regularization method
41    regularization: RegularizationType,
42}
43
44/// Regularization types for ComplEx
45#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
46pub enum RegularizationType {
47    /// L2 regularization on embeddings
48    L2,
49    /// N3 regularization (nuclear 3-norm)
50    N3,
51    /// No additional regularization
52    None,
53}
54
55impl ComplEx {
56    /// Create a new ComplEx model
57    pub fn new(config: ModelConfig) -> Self {
58        let base = BaseModel::new(config.clone());
59
60        // Get ComplEx-specific parameters
61        let regularization = match config.model_params.get("regularization") {
62            Some(0.0) => RegularizationType::None,
63            Some(1.0) => RegularizationType::L2,
64            Some(2.0) => RegularizationType::N3,
65            _ => RegularizationType::N3, // Default to N3
66        };
67
68        Self {
69            base,
70            entity_embeddings_real: Array2::zeros((0, config.dimensions)),
71            entity_embeddings_imag: Array2::zeros((0, config.dimensions)),
72            relation_embeddings_real: Array2::zeros((0, config.dimensions)),
73            relation_embeddings_imag: Array2::zeros((0, config.dimensions)),
74            embeddings_initialized: false,
75            regularization,
76        }
77    }
78
79    /// Initialize complex embeddings
80    fn initialize_embeddings(&mut self) {
81        if self.embeddings_initialized {
82            return;
83        }
84
85        let num_entities = self.base.num_entities();
86        let num_relations = self.base.num_relations();
87        let dimensions = self.base.config.dimensions;
88
89        if num_entities == 0 || num_relations == 0 {
90            return;
91        }
92
93        let mut rng = Random::default();
94
95        // Initialize all embedding components with Xavier initialization
96        self.entity_embeddings_real =
97            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
98
99        self.entity_embeddings_imag =
100            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
101
102        self.relation_embeddings_real = xavier_init(
103            (num_relations, dimensions),
104            dimensions,
105            dimensions,
106            &mut rng,
107        );
108
109        self.relation_embeddings_imag = xavier_init(
110            (num_relations, dimensions),
111            dimensions,
112            dimensions,
113            &mut rng,
114        );
115
116        self.embeddings_initialized = true;
117        debug!(
118            "Initialized ComplEx embeddings: {} entities, {} relations, {} dimensions",
119            num_entities, num_relations, dimensions
120        );
121    }
122
123    /// Score a triple using ComplEx scoring function
124    /// Score = Re(<h, r, conj(t)>) = Re(h) * Re(r) * Re(t) + Re(h) * Im(r) * Im(t) +
125    ///                                 Im(h) * Re(r) * Im(t) - Im(h) * Im(r) * Re(t)
126    fn score_triple_ids(
127        &self,
128        subject_id: usize,
129        predicate_id: usize,
130        object_id: usize,
131    ) -> Result<f64> {
132        if !self.embeddings_initialized {
133            return Err(anyhow!("Model not trained"));
134        }
135
136        let h_real = self.entity_embeddings_real.row(subject_id);
137        let h_imag = self.entity_embeddings_imag.row(subject_id);
138        let r_real = self.relation_embeddings_real.row(predicate_id);
139        let r_imag = self.relation_embeddings_imag.row(predicate_id);
140        let t_real = self.entity_embeddings_real.row(object_id);
141        let t_imag = self.entity_embeddings_imag.row(object_id);
142
143        // Complex multiplication: (h_real + i*h_imag) * (r_real + i*r_imag) * conj(t_real + i*t_imag)
144        // = (h_real + i*h_imag) * (r_real + i*r_imag) * (t_real - i*t_imag)
145        let score = (&h_real * &r_real * t_real).sum()
146            + (&h_real * &r_imag * t_imag).sum()
147            + (&h_imag * &r_real * t_imag).sum()
148            - (&h_imag * &r_imag * t_real).sum();
149
150        Ok(score)
151    }
152
153    /// Compute gradients for ComplEx model
154    fn compute_gradients(
155        &self,
156        pos_triple: (usize, usize, usize),
157        neg_triple: (usize, usize, usize),
158        pos_score: f64,
159        neg_score: f64,
160    ) -> Result<GradientTuple> {
161        let mut entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
162        let mut entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
163        let mut relation_grads_real = Array2::zeros(self.relation_embeddings_real.raw_dim());
164        let mut relation_grads_imag = Array2::zeros(self.relation_embeddings_imag.raw_dim());
165
166        // Logistic loss gradients
167        let pos_sigmoid = sigmoid(pos_score);
168        let neg_sigmoid = sigmoid(neg_score);
169
170        let pos_grad_coeff = pos_sigmoid - 1.0; // Derivative of log(sigmoid(x))
171        let neg_grad_coeff = neg_sigmoid; // Derivative of log(1 - sigmoid(x))
172
173        // Compute gradients for positive triple
174        self.add_triple_gradients(
175            pos_triple,
176            pos_grad_coeff,
177            &mut entity_grads_real,
178            &mut entity_grads_imag,
179            &mut relation_grads_real,
180            &mut relation_grads_imag,
181        );
182
183        // Compute gradients for negative triple
184        self.add_triple_gradients(
185            neg_triple,
186            neg_grad_coeff,
187            &mut entity_grads_real,
188            &mut entity_grads_imag,
189            &mut relation_grads_real,
190            &mut relation_grads_imag,
191        );
192
193        Ok((
194            entity_grads_real,
195            entity_grads_imag,
196            relation_grads_real,
197            relation_grads_imag,
198        ))
199    }
200
201    /// Add gradients for a single triple
202    fn add_triple_gradients(
203        &self,
204        triple: (usize, usize, usize),
205        grad_coeff: f64,
206        entity_grads_real: &mut Array2<f64>,
207        entity_grads_imag: &mut Array2<f64>,
208        relation_grads_real: &mut Array2<f64>,
209        relation_grads_imag: &mut Array2<f64>,
210    ) {
211        let (s, p, o) = triple;
212
213        let h_real = self.entity_embeddings_real.row(s);
214        let h_imag = self.entity_embeddings_imag.row(s);
215        let r_real = self.relation_embeddings_real.row(p);
216        let r_imag = self.relation_embeddings_imag.row(p);
217        let t_real = self.entity_embeddings_real.row(o);
218        let t_imag = self.entity_embeddings_imag.row(o);
219
220        // Gradients w.r.t. h (subject)
221        // ∂score/∂h_real = r_real * t_real + r_imag * t_imag
222        // ∂score/∂h_imag = r_real * t_imag - r_imag * t_real
223        let h_real_grad = (&r_real * &t_real + &r_imag * &t_imag) * grad_coeff;
224        let h_imag_grad = (&r_real * &t_imag - &r_imag * &t_real) * grad_coeff;
225
226        entity_grads_real.row_mut(s).add_assign(&h_real_grad);
227        entity_grads_imag.row_mut(s).add_assign(&h_imag_grad);
228
229        // Gradients w.r.t. r (relation)
230        // ∂score/∂r_real = h_real * t_real + h_imag * t_imag
231        // ∂score/∂r_imag = h_real * t_imag - h_imag * t_real
232        let r_real_grad = (&h_real * &t_real + &h_imag * &t_imag) * grad_coeff;
233        let r_imag_grad = (&h_real * &t_imag - &h_imag * &t_real) * grad_coeff;
234
235        relation_grads_real.row_mut(p).add_assign(&r_real_grad);
236        relation_grads_imag.row_mut(p).add_assign(&r_imag_grad);
237
238        // Gradients w.r.t. t (object) - note the conjugate
239        // ∂score/∂t_real = h_real * r_real - h_imag * r_imag
240        // ∂score/∂t_imag = -(h_real * r_imag + h_imag * r_real)
241        let t_real_grad = (&h_real * &r_real - &h_imag * &r_imag) * grad_coeff;
242        let t_imag_grad = -(&h_real * &r_imag + &h_imag * &r_real) * grad_coeff;
243
244        entity_grads_real.row_mut(o).add_assign(&t_real_grad);
245        entity_grads_imag.row_mut(o).add_assign(&t_imag_grad);
246    }
247
248    /// Apply N3 regularization
249    fn apply_n3_regularization(
250        &self,
251        entity_grads_real: &mut Array2<f64>,
252        entity_grads_imag: &mut Array2<f64>,
253        relation_grads_real: &mut Array2<f64>,
254        relation_grads_imag: &mut Array2<f64>,
255        regularization_weight: f64,
256    ) {
257        // N3 regularization: penalize the nuclear 3-norm
258        // For complex embeddings, this becomes more involved
259        // For simplicity, we apply L2 regularization here
260        // A full N3 implementation would require more complex tensor operations
261
262        *entity_grads_real += &(&self.entity_embeddings_real * regularization_weight);
263        *entity_grads_imag += &(&self.entity_embeddings_imag * regularization_weight);
264        *relation_grads_real += &(&self.relation_embeddings_real * regularization_weight);
265        *relation_grads_imag += &(&self.relation_embeddings_imag * regularization_weight);
266    }
267
268    /// Perform one training epoch
269    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
270        let mut rng = Random::default();
271
272        let mut total_loss = 0.0;
273        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
274            / self.base.config.batch_size;
275
276        // Create shuffled batches
277        let mut shuffled_triples = self.base.triples.clone();
278        // Manual Fisher-Yates shuffle using scirs2-core
279        for i in (1..shuffled_triples.len()).rev() {
280            let j = rng.random_range(0..i + 1);
281            shuffled_triples.swap(i, j);
282        }
283
284        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
285            let mut batch_entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
286            let mut batch_entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
287            let mut batch_relation_grads_real =
288                Array2::zeros(self.relation_embeddings_real.raw_dim());
289            let mut batch_relation_grads_imag =
290                Array2::zeros(self.relation_embeddings_imag.raw_dim());
291            let mut batch_loss = 0.0;
292
293            for &pos_triple in batch_triples {
294                // Generate negative samples
295                let neg_samples = self
296                    .base
297                    .generate_negative_samples(self.base.config.negative_samples, &mut rng);
298
299                for neg_triple in neg_samples {
300                    // Compute scores
301                    let pos_score =
302                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
303                    let neg_score =
304                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
305
306                    // Compute logistic loss
307                    let pos_loss = logistic_loss(pos_score, 1.0);
308                    let neg_loss = logistic_loss(neg_score, -1.0);
309                    let total_triple_loss = pos_loss + neg_loss;
310
311                    batch_loss += total_triple_loss;
312
313                    // Compute and accumulate gradients
314                    let (
315                        entity_grads_real,
316                        entity_grads_imag,
317                        relation_grads_real,
318                        relation_grads_imag,
319                    ) = self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
320
321                    batch_entity_grads_real += &entity_grads_real;
322                    batch_entity_grads_imag += &entity_grads_imag;
323                    batch_relation_grads_real += &relation_grads_real;
324                    batch_relation_grads_imag += &relation_grads_imag;
325                }
326            }
327
328            // Apply regularization
329            match self.regularization {
330                RegularizationType::L2 => {
331                    let reg_weight = self.base.config.l2_reg;
332                    batch_entity_grads_real += &(&self.entity_embeddings_real * reg_weight);
333                    batch_entity_grads_imag += &(&self.entity_embeddings_imag * reg_weight);
334                    batch_relation_grads_real += &(&self.relation_embeddings_real * reg_weight);
335                    batch_relation_grads_imag += &(&self.relation_embeddings_imag * reg_weight);
336                }
337                RegularizationType::N3 => {
338                    self.apply_n3_regularization(
339                        &mut batch_entity_grads_real,
340                        &mut batch_entity_grads_imag,
341                        &mut batch_relation_grads_real,
342                        &mut batch_relation_grads_imag,
343                        self.base.config.l2_reg,
344                    );
345                }
346                RegularizationType::None => {}
347            }
348
349            // Apply gradients
350            self.entity_embeddings_real -= &(&batch_entity_grads_real * learning_rate);
351            self.entity_embeddings_imag -= &(&batch_entity_grads_imag * learning_rate);
352            self.relation_embeddings_real -= &(&batch_relation_grads_real * learning_rate);
353            self.relation_embeddings_imag -= &(&batch_relation_grads_imag * learning_rate);
354
355            total_loss += batch_loss;
356        }
357
358        Ok(total_loss / num_batches as f64)
359    }
360
361    /// Get entity embedding as a concatenated real/imaginary vector
362    fn get_entity_embedding_vector(&self, entity_id: usize) -> Vector {
363        let real_part = self.entity_embeddings_real.row(entity_id);
364        let imag_part = self.entity_embeddings_imag.row(entity_id);
365
366        // Concatenate real and imaginary parts
367        let mut values = Vec::with_capacity(real_part.len() * 2);
368        for &val in real_part.iter() {
369            values.push(val as f32);
370        }
371        for &val in imag_part.iter() {
372            values.push(val as f32);
373        }
374
375        Vector::new(values)
376    }
377
378    /// Get relation embedding as a concatenated real/imaginary vector
379    fn get_relation_embedding_vector(&self, relation_id: usize) -> Vector {
380        let real_part = self.relation_embeddings_real.row(relation_id);
381        let imag_part = self.relation_embeddings_imag.row(relation_id);
382
383        // Concatenate real and imaginary parts
384        let mut values = Vec::with_capacity(real_part.len() * 2);
385        for &val in real_part.iter() {
386            values.push(val as f32);
387        }
388        for &val in imag_part.iter() {
389            values.push(val as f32);
390        }
391
392        Vector::new(values)
393    }
394}
395
396#[async_trait]
397impl EmbeddingModel for ComplEx {
398    fn config(&self) -> &ModelConfig {
399        &self.base.config
400    }
401
402    fn model_id(&self) -> &Uuid {
403        &self.base.model_id
404    }
405
406    fn model_type(&self) -> &'static str {
407        "ComplEx"
408    }
409
410    fn add_triple(&mut self, triple: Triple) -> Result<()> {
411        self.base.add_triple(triple)
412    }
413
414    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
415        let start_time = Instant::now();
416        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
417
418        // Initialize embeddings if needed
419        self.initialize_embeddings();
420
421        if !self.embeddings_initialized {
422            return Err(anyhow!("No training data available"));
423        }
424
425        let mut loss_history = Vec::new();
426        let learning_rate = self.base.config.learning_rate;
427
428        info!("Starting ComplEx training for {} epochs", max_epochs);
429
430        for epoch in 0..max_epochs {
431            let epoch_loss = self.train_epoch(learning_rate).await?;
432            loss_history.push(epoch_loss);
433
434            if epoch % 100 == 0 {
435                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
436            }
437
438            // Simple convergence check
439            if epoch > 10 && epoch_loss < 1e-6 {
440                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
441                break;
442            }
443        }
444
445        self.base.mark_trained();
446        let training_time = start_time.elapsed().as_secs_f64();
447
448        Ok(TrainingStats {
449            epochs_completed: loss_history.len(),
450            final_loss: loss_history.last().copied().unwrap_or(0.0),
451            training_time_seconds: training_time,
452            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
453            loss_history,
454        })
455    }
456
457    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
458        if !self.embeddings_initialized {
459            return Err(anyhow!("Model not trained"));
460        }
461
462        let entity_id = self
463            .base
464            .get_entity_id(entity)
465            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
466
467        Ok(self.get_entity_embedding_vector(entity_id))
468    }
469
470    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
471        if !self.embeddings_initialized {
472            return Err(anyhow!("Model not trained"));
473        }
474
475        let relation_id = self
476            .base
477            .get_relation_id(relation)
478            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
479
480        Ok(self.get_relation_embedding_vector(relation_id))
481    }
482
483    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
484        let subject_id = self
485            .base
486            .get_entity_id(subject)
487            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
488        let predicate_id = self
489            .base
490            .get_relation_id(predicate)
491            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
492        let object_id = self
493            .base
494            .get_entity_id(object)
495            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
496
497        self.score_triple_ids(subject_id, predicate_id, object_id)
498    }
499
500    fn predict_objects(
501        &self,
502        subject: &str,
503        predicate: &str,
504        k: usize,
505    ) -> Result<Vec<(String, f64)>> {
506        if !self.embeddings_initialized {
507            return Err(anyhow!("Model not trained"));
508        }
509
510        let subject_id = self
511            .base
512            .get_entity_id(subject)
513            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
514        let predicate_id = self
515            .base
516            .get_relation_id(predicate)
517            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
518
519        let mut scores = Vec::new();
520
521        for object_id in 0..self.base.num_entities() {
522            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
523            let object_name = self.base.get_entity(object_id).unwrap().clone();
524            scores.push((object_name, score));
525        }
526
527        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
528        scores.truncate(k);
529
530        Ok(scores)
531    }
532
533    fn predict_subjects(
534        &self,
535        predicate: &str,
536        object: &str,
537        k: usize,
538    ) -> Result<Vec<(String, f64)>> {
539        if !self.embeddings_initialized {
540            return Err(anyhow!("Model not trained"));
541        }
542
543        let predicate_id = self
544            .base
545            .get_relation_id(predicate)
546            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
547        let object_id = self
548            .base
549            .get_entity_id(object)
550            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
551
552        let mut scores = Vec::new();
553
554        for subject_id in 0..self.base.num_entities() {
555            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
556            let subject_name = self.base.get_entity(subject_id).unwrap().clone();
557            scores.push((subject_name, score));
558        }
559
560        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
561        scores.truncate(k);
562
563        Ok(scores)
564    }
565
566    fn predict_relations(
567        &self,
568        subject: &str,
569        object: &str,
570        k: usize,
571    ) -> Result<Vec<(String, f64)>> {
572        if !self.embeddings_initialized {
573            return Err(anyhow!("Model not trained"));
574        }
575
576        let subject_id = self
577            .base
578            .get_entity_id(subject)
579            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
580        let object_id = self
581            .base
582            .get_entity_id(object)
583            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
584
585        let mut scores = Vec::new();
586
587        for predicate_id in 0..self.base.num_relations() {
588            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
589            let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
590            scores.push((predicate_name, score));
591        }
592
593        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
594        scores.truncate(k);
595
596        Ok(scores)
597    }
598
599    fn get_entities(&self) -> Vec<String> {
600        self.base.get_entities()
601    }
602
603    fn get_relations(&self) -> Vec<String> {
604        self.base.get_relations()
605    }
606
607    fn get_stats(&self) -> ModelStats {
608        self.base.get_stats("ComplEx")
609    }
610
611    fn save(&self, path: &str) -> Result<()> {
612        info!("Saving ComplEx model to {}", path);
613        Ok(())
614    }
615
616    fn load(&mut self, path: &str) -> Result<()> {
617        info!("Loading ComplEx model from {}", path);
618        Ok(())
619    }
620
621    fn clear(&mut self) {
622        self.base.clear();
623        self.entity_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
624        self.entity_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
625        self.relation_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
626        self.relation_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
627        self.embeddings_initialized = false;
628    }
629
630    fn is_trained(&self) -> bool {
631        self.base.is_trained
632    }
633
634    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
635        Err(anyhow!(
636            "Knowledge graph embedding model does not support text encoding"
637        ))
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use crate::NamedNode;
645
646    #[tokio::test]
647    async fn test_complex_basic() -> Result<()> {
648        let config = ModelConfig::default()
649            .with_dimensions(50)
650            .with_max_epochs(10)
651            .with_seed(42);
652
653        let mut model = ComplEx::new(config);
654
655        // Add test triples
656        let alice = NamedNode::new("http://example.org/alice")?;
657        let knows = NamedNode::new("http://example.org/knows")?;
658        let bob = NamedNode::new("http://example.org/bob")?;
659
660        model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
661
662        // Train
663        let stats = model.train(Some(5)).await?;
664        assert!(stats.epochs_completed > 0);
665
666        // Test embeddings (should be 2x dimensions due to complex)
667        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
668        assert_eq!(alice_emb.dimensions, 100); // 2 * 50
669
670        // Test scoring
671        let score = model.score_triple(
672            "http://example.org/alice",
673            "http://example.org/knows",
674            "http://example.org/bob",
675        )?;
676
677        // Score should be a finite number
678        assert!(score.is_finite());
679
680        Ok(())
681    }
682}