Skip to main content

scirs2_graph/
knowledge_graph.rs

1//! Knowledge Graph Embedding models
2//!
3//! This module implements classical knowledge graph embedding (KGE) methods for
4//! learning low-dimensional representations of entities and relations in a
5//! knowledge graph (KG).
6//!
7//! A KG is a collection of (head, relation, tail) *triples* `(h, r, t)` where
8//! `h` and `t` are entities and `r` is a relation type.
9//!
10//! ## Implemented models
11//!
12//! | Model | Score function | Reference |
13//! |-------|---------------|-----------|
14//! [`TransE`] | `‖ h + r − t ‖` | Bordes et al. 2013 |
15//! [`DistMult`] | `h · r · t` (elementwise) | Yang et al. 2015 |
16//! [`ComplEx`] | `Re( h · r · conj(t) )` | Trouillon et al. 2016 |
17//!
18//! ## Training with negative sampling
19//!
20//! All models support self-adversarial negative sampling: for each positive
21//! triple `(h, r, t)`, one corrupt triple is generated by replacing either
22//! the head or the tail with a random entity.  Margin-ranking loss is
23//! minimised:
24//! ```text
25//!   L = max(0, γ + score_neg − score_pos)
26//! ```
27
28use std::collections::HashMap;
29
30use scirs2_core::random::{Rng, RngExt};
31
32use crate::error::{GraphError, Result};
33
34// ============================================================================
35// KGDataset
36// ============================================================================
37
38/// A collection of (head, relation, tail) triples for a knowledge graph.
39///
40/// Entities and relations are referenced by their integer indices.  Use
41/// [`KGDataset::from_str_triples`] to build a dataset from string labels
42/// automatically.
43#[derive(Debug, Clone)]
44pub struct KGDataset {
45    /// All triples as `(head_idx, rel_idx, tail_idx)`
46    pub triples: Vec<(usize, usize, usize)>,
47    /// Number of distinct entities
48    pub n_entities: usize,
49    /// Number of distinct relation types
50    pub n_relations: usize,
51    /// Entity index → string label
52    pub entity_labels: Vec<String>,
53    /// Relation index → string label
54    pub relation_labels: Vec<String>,
55}
56
57impl KGDataset {
58    /// Build a `KGDataset` from raw integer triples.
59    ///
60    /// # Arguments
61    /// * `triples` – Vec of `(head, relation, tail)` triples.
62    /// * `n_entities` – Total number of entities (must exceed all indices).
63    /// * `n_relations` – Total number of relation types.
64    pub fn new(
65        triples: Vec<(usize, usize, usize)>,
66        n_entities: usize,
67        n_relations: usize,
68    ) -> Result<Self> {
69        for &(h, r, t) in &triples {
70            if h >= n_entities || t >= n_entities {
71                return Err(GraphError::InvalidParameter {
72                    param: "triples".to_string(),
73                    value: format!("entity index ({h},{t}) out of range"),
74                    expected: format!("< n_entities={n_entities}"),
75                    context: "KGDataset::new".to_string(),
76                });
77            }
78            if r >= n_relations {
79                return Err(GraphError::InvalidParameter {
80                    param: "triples".to_string(),
81                    value: format!("relation index {r} out of range"),
82                    expected: format!("< n_relations={n_relations}"),
83                    context: "KGDataset::new".to_string(),
84                });
85            }
86        }
87        let entity_labels = (0..n_entities).map(|i| format!("e{i}")).collect();
88        let relation_labels = (0..n_relations).map(|i| format!("r{i}")).collect();
89        Ok(KGDataset {
90            triples,
91            n_entities,
92            n_relations,
93            entity_labels,
94            relation_labels,
95        })
96    }
97
98    /// Build a dataset from string-labeled triples.
99    ///
100    /// Entity and relation vocabularies are inferred from the data in the order
101    /// they are first encountered.
102    pub fn from_str_triples(triples: &[(&str, &str, &str)]) -> Self {
103        let mut entity_map: HashMap<String, usize> = HashMap::new();
104        let mut relation_map: HashMap<String, usize> = HashMap::new();
105        let mut entity_labels: Vec<String> = Vec::new();
106        let mut relation_labels: Vec<String> = Vec::new();
107
108        let mut get_or_insert_entity = |s: &str| -> usize {
109            if let Some(&idx) = entity_map.get(s) {
110                idx
111            } else {
112                let idx = entity_labels.len();
113                entity_map.insert(s.to_string(), idx);
114                entity_labels.push(s.to_string());
115                idx
116            }
117        };
118
119        let mut indexed_triples: Vec<(usize, usize, usize)> = Vec::with_capacity(triples.len());
120        for &(h, r, t) in triples {
121            let hi = get_or_insert_entity(h);
122            let ti = get_or_insert_entity(t);
123            let ri = if let Some(&idx) = relation_map.get(r) {
124                idx
125            } else {
126                let idx = relation_labels.len();
127                relation_map.insert(r.to_string(), idx);
128                relation_labels.push(r.to_string());
129                idx
130            };
131            indexed_triples.push((hi, ri, ti));
132        }
133
134        let n_entities = entity_labels.len();
135        let n_relations = relation_labels.len();
136
137        KGDataset {
138            triples: indexed_triples,
139            n_entities,
140            n_relations,
141            entity_labels,
142            relation_labels,
143        }
144    }
145
146    /// Return the number of triples.
147    pub fn len(&self) -> usize {
148        self.triples.len()
149    }
150
151    /// Return true if the dataset contains no triples.
152    pub fn is_empty(&self) -> bool {
153        self.triples.is_empty()
154    }
155
156    /// Randomly corrupt one triple by replacing either the head or the tail
157    /// with a uniformly sampled entity (excluding the original).
158    pub fn corrupt_triple(&self, triple: (usize, usize, usize)) -> (usize, usize, usize) {
159        let (h, r, t) = triple;
160        let mut rng = scirs2_core::random::rng();
161        let replace_head = rng.random::<f64>() < 0.5;
162        if replace_head {
163            let mut new_h = (rng.random::<f64>() * self.n_entities as f64) as usize;
164            new_h = new_h.min(self.n_entities - 1);
165            // Avoid generating the same entity
166            if new_h == h && self.n_entities > 1 {
167                new_h = (new_h + 1) % self.n_entities;
168            }
169            (new_h, r, t)
170        } else {
171            let mut new_t = (rng.random::<f64>() * self.n_entities as f64) as usize;
172            new_t = new_t.min(self.n_entities - 1);
173            if new_t == t && self.n_entities > 1 {
174                new_t = (new_t + 1) % self.n_entities;
175            }
176            (h, r, new_t)
177        }
178    }
179}
180
181// ============================================================================
182// Embedding initialisation utilities
183// ============================================================================
184
185/// Initialise a flat embedding table `[n_items × dim]` with uniform noise in
186/// `[-scale, scale]`, then L2-normalise each row.
187fn init_embeddings(n_items: usize, dim: usize, scale: f64) -> Vec<Vec<f64>> {
188    let mut rng = scirs2_core::random::rng();
189    (0..n_items)
190        .map(|_| {
191            let mut row: Vec<f64> = (0..dim)
192                .map(|_| rng.random::<f64>() * 2.0 * scale - scale)
193                .collect();
194            let norm = row.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-12);
195            row.iter_mut().for_each(|x| *x /= norm);
196            row
197        })
198        .collect()
199}
200
201/// Compute L2 norm of a vector.
202#[inline]
203fn l2_norm(v: &[f64]) -> f64 {
204    v.iter().map(|x| x * x).sum::<f64>().sqrt()
205}
206
207/// L2-normalise a vector in place.
208fn l2_normalize(v: &mut [f64]) {
209    let norm = l2_norm(v).max(1e-12);
210    v.iter_mut().for_each(|x| *x /= norm);
211}
212
213// ============================================================================
214// TransE
215// ============================================================================
216
217/// TransE knowledge graph embedding model.
218///
219/// Score: `−‖ h + r − t ‖_p` (negated L-p distance, higher = more likely).
220///
221/// Training uses margin-based loss with negative sampling.
222#[derive(Debug, Clone)]
223pub struct TransE {
224    /// Entity embedding table `[n_entities, dim]`
225    pub entity_embeddings: Vec<Vec<f64>>,
226    /// Relation embedding table `[n_relations, dim]`
227    pub relation_embeddings: Vec<Vec<f64>>,
228    /// Embedding dimension
229    pub dim: usize,
230    /// Norm order (1 or 2)
231    pub norm_order: u32,
232}
233
234impl TransE {
235    /// Create a TransE model with random initialisation.
236    ///
237    /// # Arguments
238    /// * `n_entities` – Number of entities.
239    /// * `n_relations` – Number of relation types.
240    /// * `dim` – Embedding dimension.
241    pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
242        if dim == 0 {
243            return Err(GraphError::InvalidParameter {
244                param: "dim".to_string(),
245                value: "0".to_string(),
246                expected: "> 0".to_string(),
247                context: "TransE::new".to_string(),
248            });
249        }
250        let entity_embeddings = init_embeddings(n_entities, dim, 1.0 / (dim as f64).sqrt());
251        let relation_embeddings = init_embeddings(n_relations, dim, 1.0 / (dim as f64).sqrt());
252        Ok(TransE {
253            entity_embeddings,
254            relation_embeddings,
255            dim,
256            norm_order: 2,
257        })
258    }
259
260    /// Score a single triple: `−‖ h + r − t ‖` (higher = more plausible).
261    pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
262        self.validate_indices(h, r, t)?;
263        let he = &self.entity_embeddings[h];
264        let re = &self.relation_embeddings[r];
265        let te = &self.entity_embeddings[t];
266        let dist = translation_distance(he, re, te, self.norm_order);
267        Ok(-dist)
268    }
269
270    /// Return the top-`k` entity indices most likely to complete `(h, r, ?)`.
271    pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
272        let n = self.entity_embeddings.len();
273        if h >= n {
274            return Err(GraphError::InvalidParameter {
275                param: "h".to_string(),
276                value: format!("{h}"),
277                expected: format!("< {n}"),
278                context: "TransE::predict_tails".to_string(),
279            });
280        }
281        let he = &self.entity_embeddings[h];
282        let re = &self.relation_embeddings[r];
283        let mut scores: Vec<(usize, f64)> = (0..n)
284            .map(|t| {
285                let te = &self.entity_embeddings[t];
286                let dist = translation_distance(he, re, te, self.norm_order);
287                (t, -dist) // higher score = smaller distance
288            })
289            .collect();
290        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
291        Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
292    }
293
294    /// Link prediction score for a triple (h, r, t).
295    pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
296        self.score(h, r, t)
297    }
298
299    /// Train for one epoch using stochastic gradient descent and negative
300    /// sampling.  Returns the total margin-ranking loss.
301    ///
302    /// # Arguments
303    /// * `dataset` – Training triples.
304    /// * `lr` – Learning rate.
305    /// * `margin` – Margin `γ` for the margin-ranking loss.
306    pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
307        let mut total_loss = 0.0;
308
309        for &(h, r, t) in &dataset.triples {
310            let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
311
312            let pos_score = {
313                let he = &self.entity_embeddings[h];
314                let re = &self.relation_embeddings[r];
315                let te = &self.entity_embeddings[t];
316                translation_distance(he, re, te, self.norm_order)
317            };
318            let neg_score = {
319                let he = &self.entity_embeddings[nh];
320                let re = &self.relation_embeddings[nr];
321                let te = &self.entity_embeddings[nt];
322                translation_distance(he, re, te, self.norm_order)
323            };
324
325            let loss = (margin + pos_score - neg_score).max(0.0);
326            total_loss += loss;
327
328            if loss > 0.0 {
329                // Approximate gradient step
330                let dim = self.dim;
331
332                // Gradients for positive triple: minimise ‖h + r - t‖
333                let g_pos: Vec<f64> = (0..dim)
334                    .map(|k| {
335                        let diff = self.entity_embeddings[h][k] + self.relation_embeddings[r][k]
336                            - self.entity_embeddings[t][k];
337                        if diff >= 0.0 {
338                            1.0
339                        } else {
340                            -1.0
341                        }
342                    })
343                    .collect();
344
345                // Gradients for negative triple: maximise ‖nh + nr - nt‖
346                let g_neg: Vec<f64> = (0..dim)
347                    .map(|k| {
348                        let diff = self.entity_embeddings[nh][k] + self.relation_embeddings[nr][k]
349                            - self.entity_embeddings[nt][k];
350                        if diff >= 0.0 {
351                            1.0
352                        } else {
353                            -1.0
354                        }
355                    })
356                    .collect();
357
358                // Update positive entities and relation
359                for k in 0..dim {
360                    self.entity_embeddings[h][k] -= lr * g_pos[k];
361                    self.entity_embeddings[t][k] += lr * g_pos[k];
362                    self.relation_embeddings[r][k] -= lr * g_pos[k];
363                }
364                // Update negative entities
365                for k in 0..dim {
366                    self.entity_embeddings[nh][k] += lr * g_neg[k];
367                    self.entity_embeddings[nt][k] -= lr * g_neg[k];
368                }
369
370                // Re-normalise entity embeddings
371                l2_normalize(&mut self.entity_embeddings[h]);
372                l2_normalize(&mut self.entity_embeddings[t]);
373                l2_normalize(&mut self.entity_embeddings[nh]);
374                l2_normalize(&mut self.entity_embeddings[nt]);
375            }
376        }
377
378        total_loss
379    }
380
381    fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
382        let ne = self.entity_embeddings.len();
383        let nr = self.relation_embeddings.len();
384        if h >= ne || t >= ne {
385            return Err(GraphError::InvalidParameter {
386                param: "entity_index".to_string(),
387                value: format!("({h},{t})"),
388                expected: format!("< {ne}"),
389                context: "TransE score".to_string(),
390            });
391        }
392        if r >= nr {
393            return Err(GraphError::InvalidParameter {
394                param: "relation_index".to_string(),
395                value: format!("{r}"),
396                expected: format!("< {nr}"),
397                context: "TransE score".to_string(),
398            });
399        }
400        Ok(())
401    }
402}
403
404/// Compute L-p translation distance `‖ h + r − t ‖_p`.
405fn translation_distance(h: &[f64], r: &[f64], t: &[f64], norm_order: u32) -> f64 {
406    let diff_sum: f64 = h
407        .iter()
408        .zip(r.iter())
409        .zip(t.iter())
410        .map(|((&hi, &ri), &ti)| {
411            let d = hi + ri - ti;
412            match norm_order {
413                1 => d.abs(),
414                _ => d * d,
415            }
416        })
417        .sum();
418    match norm_order {
419        1 => diff_sum,
420        _ => diff_sum.sqrt(),
421    }
422}
423
424// ============================================================================
425// DistMult
426// ============================================================================
427
428/// DistMult knowledge graph embedding model.
429///
430/// Score: `Σ_k h_k · r_k · t_k` (element-wise bilinear product).
431#[derive(Debug, Clone)]
432pub struct DistMult {
433    /// Entity embedding table `[n_entities, dim]`
434    pub entity_embeddings: Vec<Vec<f64>>,
435    /// Relation (diagonal) embedding table `[n_relations, dim]`
436    pub relation_embeddings: Vec<Vec<f64>>,
437    /// Embedding dimension
438    pub dim: usize,
439}
440
441impl DistMult {
442    /// Create a DistMult model with random initialisation.
443    pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
444        if dim == 0 {
445            return Err(GraphError::InvalidParameter {
446                param: "dim".to_string(),
447                value: "0".to_string(),
448                expected: "> 0".to_string(),
449                context: "DistMult::new".to_string(),
450            });
451        }
452        let mut rng = scirs2_core::random::rng();
453        let scale = 1.0 / (dim as f64).sqrt();
454        let mut mk_table = |n: usize| -> Vec<Vec<f64>> {
455            (0..n)
456                .map(|_| {
457                    (0..dim)
458                        .map(|_| rng.random::<f64>() * 2.0 * scale - scale)
459                        .collect()
460                })
461                .collect()
462        };
463        Ok(DistMult {
464            entity_embeddings: mk_table(n_entities),
465            relation_embeddings: mk_table(n_relations),
466            dim,
467        })
468    }
469
470    /// Score triple `(h, r, t)`: `Σ h_k r_k t_k`.
471    pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
472        self.validate_indices(h, r, t)?;
473        let score = distmult_score(
474            &self.entity_embeddings[h],
475            &self.relation_embeddings[r],
476            &self.entity_embeddings[t],
477        );
478        Ok(score)
479    }
480
481    /// Link prediction score.
482    pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
483        self.score(h, r, t)
484    }
485
486    /// Return the top-`k` entity indices most likely to complete `(h, r, ?)`.
487    pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
488        let n = self.entity_embeddings.len();
489        if h >= n {
490            return Err(GraphError::InvalidParameter {
491                param: "h".to_string(),
492                value: format!("{h}"),
493                expected: format!("< {n}"),
494                context: "DistMult::predict_tails".to_string(),
495            });
496        }
497        let he = &self.entity_embeddings[h];
498        let re = &self.relation_embeddings[r];
499        let mut scores: Vec<(usize, f64)> = (0..n)
500            .map(|ti| {
501                let te = &self.entity_embeddings[ti];
502                (ti, distmult_score(he, re, te))
503            })
504            .collect();
505        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
506        Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
507    }
508
509    /// Train one epoch with margin-based negative sampling.
510    pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
511        let mut total_loss = 0.0;
512        for &(h, r, t) in &dataset.triples {
513            let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
514
515            let pos = distmult_score(
516                &self.entity_embeddings[h],
517                &self.relation_embeddings[r],
518                &self.entity_embeddings[t],
519            );
520            let neg = distmult_score(
521                &self.entity_embeddings[nh],
522                &self.relation_embeddings[nr],
523                &self.entity_embeddings[nt],
524            );
525
526            let loss = (margin - pos + neg).max(0.0);
527            total_loss += loss;
528
529            if loss > 0.0 {
530                let dim = self.dim;
531                // Gradient: d/d(h_k) = r_k * t_k
532                for k in 0..dim {
533                    let re = self.relation_embeddings[r][k];
534                    let te = self.entity_embeddings[t][k];
535                    self.entity_embeddings[h][k] += lr * re * te;
536                }
537                for k in 0..dim {
538                    let re = self.relation_embeddings[nr][k];
539                    let te = self.entity_embeddings[nt][k];
540                    self.entity_embeddings[nh][k] -= lr * re * te;
541                }
542            }
543        }
544        total_loss
545    }
546
547    fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
548        let ne = self.entity_embeddings.len();
549        let nr = self.relation_embeddings.len();
550        if h >= ne || t >= ne {
551            return Err(GraphError::InvalidParameter {
552                param: "entity_index".to_string(),
553                value: format!("({h},{t})"),
554                expected: format!("< {ne}"),
555                context: "DistMult score".to_string(),
556            });
557        }
558        if r >= nr {
559            return Err(GraphError::InvalidParameter {
560                param: "relation_index".to_string(),
561                value: format!("{r}"),
562                expected: format!("< {nr}"),
563                context: "DistMult score".to_string(),
564            });
565        }
566        Ok(())
567    }
568}
569
570fn distmult_score(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
571    h.iter()
572        .zip(r.iter())
573        .zip(t.iter())
574        .map(|((&hi, &ri), &ti)| hi * ri * ti)
575        .sum()
576}
577
578// ============================================================================
579// ComplEx
580// ============================================================================
581
582/// ComplEx knowledge graph embedding model.
583///
584/// Entities and relations are embedded in complex vector space `ℂ^d`.
585/// Each embedding is stored as two real vectors (real and imaginary parts).
586///
587/// Score: `Re( Σ_k h_k · r_k · conj(t_k) )`
588///      = `Σ_k ( Re(h)·Re(r)·Re(t) + Im(h)·Re(r)·Im(t)
589///              + Re(h)·Im(r)·Im(t) - Im(h)·Im(r)·Re(t) )`
590#[derive(Debug, Clone)]
591pub struct ComplEx {
592    /// Real part of entity embeddings `[n_entities, dim]`
593    pub entity_re: Vec<Vec<f64>>,
594    /// Imaginary part of entity embeddings `[n_entities, dim]`
595    pub entity_im: Vec<Vec<f64>>,
596    /// Real part of relation embeddings `[n_relations, dim]`
597    pub relation_re: Vec<Vec<f64>>,
598    /// Imaginary part of relation embeddings `[n_relations, dim]`
599    pub relation_im: Vec<Vec<f64>>,
600    /// Embedding dimension (complex components per embedding)
601    pub dim: usize,
602}
603
604impl ComplEx {
605    /// Create a ComplEx model with random initialisation.
606    pub fn new(n_entities: usize, n_relations: usize, dim: usize) -> Result<Self> {
607        if dim == 0 {
608            return Err(GraphError::InvalidParameter {
609                param: "dim".to_string(),
610                value: "0".to_string(),
611                expected: "> 0".to_string(),
612                context: "ComplEx::new".to_string(),
613            });
614        }
615        let scale = 1.0 / (dim as f64).sqrt();
616        Ok(ComplEx {
617            entity_re: init_embeddings(n_entities, dim, scale),
618            entity_im: init_embeddings(n_entities, dim, scale),
619            relation_re: init_embeddings(n_relations, dim, scale),
620            relation_im: init_embeddings(n_relations, dim, scale),
621            dim,
622        })
623    }
624
625    /// Score triple `(h, r, t)` using the ComplEx scoring function.
626    pub fn score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
627        self.validate_indices(h, r, t)?;
628        let s = complex_score(
629            &self.entity_re[h],
630            &self.entity_im[h],
631            &self.relation_re[r],
632            &self.relation_im[r],
633            &self.entity_re[t],
634            &self.entity_im[t],
635        );
636        Ok(s)
637    }
638
639    /// Link prediction score.
640    pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
641        self.score(h, r, t)
642    }
643
644    /// Return the top-`k` entity indices most likely to complete `(h, r, ?)`.
645    pub fn predict_tails(&self, h: usize, r: usize, k: usize) -> Result<Vec<usize>> {
646        let n = self.entity_re.len();
647        if h >= n {
648            return Err(GraphError::InvalidParameter {
649                param: "h".to_string(),
650                value: format!("{h}"),
651                expected: format!("< {n}"),
652                context: "ComplEx::predict_tails".to_string(),
653            });
654        }
655        let mut scores: Vec<(usize, f64)> = (0..n)
656            .map(|ti| {
657                let s = complex_score(
658                    &self.entity_re[h],
659                    &self.entity_im[h],
660                    &self.relation_re[r],
661                    &self.relation_im[r],
662                    &self.entity_re[ti],
663                    &self.entity_im[ti],
664                );
665                (ti, s)
666            })
667            .collect();
668        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
669        Ok(scores.into_iter().take(k).map(|(idx, _)| idx).collect())
670    }
671
672    /// Train one epoch with margin-based negative sampling.
673    pub fn train_epoch(&mut self, dataset: &KGDataset, lr: f64, margin: f64) -> f64 {
674        let mut total_loss = 0.0;
675        for &(h, r, t) in &dataset.triples {
676            let (nh, nr, nt) = dataset.corrupt_triple((h, r, t));
677
678            let pos = complex_score(
679                &self.entity_re[h],
680                &self.entity_im[h],
681                &self.relation_re[r],
682                &self.relation_im[r],
683                &self.entity_re[t],
684                &self.entity_im[t],
685            );
686            let neg = complex_score(
687                &self.entity_re[nh],
688                &self.entity_im[nh],
689                &self.relation_re[nr],
690                &self.relation_im[nr],
691                &self.entity_re[nt],
692                &self.entity_im[nt],
693            );
694
695            let loss = (margin - pos + neg).max(0.0);
696            total_loss += loss;
697
698            if loss > 0.0 {
699                let dim = self.dim;
700                // Gradient w.r.t. Re(h): d/d(Re(h_k)) = Re(r_k)*Re(t_k) + Im(r_k)*Im(t_k)
701                for k in 0..dim {
702                    let re_r = self.relation_re[r][k];
703                    let im_r = self.relation_im[r][k];
704                    let re_t = self.entity_re[t][k];
705                    let im_t = self.entity_im[t][k];
706                    let g_re_h = re_r * re_t + im_r * im_t;
707                    let g_im_h = re_r * im_t - im_r * re_t;
708                    self.entity_re[h][k] += lr * g_re_h;
709                    self.entity_im[h][k] += lr * g_im_h;
710
711                    // Negative gradient (subtract)
712                    let re_rn = self.relation_re[nr][k];
713                    let im_rn = self.relation_im[nr][k];
714                    let re_tn = self.entity_re[nt][k];
715                    let im_tn = self.entity_im[nt][k];
716                    let g_re_hn = re_rn * re_tn + im_rn * im_tn;
717                    let g_im_hn = re_rn * im_tn - im_rn * re_tn;
718                    self.entity_re[nh][k] -= lr * g_re_hn;
719                    self.entity_im[nh][k] -= lr * g_im_hn;
720                }
721            }
722        }
723        total_loss
724    }
725
726    fn validate_indices(&self, h: usize, r: usize, t: usize) -> Result<()> {
727        let ne = self.entity_re.len();
728        let nr = self.relation_re.len();
729        if h >= ne || t >= ne {
730            return Err(GraphError::InvalidParameter {
731                param: "entity_index".to_string(),
732                value: format!("({h},{t})"),
733                expected: format!("< {ne}"),
734                context: "ComplEx score".to_string(),
735            });
736        }
737        if r >= nr {
738            return Err(GraphError::InvalidParameter {
739                param: "relation_index".to_string(),
740                value: format!("{r}"),
741                expected: format!("< {nr}"),
742                context: "ComplEx score".to_string(),
743            });
744        }
745        Ok(())
746    }
747}
748
749/// Compute the ComplEx score between complex-valued embeddings.
750///
751/// ```text
752/// score = Re( h · r · conj(t) )
753///       = Σ_k [ Re(h_k)·Re(r_k)·Re(t_k)
754///              + Im(h_k)·Re(r_k)·Im(t_k)
755///              + Re(h_k)·Im(r_k)·Im(t_k)
756///              - Im(h_k)·Im(r_k)·Re(t_k) ]
757/// ```
758fn complex_score(
759    h_re: &[f64],
760    h_im: &[f64],
761    r_re: &[f64],
762    r_im: &[f64],
763    t_re: &[f64],
764    t_im: &[f64],
765) -> f64 {
766    h_re.iter()
767        .zip(h_im.iter())
768        .zip(r_re.iter())
769        .zip(r_im.iter())
770        .zip(t_re.iter())
771        .zip(t_im.iter())
772        .map(|(((((hre, him), rre), rim), tre), tim)| {
773            hre * rre * tre + him * rre * tim + hre * rim * tim - him * rim * tre
774        })
775        .sum()
776}
777
778// ============================================================================
779// Unified link_prediction_score helper
780// ============================================================================
781
782/// Scoring model enum for convenient dispatch.
783pub enum KgModel {
784    /// TransE model
785    TransE(TransE),
786    /// DistMult model
787    DistMult(DistMult),
788    /// ComplEx model
789    ComplEx(ComplEx),
790}
791
792impl KgModel {
793    /// Compute the link prediction score for triple `(h, r, t)`.
794    pub fn link_prediction_score(&self, h: usize, r: usize, t: usize) -> Result<f64> {
795        match self {
796            KgModel::TransE(m) => m.link_prediction_score(h, r, t),
797            KgModel::DistMult(m) => m.link_prediction_score(h, r, t),
798            KgModel::ComplEx(m) => m.link_prediction_score(h, r, t),
799        }
800    }
801}
802
803// ============================================================================
804// Tests
805// ============================================================================
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    fn simple_dataset() -> KGDataset {
812        // 4 entities (0..3), 2 relations (0..1)
813        let triples = vec![(0, 0, 1), (1, 0, 2), (2, 1, 3), (0, 1, 3)];
814        KGDataset::new(triples, 4, 2).expect("dataset")
815    }
816
817    // --- KGDataset ---
818
819    #[test]
820    fn test_dataset_creation() {
821        let ds = simple_dataset();
822        assert_eq!(ds.n_entities, 4);
823        assert_eq!(ds.n_relations, 2);
824        assert_eq!(ds.len(), 4);
825        assert!(!ds.is_empty());
826    }
827
828    #[test]
829    fn test_dataset_from_str_triples() {
830        let raw = vec![
831            ("Alice", "knows", "Bob"),
832            ("Bob", "likes", "Carol"),
833            ("Alice", "likes", "Carol"),
834        ];
835        let ds = KGDataset::from_str_triples(&raw);
836        assert_eq!(ds.n_entities, 3); // Alice, Bob, Carol
837        assert_eq!(ds.n_relations, 2); // knows, likes
838        assert_eq!(ds.len(), 3);
839    }
840
841    #[test]
842    fn test_dataset_out_of_bounds() {
843        let triples = vec![(10, 0, 1)]; // entity 10 out of range
844        let result = KGDataset::new(triples, 4, 2);
845        assert!(result.is_err());
846    }
847
848    #[test]
849    fn test_corrupt_triple_changes_entity() {
850        let ds = simple_dataset();
851        let original = (0, 0, 1);
852        let corrupted = ds.corrupt_triple(original);
853        // Either head or tail changed, relation stays same
854        assert_eq!(corrupted.1, 0);
855        assert!(corrupted.0 != 0 || corrupted.2 != 1);
856    }
857
858    // --- TransE ---
859
860    #[test]
861    fn test_transe_score_finite() {
862        let model = TransE::new(4, 2, 8).expect("TransE::new");
863        let score = model.score(0, 0, 1).expect("score");
864        assert!(score.is_finite());
865    }
866
867    #[test]
868    fn test_transe_score_range() {
869        let model = TransE::new(4, 2, 8).expect("TransE::new");
870        // Score = -distance, so <= 0 when using L2
871        let score = model.score(0, 0, 1).expect("score");
872        assert!(score <= 0.0);
873    }
874
875    #[test]
876    fn test_transe_predict_tails_length() {
877        let model = TransE::new(10, 3, 16).expect("TransE");
878        let preds = model.predict_tails(0, 0, 5).expect("predict_tails");
879        assert_eq!(preds.len(), 5);
880        // All indices valid
881        for &idx in &preds {
882            assert!(idx < 10);
883        }
884    }
885
886    #[test]
887    fn test_transe_train_epoch_reduces_loss() {
888        let ds = simple_dataset();
889        let mut model = TransE::new(4, 2, 8).expect("TransE");
890        let loss0 = model.train_epoch(&ds, 0.01, 1.0);
891        let loss1 = model.train_epoch(&ds, 0.01, 1.0);
892        // Loss should be finite
893        assert!(loss0.is_finite());
894        assert!(loss1.is_finite());
895    }
896
897    #[test]
898    fn test_transe_invalid_index() {
899        let model = TransE::new(4, 2, 8).expect("TransE");
900        assert!(model.score(10, 0, 1).is_err());
901    }
902
903    // --- DistMult ---
904
905    #[test]
906    fn test_distmult_score_finite() {
907        let model = DistMult::new(4, 2, 8).expect("DistMult");
908        let score = model.score(0, 0, 1).expect("score");
909        assert!(score.is_finite());
910    }
911
912    #[test]
913    fn test_distmult_predict_tails() {
914        let model = DistMult::new(10, 3, 16).expect("DistMult");
915        let preds = model.predict_tails(0, 1, 3).expect("predict");
916        assert_eq!(preds.len(), 3);
917    }
918
919    #[test]
920    fn test_distmult_train_epoch() {
921        let ds = simple_dataset();
922        let mut model = DistMult::new(4, 2, 8).expect("DistMult");
923        let loss = model.train_epoch(&ds, 0.01, 1.0);
924        assert!(loss.is_finite());
925    }
926
927    // --- ComplEx ---
928
929    #[test]
930    fn test_complex_score_finite() {
931        let model = ComplEx::new(4, 2, 8).expect("ComplEx");
932        let score = model.score(0, 0, 1).expect("score");
933        assert!(score.is_finite());
934    }
935
936    #[test]
937    fn test_complex_predict_tails() {
938        let model = ComplEx::new(10, 3, 16).expect("ComplEx");
939        let preds = model.predict_tails(0, 0, 4).expect("predict");
940        assert_eq!(preds.len(), 4);
941    }
942
943    #[test]
944    fn test_complex_train_epoch() {
945        let ds = simple_dataset();
946        let mut model = ComplEx::new(4, 2, 8).expect("ComplEx");
947        let loss = model.train_epoch(&ds, 0.01, 1.0);
948        assert!(loss.is_finite());
949    }
950
951    #[test]
952    fn test_complex_antisymmetry() {
953        // ComplEx can model asymmetric relations: score(h,r,t) ≠ score(t,r,h) in general
954        let model = ComplEx::new(4, 2, 16).expect("ComplEx");
955        let s1 = model.score(0, 0, 1).expect("s1");
956        let s2 = model.score(1, 0, 0).expect("s2");
957        // They are generally different (not guaranteed, but very likely)
958        // Just verify both are finite
959        assert!(s1.is_finite());
960        assert!(s2.is_finite());
961    }
962
963    // --- KgModel enum ---
964
965    #[test]
966    fn test_kgmodel_dispatch() {
967        let transe = TransE::new(4, 2, 8).expect("TransE");
968        let model = KgModel::TransE(transe);
969        let score = model.link_prediction_score(0, 0, 1).expect("score");
970        assert!(score.is_finite());
971    }
972
973    #[test]
974    fn test_multi_epoch_training_transe() {
975        let ds = simple_dataset();
976        let mut model = TransE::new(4, 2, 16).expect("TransE");
977        let mut losses = Vec::new();
978        for _ in 0..5 {
979            losses.push(model.train_epoch(&ds, 0.01, 1.0));
980        }
981        // All losses finite
982        for loss in &losses {
983            assert!(loss.is_finite());
984        }
985    }
986
987    #[test]
988    fn test_complex_score_symmetry_check() {
989        // Verify the score formula: Re(h · r · conj(t))
990        let mut model = ComplEx::new(2, 1, 2).expect("ComplEx");
991        // Manually set embeddings for a known result
992        model.entity_re[0] = vec![1.0, 0.0];
993        model.entity_im[0] = vec![0.0, 1.0];
994        model.relation_re[0] = vec![1.0, 1.0];
995        model.relation_im[0] = vec![0.0, 0.0];
996        model.entity_re[1] = vec![1.0, 0.0];
997        model.entity_im[1] = vec![0.0, 1.0];
998        // score = Re(h * r * conj(t))
999        // = Re([1+0i, 0+1i] * [1,1] * conj([1+0i, 0+1i]))
1000        // = Re([1, i] * [1,1] * [1,-i])
1001        // = Re([1*1*1, i*1*(-i)]) = Re([1, 1]) = 2.0
1002        let score = model.score(0, 0, 1).expect("manual score");
1003        assert!((score - 2.0).abs() < 1e-10, "expected 2.0, got {score}");
1004    }
1005}