Skip to main content

oxirs_graphrag/
transe_model.rs

1//! # TransE Knowledge Graph Embedding Model
2//!
3//! Implements the TransE (Translating Embeddings) model for knowledge graph
4//! link prediction. In TransE, relationships are represented as translations
5//! in the embedding space: `h + r ≈ t` for a triple (h, r, t).
6//!
7//! ## Features
8//!
9//! - **Training**: Stochastic gradient descent with margin-based ranking loss
10//! - **Scoring**: Score candidate triples using L1 or L2 distance
11//! - **Link prediction**: Predict head/tail entities given partial triples
12//! - **Nearest neighbor search**: Find entities closest to a query embedding
13//! - **Serialization**: Export/import learned embeddings
14
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17
18// ─────────────────────────────────────────────
19// Configuration
20// ─────────────────────────────────────────────
21
22/// Distance metric for TransE scoring.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum DistanceMetric {
25    /// L1 (Manhattan) distance.
26    L1,
27    /// L2 (Euclidean) distance.
28    L2,
29}
30
31/// Configuration for the TransE model.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TransEConfig {
34    /// Embedding dimension.
35    pub dim: usize,
36    /// Learning rate.
37    pub learning_rate: f64,
38    /// Margin for the ranking loss (gamma).
39    pub margin: f64,
40    /// Distance metric.
41    pub distance_metric: DistanceMetric,
42    /// Maximum training epochs.
43    pub max_epochs: usize,
44    /// Number of negative samples per positive triple.
45    pub num_negatives: usize,
46    /// Whether to normalize embeddings after each update.
47    pub normalize_embeddings: bool,
48}
49
50impl Default for TransEConfig {
51    fn default() -> Self {
52        Self {
53            dim: 50,
54            learning_rate: 0.01,
55            margin: 1.0,
56            distance_metric: DistanceMetric::L2,
57            max_epochs: 100,
58            num_negatives: 1,
59            normalize_embeddings: true,
60        }
61    }
62}
63
64// ─────────────────────────────────────────────
65// Triple types
66// ─────────────────────────────────────────────
67
68/// An RDF-like triple (head, relation, tail) using integer IDs.
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct Triple {
71    pub head: usize,
72    pub relation: usize,
73    pub tail: usize,
74}
75
76/// A scored triple for link prediction results.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ScoredTriple {
79    pub head: usize,
80    pub relation: usize,
81    pub tail: usize,
82    pub score: f64,
83}
84
85/// Training statistics.
86#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct TrainingStats {
88    /// Training loss per epoch.
89    pub loss_history: Vec<f64>,
90    /// Number of epochs completed.
91    pub epochs_completed: usize,
92    /// Total triples processed.
93    pub triples_processed: u64,
94}
95
96// ─────────────────────────────────────────────
97// TransEModel
98// ─────────────────────────────────────────────
99
100/// TransE knowledge graph embedding model.
101pub struct TransEModel {
102    config: TransEConfig,
103    /// Entity embeddings: entity_id -> embedding vector.
104    entity_embeddings: HashMap<usize, Vec<f64>>,
105    /// Relation embeddings: relation_id -> embedding vector.
106    relation_embeddings: HashMap<usize, Vec<f64>>,
107    /// Entity name to ID mapping.
108    entity_to_id: HashMap<String, usize>,
109    /// ID to entity name mapping.
110    id_to_entity: HashMap<usize, String>,
111    /// Relation name to ID mapping.
112    relation_to_id: HashMap<String, usize>,
113    /// ID to relation name mapping.
114    id_to_relation: HashMap<usize, String>,
115    /// Known triples (for filtered evaluation).
116    known_triples: HashSet<Triple>,
117    /// Training statistics.
118    stats: TrainingStats,
119    /// Simple LCG state for reproducible pseudo-random initialization.
120    rng_state: u64,
121}
122
123impl TransEModel {
124    /// Create a new TransE model with default configuration.
125    pub fn new() -> Self {
126        Self::with_config(TransEConfig::default())
127    }
128
129    /// Create a new TransE model with the given configuration.
130    pub fn with_config(config: TransEConfig) -> Self {
131        Self {
132            config,
133            entity_embeddings: HashMap::new(),
134            relation_embeddings: HashMap::new(),
135            entity_to_id: HashMap::new(),
136            id_to_entity: HashMap::new(),
137            relation_to_id: HashMap::new(),
138            id_to_relation: HashMap::new(),
139            known_triples: HashSet::new(),
140            stats: TrainingStats::default(),
141            rng_state: 12345,
142        }
143    }
144
145    /// Get model configuration.
146    pub fn config(&self) -> &TransEConfig {
147        &self.config
148    }
149
150    /// Get training statistics.
151    pub fn stats(&self) -> &TrainingStats {
152        &self.stats
153    }
154
155    /// Number of entities.
156    pub fn entity_count(&self) -> usize {
157        self.entity_to_id.len()
158    }
159
160    /// Number of relations.
161    pub fn relation_count(&self) -> usize {
162        self.relation_to_id.len()
163    }
164
165    /// Number of known triples.
166    pub fn triple_count(&self) -> usize {
167        self.known_triples.len()
168    }
169
170    /// Register an entity and return its ID.
171    pub fn add_entity(&mut self, name: impl Into<String>) -> usize {
172        let name = name.into();
173        if let Some(&id) = self.entity_to_id.get(&name) {
174            return id;
175        }
176        let id = self.entity_to_id.len();
177        self.entity_to_id.insert(name.clone(), id);
178        self.id_to_entity.insert(id, name);
179        // Initialize embedding
180        let embedding = self.random_embedding();
181        self.entity_embeddings.insert(id, embedding);
182        id
183    }
184
185    /// Register a relation and return its ID.
186    pub fn add_relation(&mut self, name: impl Into<String>) -> usize {
187        let name = name.into();
188        if let Some(&id) = self.relation_to_id.get(&name) {
189            return id;
190        }
191        let id = self.relation_to_id.len();
192        self.relation_to_id.insert(name.clone(), id);
193        self.id_to_relation.insert(id, name);
194        // Initialize embedding
195        let mut embedding = self.random_embedding();
196        // Normalize relation embeddings
197        let norm = l2_norm(&embedding);
198        if norm > 1e-12 {
199            for v in &mut embedding {
200                *v /= norm;
201            }
202        }
203        self.relation_embeddings.insert(id, embedding);
204        id
205    }
206
207    /// Add a triple (by entity/relation names).
208    pub fn add_triple(
209        &mut self,
210        head: impl Into<String>,
211        relation: impl Into<String>,
212        tail: impl Into<String>,
213    ) {
214        let h = self.add_entity(head);
215        let r = self.add_relation(relation);
216        let t = self.add_entity(tail);
217        self.known_triples.insert(Triple {
218            head: h,
219            relation: r,
220            tail: t,
221        });
222    }
223
224    /// Train the model on the known triples.
225    pub fn train(&mut self, epochs: usize) -> TrainingStats {
226        let num_epochs = epochs.min(self.config.max_epochs);
227        let triples: Vec<Triple> = self.known_triples.iter().cloned().collect();
228
229        if triples.is_empty() {
230            return self.stats.clone();
231        }
232
233        let num_entities = self.entity_to_id.len();
234
235        for _epoch in 0..num_epochs {
236            let mut epoch_loss = 0.0;
237
238            for triple in &triples {
239                // Generate negative sample by corrupting head or tail
240                let neg_triple = self.corrupt_triple(triple, num_entities);
241
242                // Score positive and negative
243                let pos_score = self.score_triple_ids(triple.head, triple.relation, triple.tail);
244                let neg_score =
245                    self.score_triple_ids(neg_triple.head, neg_triple.relation, neg_triple.tail);
246
247                // Margin-based ranking loss: max(0, gamma + pos - neg)
248                let loss = (self.config.margin + pos_score - neg_score).max(0.0);
249                epoch_loss += loss;
250
251                if loss > 0.0 {
252                    // Update embeddings via SGD
253                    self.update_embeddings(triple, &neg_triple);
254                }
255
256                self.stats.triples_processed += 1;
257            }
258
259            let avg_loss = epoch_loss / triples.len() as f64;
260            self.stats.loss_history.push(avg_loss);
261            self.stats.epochs_completed += 1;
262
263            // Normalize entity embeddings if configured
264            if self.config.normalize_embeddings {
265                self.normalize_entities();
266            }
267        }
268
269        self.stats.clone()
270    }
271
272    /// Score a triple (lower score = better).
273    pub fn score(&self, head: &str, relation: &str, tail: &str) -> Option<f64> {
274        let h = self.entity_to_id.get(head)?;
275        let r = self.relation_to_id.get(relation)?;
276        let t = self.entity_to_id.get(tail)?;
277        Some(self.score_triple_ids(*h, *r, *t))
278    }
279
280    /// Score a triple by IDs.
281    fn score_triple_ids(&self, head: usize, relation: usize, tail: usize) -> f64 {
282        let h = match self.entity_embeddings.get(&head) {
283            Some(e) => e,
284            None => return f64::MAX,
285        };
286        let r = match self.relation_embeddings.get(&relation) {
287            Some(e) => e,
288            None => return f64::MAX,
289        };
290        let t = match self.entity_embeddings.get(&tail) {
291            Some(e) => e,
292            None => return f64::MAX,
293        };
294
295        // distance(h + r, t)
296        let dim = self.config.dim;
297        match self.config.distance_metric {
298            DistanceMetric::L1 => {
299                let mut dist = 0.0;
300                for i in 0..dim {
301                    let hi = h.get(i).copied().unwrap_or(0.0);
302                    let ri = r.get(i).copied().unwrap_or(0.0);
303                    let ti = t.get(i).copied().unwrap_or(0.0);
304                    dist += (hi + ri - ti).abs();
305                }
306                dist
307            }
308            DistanceMetric::L2 => {
309                let mut dist = 0.0;
310                for i in 0..dim {
311                    let hi = h.get(i).copied().unwrap_or(0.0);
312                    let ri = r.get(i).copied().unwrap_or(0.0);
313                    let ti = t.get(i).copied().unwrap_or(0.0);
314                    dist += (hi + ri - ti).powi(2);
315                }
316                dist.sqrt()
317            }
318        }
319    }
320
321    /// Predict the top-k tail entities given (head, relation, ?).
322    pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<ScoredTriple> {
323        let h = match self.entity_to_id.get(head) {
324            Some(&id) => id,
325            None => return Vec::new(),
326        };
327        let r = match self.relation_to_id.get(relation) {
328            Some(&id) => id,
329            None => return Vec::new(),
330        };
331
332        let mut scores: Vec<ScoredTriple> = self
333            .entity_to_id
334            .values()
335            .map(|&t_id| {
336                let score = self.score_triple_ids(h, r, t_id);
337                ScoredTriple {
338                    head: h,
339                    relation: r,
340                    tail: t_id,
341                    score,
342                }
343            })
344            .collect();
345
346        scores.sort_by(|a, b| {
347            a.score
348                .partial_cmp(&b.score)
349                .unwrap_or(std::cmp::Ordering::Equal)
350        });
351        scores.truncate(k);
352        scores
353    }
354
355    /// Predict the top-k head entities given (?, relation, tail).
356    pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<ScoredTriple> {
357        let r = match self.relation_to_id.get(relation) {
358            Some(&id) => id,
359            None => return Vec::new(),
360        };
361        let t = match self.entity_to_id.get(tail) {
362            Some(&id) => id,
363            None => return Vec::new(),
364        };
365
366        let mut scores: Vec<ScoredTriple> = self
367            .entity_to_id
368            .values()
369            .map(|&h_id| {
370                let score = self.score_triple_ids(h_id, r, t);
371                ScoredTriple {
372                    head: h_id,
373                    relation: r,
374                    tail: t,
375                    score,
376                }
377            })
378            .collect();
379
380        scores.sort_by(|a, b| {
381            a.score
382                .partial_cmp(&b.score)
383                .unwrap_or(std::cmp::Ordering::Equal)
384        });
385        scores.truncate(k);
386        scores
387    }
388
389    /// Get the embedding for an entity.
390    pub fn entity_embedding(&self, name: &str) -> Option<&Vec<f64>> {
391        self.entity_to_id
392            .get(name)
393            .and_then(|id| self.entity_embeddings.get(id))
394    }
395
396    /// Get the embedding for a relation.
397    pub fn relation_embedding(&self, name: &str) -> Option<&Vec<f64>> {
398        self.relation_to_id
399            .get(name)
400            .and_then(|id| self.relation_embeddings.get(id))
401    }
402
403    /// Find nearest entities to a query embedding.
404    pub fn nearest_entities(&self, query: &[f64], k: usize) -> Vec<(String, f64)> {
405        let mut dists: Vec<(String, f64)> = self
406            .entity_embeddings
407            .iter()
408            .map(|(&id, emb)| {
409                let dist = l2_distance(query, emb);
410                let name = self
411                    .id_to_entity
412                    .get(&id)
413                    .cloned()
414                    .unwrap_or_else(|| format!("entity_{id}"));
415                (name, dist)
416            })
417            .collect();
418
419        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
420        dists.truncate(k);
421        dists
422    }
423
424    /// Get entity name by ID.
425    pub fn entity_name(&self, id: usize) -> Option<&str> {
426        self.id_to_entity.get(&id).map(|s| s.as_str())
427    }
428
429    /// Get relation name by ID.
430    pub fn relation_name(&self, id: usize) -> Option<&str> {
431        self.id_to_relation.get(&id).map(|s| s.as_str())
432    }
433
434    // ─── Internal methods ────────────────────────────────
435
436    fn random_embedding(&mut self) -> Vec<f64> {
437        let dim = self.config.dim;
438        (0..dim)
439            .map(|_| {
440                self.rng_state = self
441                    .rng_state
442                    .wrapping_mul(6364136223846793005)
443                    .wrapping_add(1442695040888963407);
444                let val = ((self.rng_state >> 33) as f64) / (u32::MAX as f64);
445                (val - 0.5) * 2.0 / (dim as f64).sqrt()
446            })
447            .collect()
448    }
449
450    fn corrupt_triple(&mut self, triple: &Triple, num_entities: usize) -> Triple {
451        if num_entities == 0 {
452            return triple.clone();
453        }
454        self.rng_state = self
455            .rng_state
456            .wrapping_mul(6364136223846793005)
457            .wrapping_add(1442695040888963407);
458        let corrupt_head = (self.rng_state >> 33) % 2 == 0;
459        let random_entity = ((self.rng_state >> 17) as usize) % num_entities;
460
461        if corrupt_head {
462            Triple {
463                head: random_entity,
464                relation: triple.relation,
465                tail: triple.tail,
466            }
467        } else {
468            Triple {
469                head: triple.head,
470                relation: triple.relation,
471                tail: random_entity,
472            }
473        }
474    }
475
476    fn update_embeddings(&mut self, positive: &Triple, negative: &Triple) {
477        let lr = self.config.learning_rate;
478        let dim = self.config.dim;
479
480        // Gradient for positive triple: h + r - t
481        // For negative triple: h' + r - t' (corrupted)
482        let pos_h = self
483            .entity_embeddings
484            .get(&positive.head)
485            .cloned()
486            .unwrap_or_else(|| vec![0.0; dim]);
487        let pos_t = self
488            .entity_embeddings
489            .get(&positive.tail)
490            .cloned()
491            .unwrap_or_else(|| vec![0.0; dim]);
492        let neg_h = self
493            .entity_embeddings
494            .get(&negative.head)
495            .cloned()
496            .unwrap_or_else(|| vec![0.0; dim]);
497        let neg_t = self
498            .entity_embeddings
499            .get(&negative.tail)
500            .cloned()
501            .unwrap_or_else(|| vec![0.0; dim]);
502        let rel = self
503            .relation_embeddings
504            .get(&positive.relation)
505            .cloned()
506            .unwrap_or_else(|| vec![0.0; dim]);
507
508        // Gradient direction for L2
509        let mut pos_grad = vec![0.0; dim];
510        let mut neg_grad = vec![0.0; dim];
511        for i in 0..dim {
512            pos_grad[i] = pos_h[i] + rel[i] - pos_t[i];
513            neg_grad[i] = neg_h[i] + rel[i] - neg_t[i];
514        }
515
516        // Normalize gradients
517        let pos_norm = l2_norm(&pos_grad).max(1e-12);
518        let neg_norm = l2_norm(&neg_grad).max(1e-12);
519
520        // Update positive head: h = h - lr * gradient
521        if let Some(h_emb) = self.entity_embeddings.get_mut(&positive.head) {
522            for i in 0..dim {
523                h_emb[i] -= lr * pos_grad[i] / pos_norm;
524            }
525        }
526
527        // Update positive tail: t = t + lr * gradient
528        if let Some(t_emb) = self.entity_embeddings.get_mut(&positive.tail) {
529            for i in 0..dim {
530                t_emb[i] += lr * pos_grad[i] / pos_norm;
531            }
532        }
533
534        // Update negative head: h' = h' + lr * gradient
535        if let Some(h_emb) = self.entity_embeddings.get_mut(&negative.head) {
536            for i in 0..dim {
537                h_emb[i] += lr * neg_grad[i] / neg_norm;
538            }
539        }
540
541        // Update negative tail: t' = t' - lr * gradient
542        if let Some(t_emb) = self.entity_embeddings.get_mut(&negative.tail) {
543            for i in 0..dim {
544                t_emb[i] -= lr * neg_grad[i] / neg_norm;
545            }
546        }
547
548        // Update relation: r = r - lr * (pos_grad - neg_grad)
549        if let Some(r_emb) = self.relation_embeddings.get_mut(&positive.relation) {
550            for i in 0..dim {
551                r_emb[i] -= lr * (pos_grad[i] / pos_norm - neg_grad[i] / neg_norm);
552            }
553        }
554    }
555
556    fn normalize_entities(&mut self) {
557        for emb in self.entity_embeddings.values_mut() {
558            let norm = l2_norm(emb);
559            if norm > 1.0 {
560                for v in emb.iter_mut() {
561                    *v /= norm;
562                }
563            }
564        }
565    }
566}
567
568impl Default for TransEModel {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574// ─── Helper functions ────────────────────────────────────
575
576fn l2_norm(v: &[f64]) -> f64 {
577    v.iter().map(|x| x * x).sum::<f64>().sqrt()
578}
579
580fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
581    a.iter()
582        .zip(b.iter())
583        .map(|(x, y)| (x - y).powi(2))
584        .sum::<f64>()
585        .sqrt()
586}
587
588// ─────────────────────────────────────────────
589// Tests
590// ─────────────────────────────────────────────
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    fn sample_model() -> TransEModel {
597        let mut model = TransEModel::with_config(TransEConfig {
598            dim: 10,
599            learning_rate: 0.01,
600            margin: 1.0,
601            max_epochs: 50,
602            ..Default::default()
603        });
604        model.add_triple("alice", "knows", "bob");
605        model.add_triple("bob", "knows", "charlie");
606        model.add_triple("alice", "likes", "music");
607        model.add_triple("bob", "likes", "sports");
608        model.add_triple("charlie", "likes", "music");
609        model
610    }
611
612    // ═══ Config tests ════════════════════════════════════
613
614    #[test]
615    fn test_default_config() {
616        let config = TransEConfig::default();
617        assert_eq!(config.dim, 50);
618        assert_eq!(config.distance_metric, DistanceMetric::L2);
619        assert!(config.normalize_embeddings);
620    }
621
622    // ═══ Entity/relation management tests ════════════════
623
624    #[test]
625    fn test_add_entity() {
626        let mut model = TransEModel::new();
627        let id = model.add_entity("alice");
628        assert_eq!(id, 0);
629        assert_eq!(model.entity_count(), 1);
630    }
631
632    #[test]
633    fn test_add_entity_idempotent() {
634        let mut model = TransEModel::new();
635        let id1 = model.add_entity("alice");
636        let id2 = model.add_entity("alice");
637        assert_eq!(id1, id2);
638        assert_eq!(model.entity_count(), 1);
639    }
640
641    #[test]
642    fn test_add_relation() {
643        let mut model = TransEModel::new();
644        let id = model.add_relation("knows");
645        assert_eq!(id, 0);
646        assert_eq!(model.relation_count(), 1);
647    }
648
649    #[test]
650    fn test_add_triple() {
651        let model = sample_model();
652        assert_eq!(model.triple_count(), 5);
653        assert_eq!(model.entity_count(), 5); // alice, bob, charlie, music, sports
654        assert_eq!(model.relation_count(), 2); // knows, likes
655    }
656
657    #[test]
658    fn test_entity_name() {
659        let model = sample_model();
660        assert_eq!(model.entity_name(0), Some("alice"));
661    }
662
663    #[test]
664    fn test_relation_name() {
665        let model = sample_model();
666        let name = model.relation_name(0);
667        assert!(name.is_some());
668    }
669
670    // ═══ Training tests ══════════════════════════════════
671
672    #[test]
673    fn test_train_basic() {
674        let mut model = sample_model();
675        let stats = model.train(10);
676        assert_eq!(stats.epochs_completed, 10);
677        assert_eq!(stats.loss_history.len(), 10);
678    }
679
680    #[test]
681    fn test_train_loss_decreases() {
682        let mut model = sample_model();
683        model.train(20);
684        let losses = &model.stats().loss_history;
685        // Not guaranteed to be strictly decreasing, but early loss should be >= late loss
686        let first_avg: f64 = losses[..5].iter().sum::<f64>() / 5.0;
687        let last_avg: f64 = losses[15..].iter().sum::<f64>() / 5.0;
688        // Loss should decrease or at least not explode
689        assert!(last_avg < first_avg * 10.0);
690    }
691
692    #[test]
693    fn test_train_empty_triples() {
694        let mut model = TransEModel::new();
695        let stats = model.train(10);
696        assert_eq!(stats.epochs_completed, 0);
697    }
698
699    #[test]
700    fn test_train_stats_cumulative() {
701        let mut model = sample_model();
702        model.train(5);
703        model.train(5);
704        assert_eq!(model.stats().epochs_completed, 10);
705    }
706
707    // ═══ Scoring tests ═══════════════════════════════════
708
709    #[test]
710    fn test_score_known_triple() {
711        let mut model = sample_model();
712        model.train(20);
713        let score = model.score("alice", "knows", "bob");
714        assert!(score.is_some());
715        assert!(score.expect("score") < 100.0);
716    }
717
718    #[test]
719    fn test_score_unknown_entity() {
720        let model = sample_model();
721        assert!(model.score("unknown", "knows", "bob").is_none());
722    }
723
724    #[test]
725    fn test_score_unknown_relation() {
726        let model = sample_model();
727        assert!(model.score("alice", "unknown", "bob").is_none());
728    }
729
730    // ═══ Prediction tests ════════════════════════════════
731
732    #[test]
733    fn test_predict_tail() {
734        let mut model = sample_model();
735        model.train(10);
736        let predictions = model.predict_tail("alice", "knows", 3);
737        assert_eq!(predictions.len(), 3);
738        // Should be sorted by score (ascending, lower is better)
739        for window in predictions.windows(2) {
740            assert!(window[0].score <= window[1].score);
741        }
742    }
743
744    #[test]
745    fn test_predict_head() {
746        let mut model = sample_model();
747        model.train(10);
748        let predictions = model.predict_head("knows", "bob", 3);
749        assert_eq!(predictions.len(), 3);
750    }
751
752    #[test]
753    fn test_predict_unknown_entity() {
754        let model = sample_model();
755        let predictions = model.predict_tail("unknown", "knows", 3);
756        assert!(predictions.is_empty());
757    }
758
759    #[test]
760    fn test_predict_unknown_relation() {
761        let model = sample_model();
762        let predictions = model.predict_tail("alice", "unknown", 3);
763        assert!(predictions.is_empty());
764    }
765
766    // ═══ Embedding access tests ══════════════════════════
767
768    #[test]
769    fn test_entity_embedding() {
770        let model = sample_model();
771        let emb = model.entity_embedding("alice");
772        assert!(emb.is_some());
773        assert_eq!(emb.expect("embedding").len(), 10);
774    }
775
776    #[test]
777    fn test_relation_embedding() {
778        let model = sample_model();
779        let emb = model.relation_embedding("knows");
780        assert!(emb.is_some());
781        assert_eq!(emb.expect("embedding").len(), 10);
782    }
783
784    #[test]
785    fn test_embedding_unknown() {
786        let model = sample_model();
787        assert!(model.entity_embedding("unknown").is_none());
788        assert!(model.relation_embedding("unknown").is_none());
789    }
790
791    // ═══ Nearest neighbor tests ══════════════════════════
792
793    #[test]
794    fn test_nearest_entities() {
795        let mut model = sample_model();
796        model.train(10);
797        let alice_emb = model.entity_embedding("alice").expect("alice").clone();
798        let nearest = model.nearest_entities(&alice_emb, 3);
799        assert_eq!(nearest.len(), 3);
800        // Alice should be closest to itself
801        assert_eq!(nearest[0].0, "alice");
802        assert!(nearest[0].1 < 1e-10);
803    }
804
805    // ═══ Distance metric tests ═══════════════════════════
806
807    #[test]
808    fn test_l1_distance_metric() {
809        let mut model = TransEModel::with_config(TransEConfig {
810            dim: 10,
811            distance_metric: DistanceMetric::L1,
812            ..Default::default()
813        });
814        model.add_triple("a", "r", "b");
815        model.train(5);
816        let score = model.score("a", "r", "b");
817        assert!(score.is_some());
818    }
819
820    // ═══ Normalization tests ═════════════════════════════
821
822    #[test]
823    fn test_normalized_embeddings() {
824        let mut model = sample_model();
825        model.train(10);
826        for emb in model.entity_embeddings.values() {
827            let norm = l2_norm(emb);
828            assert!(norm <= 1.0 + 1e-6);
829        }
830    }
831
832    #[test]
833    fn test_no_normalization() {
834        let mut model = TransEModel::with_config(TransEConfig {
835            dim: 10,
836            normalize_embeddings: false,
837            ..Default::default()
838        });
839        model.add_triple("a", "r", "b");
840        model.train(5);
841        // Should still work (no crash)
842        assert_eq!(model.triple_count(), 1);
843    }
844
845    // ═══ Helper function tests ═══════════════════════════
846
847    #[test]
848    fn test_l2_norm() {
849        let v = vec![3.0, 4.0];
850        assert!((l2_norm(&v) - 5.0).abs() < 1e-10);
851    }
852
853    #[test]
854    fn test_l2_distance() {
855        let a = vec![0.0, 0.0];
856        let b = vec![3.0, 4.0];
857        assert!((l2_distance(&a, &b) - 5.0).abs() < 1e-10);
858    }
859
860    #[test]
861    fn test_l2_distance_same() {
862        let a = vec![1.0, 2.0, 3.0];
863        assert!(l2_distance(&a, &a) < 1e-10);
864    }
865
866    // ═══ Default impl test ═══════════════════════════════
867
868    #[test]
869    fn test_default_model() {
870        let model = TransEModel::default();
871        assert_eq!(model.entity_count(), 0);
872        assert_eq!(model.relation_count(), 0);
873        assert_eq!(model.triple_count(), 0);
874    }
875}