oxirs_embed/models/
transe.rs

1//! TransE: Translating Embeddings for Modeling Multi-relational Data
2//!
3//! TransE models relations as translations in the embedding space:
4//! h + r ≈ t for a true triple (h, r, t)
5//!
6//! Reference: Bordes et al. "Translating Embeddings for Modeling Multi-relational Data" (2013)
7
8use crate::models::{common::*, BaseModel};
9use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
10use anyhow::{anyhow, Result};
11use async_trait::async_trait;
12use scirs2_core::ndarray_ext::{Array1, Array2};
13#[allow(unused_imports)]
14use scirs2_core::random::{Random, Rng};
15use serde::{Deserialize, Serialize};
16use std::ops::{AddAssign, SubAssign};
17use std::time::Instant;
18use tracing::{debug, info};
19use uuid::Uuid;
20
21/// TransE embedding model
22#[derive(Debug, Clone)]
23pub struct TransE {
24    /// Base model functionality
25    base: BaseModel,
26    /// Entity embeddings matrix (num_entities × dimensions)
27    entity_embeddings: Array2<f64>,
28    /// Relation embeddings matrix (num_relations × dimensions)
29    relation_embeddings: Array2<f64>,
30    /// Whether embeddings have been initialized
31    embeddings_initialized: bool,
32    /// Distance metric for scoring
33    distance_metric: DistanceMetric,
34    /// Margin for ranking loss
35    margin: f64,
36}
37
38/// Distance metrics supported by TransE
39#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub enum DistanceMetric {
41    /// L1 (Manhattan) distance
42    L1,
43    /// L2 (Euclidean) distance
44    L2,
45    /// Cosine distance (1 - cosine similarity)
46    Cosine,
47}
48
49impl TransE {
50    /// Create a new TransE model
51    pub fn new(config: ModelConfig) -> Self {
52        let base = BaseModel::new(config.clone());
53
54        // Get TransE-specific parameters
55        let distance_metric = match config.model_params.get("distance_metric") {
56            Some(0.0) => DistanceMetric::L1,
57            Some(1.0) => DistanceMetric::L2,
58            Some(2.0) => DistanceMetric::Cosine,
59            _ => DistanceMetric::L2, // Default to L2
60        };
61
62        let margin = config.model_params.get("margin").copied().unwrap_or(1.0);
63
64        Self {
65            base,
66            entity_embeddings: Array2::zeros((0, config.dimensions)),
67            relation_embeddings: Array2::zeros((0, config.dimensions)),
68            embeddings_initialized: false,
69            distance_metric,
70            margin,
71        }
72    }
73
74    /// Create a new TransE model with L1 (Manhattan) distance metric
75    pub fn with_l1_distance(mut config: ModelConfig) -> Self {
76        config
77            .model_params
78            .insert("distance_metric".to_string(), 0.0);
79        Self::new(config)
80    }
81
82    /// Create a new TransE model with L2 (Euclidean) distance metric
83    pub fn with_l2_distance(mut config: ModelConfig) -> Self {
84        config
85            .model_params
86            .insert("distance_metric".to_string(), 1.0);
87        Self::new(config)
88    }
89
90    /// Create a new TransE model with Cosine distance metric
91    pub fn with_cosine_distance(mut config: ModelConfig) -> Self {
92        config
93            .model_params
94            .insert("distance_metric".to_string(), 2.0);
95        Self::new(config)
96    }
97
98    /// Create a new TransE model with custom margin for ranking loss
99    pub fn with_margin(mut config: ModelConfig, margin: f64) -> Self {
100        config.model_params.insert("margin".to_string(), margin);
101        Self::new(config)
102    }
103
104    /// Get the current distance metric
105    pub fn distance_metric(&self) -> DistanceMetric {
106        self.distance_metric
107    }
108
109    /// Get the current margin value
110    pub fn margin(&self) -> f64 {
111        self.margin
112    }
113
114    /// Initialize embeddings after entities and relations are known
115    fn initialize_embeddings(&mut self) {
116        if self.embeddings_initialized {
117            return;
118        }
119
120        let num_entities = self.base.num_entities();
121        let num_relations = self.base.num_relations();
122        let dimensions = self.base.config.dimensions;
123
124        if num_entities == 0 || num_relations == 0 {
125            return;
126        }
127
128        let mut rng = Random::default();
129
130        // Initialize entity embeddings with Xavier initialization
131        self.entity_embeddings =
132            xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
133
134        // Initialize relation embeddings with Xavier initialization
135        self.relation_embeddings = xavier_init(
136            (num_relations, dimensions),
137            dimensions,
138            dimensions,
139            &mut rng,
140        );
141
142        // Normalize entity embeddings to unit sphere
143        normalize_embeddings(&mut self.entity_embeddings);
144
145        self.embeddings_initialized = true;
146        debug!(
147            "Initialized TransE embeddings: {} entities, {} relations, {} dimensions",
148            num_entities, num_relations, dimensions
149        );
150    }
151
152    /// Score a triple using TransE scoring function
153    fn score_triple_ids(
154        &self,
155        subject_id: usize,
156        predicate_id: usize,
157        object_id: usize,
158    ) -> Result<f64> {
159        if !self.embeddings_initialized {
160            return Err(anyhow!("Model not trained"));
161        }
162
163        let h = self.entity_embeddings.row(subject_id);
164        let r = self.relation_embeddings.row(predicate_id);
165        let t = self.entity_embeddings.row(object_id);
166
167        // Compute h + r - t
168        let diff = &h + &r - t;
169
170        // Distance metric determines scoring (lower distance = higher score)
171        let distance = match self.distance_metric {
172            DistanceMetric::L1 => diff.mapv(|x| x.abs()).sum(),
173            DistanceMetric::L2 => diff.mapv(|x| x * x).sum().sqrt(),
174            DistanceMetric::Cosine => {
175                // For cosine distance, we compute 1 - cosine_similarity(h + r, t)
176                let h_plus_r = &h + &r;
177                let dot_product = (&h_plus_r * &t).sum();
178                let norm_h_plus_r = h_plus_r.mapv(|x| x * x).sum().sqrt();
179                let norm_t = t.mapv(|x| x * x).sum().sqrt();
180
181                if norm_h_plus_r == 0.0 || norm_t == 0.0 {
182                    1.0 // Maximum distance for zero vectors
183                } else {
184                    let cosine_sim = dot_product / (norm_h_plus_r * norm_t);
185                    1.0 - cosine_sim.clamp(-1.0, 1.0) // Clamp to [-1, 1] and convert to distance
186                }
187            }
188        };
189
190        // Return negative distance as score (higher is better)
191        Ok(-distance)
192    }
193
194    /// Compute gradients for a training triple
195    fn compute_gradients(
196        &self,
197        pos_triple: (usize, usize, usize),
198        neg_triple: (usize, usize, usize),
199    ) -> Result<(Array2<f64>, Array2<f64>)> {
200        let (pos_s, pos_p, pos_o) = pos_triple;
201        let (neg_s, neg_p, neg_o) = neg_triple;
202
203        let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
204        let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
205
206        // Get embeddings
207        let pos_h = self.entity_embeddings.row(pos_s);
208        let pos_r = self.relation_embeddings.row(pos_p);
209        let pos_t = self.entity_embeddings.row(pos_o);
210
211        let neg_h = self.entity_embeddings.row(neg_s);
212        let neg_r = self.relation_embeddings.row(neg_p);
213        let neg_t = self.entity_embeddings.row(neg_o);
214
215        // Compute differences
216        let pos_diff = &pos_h + &pos_r - pos_t;
217        let neg_diff = &neg_h + &neg_r - neg_t;
218
219        // Compute distances
220        let pos_distance = match self.distance_metric {
221            DistanceMetric::L1 => pos_diff.mapv(|x| x.abs()).sum(),
222            DistanceMetric::L2 => pos_diff.mapv(|x| x * x).sum().sqrt(),
223            DistanceMetric::Cosine => {
224                let norm = pos_diff.mapv(|x| x * x).sum().sqrt();
225                if norm > 1e-10 {
226                    1.0 - (pos_diff.dot(&pos_diff) / (norm * norm)).clamp(-1.0, 1.0)
227                } else {
228                    0.0
229                }
230            }
231        };
232
233        let neg_distance = match self.distance_metric {
234            DistanceMetric::L1 => neg_diff.mapv(|x| x.abs()).sum(),
235            DistanceMetric::L2 => neg_diff.mapv(|x| x * x).sum().sqrt(),
236            DistanceMetric::Cosine => {
237                let norm = neg_diff.mapv(|x| x * x).sum().sqrt();
238                if norm > 1e-10 {
239                    1.0 - (neg_diff.dot(&neg_diff) / (norm * norm)).clamp(-1.0, 1.0)
240                } else {
241                    0.0
242                }
243            }
244        };
245
246        // Check if we need to update (margin loss > 0)
247        let loss = self.margin + pos_distance - neg_distance;
248        if loss > 0.0 {
249            // Compute gradient direction based on distance metric
250            let pos_grad_direction = match self.distance_metric {
251                DistanceMetric::L1 => pos_diff.mapv(|x| {
252                    if x > 0.0 {
253                        1.0
254                    } else if x < 0.0 {
255                        -1.0
256                    } else {
257                        0.0
258                    }
259                }),
260                DistanceMetric::L2 => {
261                    if pos_distance > 1e-10 {
262                        &pos_diff / pos_distance
263                    } else {
264                        Array1::zeros(pos_diff.len())
265                    }
266                }
267                DistanceMetric::Cosine => {
268                    let norm_sq = pos_diff.mapv(|x| x * x).sum();
269                    if norm_sq > 1e-10 {
270                        &pos_diff / norm_sq.sqrt()
271                    } else {
272                        Array1::zeros(pos_diff.len())
273                    }
274                }
275            };
276
277            let neg_grad_direction = match self.distance_metric {
278                DistanceMetric::L1 => neg_diff.mapv(|x| {
279                    if x > 0.0 {
280                        1.0
281                    } else if x < 0.0 {
282                        -1.0
283                    } else {
284                        0.0
285                    }
286                }),
287                DistanceMetric::L2 => {
288                    if neg_distance > 1e-10 {
289                        &neg_diff / neg_distance
290                    } else {
291                        Array1::zeros(neg_diff.len())
292                    }
293                }
294                DistanceMetric::Cosine => {
295                    let norm_sq = neg_diff.mapv(|x| x * x).sum();
296                    if norm_sq > 1e-10 {
297                        &neg_diff / norm_sq.sqrt()
298                    } else {
299                        Array1::zeros(neg_diff.len())
300                    }
301                }
302            };
303
304            // Update gradients for positive triple (increase distance)
305            entity_grads.row_mut(pos_s).add_assign(&pos_grad_direction);
306            relation_grads
307                .row_mut(pos_p)
308                .add_assign(&pos_grad_direction);
309            entity_grads.row_mut(pos_o).sub_assign(&pos_grad_direction);
310
311            // Update gradients for negative triple (decrease distance)
312            entity_grads.row_mut(neg_s).sub_assign(&neg_grad_direction);
313            relation_grads
314                .row_mut(neg_p)
315                .sub_assign(&neg_grad_direction);
316            entity_grads.row_mut(neg_o).add_assign(&neg_grad_direction);
317        }
318
319        Ok((entity_grads, relation_grads))
320    }
321
322    /// Perform one training epoch
323    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
324        let mut rng = Random::default();
325
326        let mut total_loss = 0.0;
327        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
328            / self.base.config.batch_size;
329
330        // Create shuffled batches
331        let mut shuffled_triples = self.base.triples.clone();
332        // Manual Fisher-Yates shuffle using scirs2-core
333        for i in (1..shuffled_triples.len()).rev() {
334            let j = rng.random_range(0..i + 1);
335            shuffled_triples.swap(i, j);
336        }
337
338        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
339            let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
340            let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
341            let mut batch_loss = 0.0;
342
343            for &pos_triple in batch_triples {
344                // Generate negative samples
345                let neg_samples = self
346                    .base
347                    .generate_negative_samples(self.base.config.negative_samples, &mut rng);
348
349                for neg_triple in neg_samples {
350                    // Compute scores
351                    let pos_score =
352                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
353                    let neg_score =
354                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
355
356                    // Convert scores to distances (negate because score = -distance)
357                    let pos_distance = -pos_score;
358                    let neg_distance = -neg_score;
359
360                    // Compute margin loss
361                    let loss = margin_loss(pos_distance, neg_distance, self.margin);
362                    batch_loss += loss;
363
364                    if loss > 0.0 {
365                        // Compute and accumulate gradients
366                        let (entity_grads, relation_grads) =
367                            self.compute_gradients(pos_triple, neg_triple)?;
368                        batch_entity_grads += &entity_grads;
369                        batch_relation_grads += &relation_grads;
370                    }
371                }
372            }
373
374            // Apply gradients
375            if batch_loss > 0.0 {
376                gradient_update(
377                    &mut self.entity_embeddings,
378                    &batch_entity_grads,
379                    learning_rate,
380                    self.base.config.l2_reg,
381                );
382
383                gradient_update(
384                    &mut self.relation_embeddings,
385                    &batch_relation_grads,
386                    learning_rate,
387                    self.base.config.l2_reg,
388                );
389
390                // Normalize entity embeddings
391                normalize_embeddings(&mut self.entity_embeddings);
392            }
393
394            total_loss += batch_loss;
395        }
396
397        Ok(total_loss / num_batches as f64)
398    }
399}
400
401#[async_trait]
402impl EmbeddingModel for TransE {
403    fn config(&self) -> &ModelConfig {
404        &self.base.config
405    }
406
407    fn model_id(&self) -> &Uuid {
408        &self.base.model_id
409    }
410
411    fn model_type(&self) -> &'static str {
412        "TransE"
413    }
414
415    fn add_triple(&mut self, triple: Triple) -> Result<()> {
416        self.base.add_triple(triple)
417    }
418
419    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
420        let start_time = Instant::now();
421        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
422
423        // Initialize embeddings if needed
424        self.initialize_embeddings();
425
426        if !self.embeddings_initialized {
427            return Err(anyhow!("No training data available"));
428        }
429
430        let mut loss_history = Vec::new();
431        let learning_rate = self.base.config.learning_rate;
432
433        info!("Starting TransE training for {} epochs", max_epochs);
434
435        for epoch in 0..max_epochs {
436            let epoch_loss = self.train_epoch(learning_rate).await?;
437            loss_history.push(epoch_loss);
438
439            if epoch % 100 == 0 {
440                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
441            }
442
443            // Simple convergence check
444            if epoch > 10 && epoch_loss < 1e-6 {
445                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
446                break;
447            }
448        }
449
450        self.base.mark_trained();
451        let training_time = start_time.elapsed().as_secs_f64();
452
453        Ok(TrainingStats {
454            epochs_completed: loss_history.len(),
455            final_loss: loss_history.last().copied().unwrap_or(0.0),
456            training_time_seconds: training_time,
457            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
458            loss_history,
459        })
460    }
461
462    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
463        if !self.embeddings_initialized {
464            return Err(anyhow!("Model not trained"));
465        }
466
467        let entity_id = self
468            .base
469            .get_entity_id(entity)
470            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
471
472        let embedding = self.entity_embeddings.row(entity_id).to_owned();
473        Ok(ndarray_to_vector(&embedding))
474    }
475
476    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
477        if !self.embeddings_initialized {
478            return Err(anyhow!("Model not trained"));
479        }
480
481        let relation_id = self
482            .base
483            .get_relation_id(relation)
484            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
485
486        let embedding = self.relation_embeddings.row(relation_id).to_owned();
487        Ok(ndarray_to_vector(&embedding))
488    }
489
490    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
491        let subject_id = self
492            .base
493            .get_entity_id(subject)
494            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
495        let predicate_id = self
496            .base
497            .get_relation_id(predicate)
498            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
499        let object_id = self
500            .base
501            .get_entity_id(object)
502            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
503
504        self.score_triple_ids(subject_id, predicate_id, object_id)
505    }
506
507    fn predict_objects(
508        &self,
509        subject: &str,
510        predicate: &str,
511        k: usize,
512    ) -> Result<Vec<(String, f64)>> {
513        if !self.embeddings_initialized {
514            return Err(anyhow!("Model not trained"));
515        }
516
517        let subject_id = self
518            .base
519            .get_entity_id(subject)
520            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
521        let predicate_id = self
522            .base
523            .get_relation_id(predicate)
524            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
525
526        let mut scores = Vec::new();
527
528        for object_id in 0..self.base.num_entities() {
529            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
530            let object_name = self.base.get_entity(object_id).unwrap().clone();
531            scores.push((object_name, score));
532        }
533
534        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
535        scores.truncate(k);
536
537        Ok(scores)
538    }
539
540    fn predict_subjects(
541        &self,
542        predicate: &str,
543        object: &str,
544        k: usize,
545    ) -> Result<Vec<(String, f64)>> {
546        if !self.embeddings_initialized {
547            return Err(anyhow!("Model not trained"));
548        }
549
550        let predicate_id = self
551            .base
552            .get_relation_id(predicate)
553            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
554        let object_id = self
555            .base
556            .get_entity_id(object)
557            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
558
559        let mut scores = Vec::new();
560
561        for subject_id in 0..self.base.num_entities() {
562            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
563            let subject_name = self.base.get_entity(subject_id).unwrap().clone();
564            scores.push((subject_name, score));
565        }
566
567        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
568        scores.truncate(k);
569
570        Ok(scores)
571    }
572
573    fn predict_relations(
574        &self,
575        subject: &str,
576        object: &str,
577        k: usize,
578    ) -> Result<Vec<(String, f64)>> {
579        if !self.embeddings_initialized {
580            return Err(anyhow!("Model not trained"));
581        }
582
583        let subject_id = self
584            .base
585            .get_entity_id(subject)
586            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
587        let object_id = self
588            .base
589            .get_entity_id(object)
590            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
591
592        let mut scores = Vec::new();
593
594        for predicate_id in 0..self.base.num_relations() {
595            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
596            let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
597            scores.push((predicate_name, score));
598        }
599
600        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
601        scores.truncate(k);
602
603        Ok(scores)
604    }
605
606    fn get_entities(&self) -> Vec<String> {
607        self.base.get_entities()
608    }
609
610    fn get_relations(&self) -> Vec<String> {
611        self.base.get_relations()
612    }
613
614    fn get_stats(&self) -> ModelStats {
615        self.base.get_stats("TransE")
616    }
617
618    fn save(&self, path: &str) -> Result<()> {
619        // For now, just a placeholder
620        // In a full implementation, this would serialize the model to file
621        info!("Saving TransE model to {}", path);
622        Ok(())
623    }
624
625    fn load(&mut self, path: &str) -> Result<()> {
626        // For now, just a placeholder
627        // In a full implementation, this would deserialize the model from file
628        info!("Loading TransE model from {}", path);
629        Ok(())
630    }
631
632    fn clear(&mut self) {
633        self.base.clear();
634        self.entity_embeddings = Array2::zeros((0, self.base.config.dimensions));
635        self.relation_embeddings = Array2::zeros((0, self.base.config.dimensions));
636        self.embeddings_initialized = false;
637    }
638
639    fn is_trained(&self) -> bool {
640        self.base.is_trained
641    }
642
643    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
644        Err(anyhow!(
645            "TransE is a knowledge graph embedding model and does not support text encoding"
646        ))
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use crate::NamedNode;
654
655    #[tokio::test]
656    async fn test_transe_basic() -> Result<()> {
657        let config = ModelConfig::default()
658            .with_dimensions(50)
659            .with_max_epochs(10)
660            .with_seed(42);
661
662        let mut model = TransE::new(config);
663
664        // Add test triples
665        let alice = NamedNode::new("http://example.org/alice")?;
666        let knows = NamedNode::new("http://example.org/knows")?;
667        let bob = NamedNode::new("http://example.org/bob")?;
668
669        model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
670        model.add_triple(Triple::new(bob.clone(), knows.clone(), alice.clone()))?;
671
672        // Train
673        let stats = model.train(Some(5)).await?;
674        assert!(stats.epochs_completed > 0);
675
676        // Test embeddings
677        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
678        assert_eq!(alice_emb.dimensions, 50);
679
680        // Test scoring
681        let score = model.score_triple(
682            "http://example.org/alice",
683            "http://example.org/knows",
684            "http://example.org/bob",
685        )?;
686
687        // Score should be a finite number
688        assert!(score.is_finite());
689
690        Ok(())
691    }
692
693    #[tokio::test]
694    async fn test_transe_distance_metrics() -> Result<()> {
695        let base_config = ModelConfig::default()
696            .with_dimensions(10)
697            .with_max_epochs(5)
698            .with_seed(42);
699
700        // Test L1 distance
701        let mut model_l1 = TransE::with_l1_distance(base_config.clone());
702        assert!(matches!(model_l1.distance_metric(), DistanceMetric::L1));
703
704        // Test L2 distance
705        let mut model_l2 = TransE::with_l2_distance(base_config.clone());
706        assert!(matches!(model_l2.distance_metric(), DistanceMetric::L2));
707
708        // Test Cosine distance
709        let mut model_cosine = TransE::with_cosine_distance(base_config.clone());
710        assert!(matches!(
711            model_cosine.distance_metric(),
712            DistanceMetric::Cosine
713        ));
714
715        // Test custom margin
716        let model_margin = TransE::with_margin(base_config.clone(), 2.0);
717        assert_eq!(model_margin.margin(), 2.0);
718
719        // Add same triples to all models
720        let alice = NamedNode::new("http://example.org/alice")?;
721        let knows = NamedNode::new("http://example.org/knows")?;
722        let bob = NamedNode::new("http://example.org/bob")?;
723        let triple = Triple::new(alice, knows, bob);
724
725        model_l1.add_triple(triple.clone())?;
726        model_l2.add_triple(triple.clone())?;
727        model_cosine.add_triple(triple.clone())?;
728
729        // Train all models
730        model_l1.train(Some(3)).await?;
731        model_l2.train(Some(3)).await?;
732        model_cosine.train(Some(3)).await?;
733
734        // Test that all models produce finite scores
735        let score_l1 = model_l1.score_triple(
736            "http://example.org/alice",
737            "http://example.org/knows",
738            "http://example.org/bob",
739        )?;
740        let score_l2 = model_l2.score_triple(
741            "http://example.org/alice",
742            "http://example.org/knows",
743            "http://example.org/bob",
744        )?;
745        let score_cosine = model_cosine.score_triple(
746            "http://example.org/alice",
747            "http://example.org/knows",
748            "http://example.org/bob",
749        )?;
750
751        assert!(score_l1.is_finite());
752        assert!(score_l2.is_finite());
753        assert!(score_cosine.is_finite());
754
755        // Scores may differ due to different distance metrics
756        // This tests that the cosine distance implementation works
757        println!("L1 score: {score_l1}, L2 score: {score_l2}, Cosine score: {score_cosine}");
758
759        Ok(())
760    }
761}