Skip to main content

oxirs_embed/models/
advanced_models.rs

1//! Advanced Knowledge Graph Embedding Models
2//!
3//! Implements RotatE+, PairRE, and RESCAL — three advanced KG embedding models
4//! that handle complex relation patterns including phase-space operations,
5//! 1-N/N-N relations, and bilinear scoring.
6
7/// Simple LCG (Linear Congruential Generator) for deterministic initialization
8/// without external rand dependency.
9pub struct Lcg {
10    state: u64,
11}
12
13impl Lcg {
14    pub fn new(seed: u64) -> Self {
15        Self {
16            state: seed.wrapping_add(1),
17        }
18    }
19
20    /// Advance one step and return a value in [0.0, 1.0)
21    pub fn next_f32(&mut self) -> f32 {
22        // Knuth multiplicative LCG constants
23        self.state = self
24            .state
25            .wrapping_mul(6_364_136_223_846_793_005)
26            .wrapping_add(1_442_695_040_888_963_407);
27        // Take upper 32 bits for uniform [0, 1)
28        ((self.state >> 33) as f32) / (u32::MAX as f32)
29    }
30
31    /// Return a value in [0.0, max)
32    pub fn next_f32_range(&mut self, max: f32) -> f32 {
33        self.next_f32() * max
34    }
35}
36
37// ─────────────────────────────────────────────
38// RotatE+
39// ─────────────────────────────────────────────
40
41/// RotatE+ with phase-space operations.
42///
43/// Entities and relations are represented as phase vectors in [0, 2π).
44/// The scoring function computes the L1 distance in phase space after
45/// applying the relational rotation: score = -||h ∘ r - t||_1.
46#[derive(Debug, Clone)]
47pub struct RotatEPlus {
48    /// Entity phase embeddings: `[num_entities][dim]`, values in \[0, 2π)
49    pub entity_phase: Vec<Vec<f32>>,
50    /// Relation phase embeddings: `[num_relations][dim]`, values in \[0, 2π)
51    pub relation_phase: Vec<Vec<f32>>,
52    /// Embedding dimension
53    pub dim: usize,
54}
55
56impl RotatEPlus {
57    /// Create a new RotatE+ model with random phase initialization.
58    pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
59        let two_pi = 2.0 * std::f32::consts::PI;
60        let mut lcg = Lcg::new(42);
61
62        let entity_phase = (0..num_entities)
63            .map(|_| (0..dim).map(|_| lcg.next_f32_range(two_pi)).collect())
64            .collect();
65
66        let relation_phase = (0..num_relations)
67            .map(|_| (0..dim).map(|_| lcg.next_f32_range(two_pi)).collect())
68            .collect();
69
70        Self {
71            entity_phase,
72            relation_phase,
73            dim,
74        }
75    }
76
77    /// Compute score = -||h ∘ r - t||_1
78    /// where ∘ is element-wise phase addition mod 2π.
79    pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
80        let two_pi = 2.0 * std::f32::consts::PI;
81        let h = &self.entity_phase[head];
82        let r = &self.relation_phase[relation];
83        let t = &self.entity_phase[tail];
84
85        let l1: f32 = (0..self.dim)
86            .map(|i| {
87                // Phase addition mod 2π
88                let rotated = (h[i] + r[i]) % two_pi;
89                // Circular distance
90                let raw = (rotated - t[i]).abs();
91                raw.min(two_pi - raw)
92            })
93            .sum();
94
95        -l1
96    }
97
98    /// Update embeddings via margin-based gradient step.
99    ///
100    /// `pos_score` and `neg_score` are output from `score()`.
101    /// We push pos_score higher and neg_score lower.
102    pub fn update(
103        &mut self,
104        head: usize,
105        relation: usize,
106        tail: usize,
107        pos_score: f32,
108        neg_score: f32,
109        lr: f32,
110    ) {
111        let two_pi = 2.0 * std::f32::consts::PI;
112        let margin = 1.0_f32;
113        let loss_gradient = if pos_score - neg_score < margin {
114            1.0_f32
115        } else {
116            0.0_f32
117        };
118
119        if loss_gradient.abs() < 1e-9 {
120            return;
121        }
122
123        // Gradient sign: increase positive score (decrease L1), decrease negative
124        for i in 0..self.dim {
125            let h_phase = self.entity_phase[head][i];
126            let r_phase = self.relation_phase[relation][i];
127            let t_phase = self.entity_phase[tail][i];
128
129            let rotated = (h_phase + r_phase) % two_pi;
130            let diff = rotated - t_phase;
131            // Sign of gradient for L1
132            let sign = if diff > 0.0 { 1.0_f32 } else { -1.0_f32 };
133
134            // Positive example: push score up → decrease L1 distance
135            let grad = sign * loss_gradient * lr;
136            self.entity_phase[head][i] = (self.entity_phase[head][i] - grad).rem_euclid(two_pi);
137            self.relation_phase[relation][i] =
138                (self.relation_phase[relation][i] - grad).rem_euclid(two_pi);
139            self.entity_phase[tail][i] = (self.entity_phase[tail][i] + grad).rem_euclid(two_pi);
140        }
141    }
142
143    /// Number of entities
144    pub fn entity_count(&self) -> usize {
145        self.entity_phase.len()
146    }
147
148    /// Number of relations
149    pub fn relation_count(&self) -> usize {
150        self.relation_phase.len()
151    }
152}
153
154// ─────────────────────────────────────────────
155// PairRE
156// ─────────────────────────────────────────────
157
158/// PairRE: Handles 1-N, N-1, and N-N relations with paired relation vectors.
159///
160/// Each relation has two vectors r_h (applied to head) and r_t (applied to tail).
161/// score = -||h ⊙ r_h - t ⊙ r_t||_2
162#[derive(Debug, Clone)]
163pub struct PairRE {
164    /// Entity embeddings: `[num_entities][dim]`
165    pub entity_emb: Vec<Vec<f32>>,
166    /// Head-side relation vectors: `[num_relations][dim]`
167    pub relation_head: Vec<Vec<f32>>,
168    /// Tail-side relation vectors: `[num_relations][dim]`
169    pub relation_tail: Vec<Vec<f32>>,
170    /// Embedding dimension
171    pub dim: usize,
172}
173
174impl PairRE {
175    /// Create a new PairRE model with small random initializations.
176    pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
177        let mut lcg = Lcg::new(7);
178        let scale = 0.1_f32;
179
180        let entity_emb = (0..num_entities)
181            .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
182            .collect();
183
184        let relation_head = (0..num_relations)
185            .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
186            .collect();
187
188        let relation_tail = (0..num_relations)
189            .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
190            .collect();
191
192        Self {
193            entity_emb,
194            relation_head,
195            relation_tail,
196            dim,
197        }
198    }
199
200    /// Compute score = -||h ⊙ rh - t ⊙ rt||_2
201    pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
202        let h = &self.entity_emb[head];
203        let rh = &self.relation_head[relation];
204        let t = &self.entity_emb[tail];
205        let rt = &self.relation_tail[relation];
206
207        let l2_sq: f32 = (0..self.dim)
208            .map(|i| {
209                let diff = h[i] * rh[i] - t[i] * rt[i];
210                diff * diff
211            })
212            .sum();
213
214        -l2_sq.sqrt()
215    }
216
217    /// SGD-style update step.
218    pub fn update(&mut self, head: usize, relation: usize, tail: usize, label: f32, lr: f32) {
219        // Compute current score and gradient
220        let h = self.entity_emb[head].clone();
221        let rh = self.relation_head[relation].clone();
222        let t = self.entity_emb[tail].clone();
223        let rt = self.relation_tail[relation].clone();
224
225        let diffs: Vec<f32> = (0..self.dim).map(|i| h[i] * rh[i] - t[i] * rt[i]).collect();
226        let norm: f32 = diffs.iter().map(|d| d * d).sum::<f32>().sqrt().max(1e-8);
227
228        // Gradient of ||...||_2 w.r.t. diff[i] = diff[i] / norm
229        // score = -||d||_2, so d_score/d_diff[i] = -diff[i]/norm
230        // For positive (label=+1): maximize score → update in +gradient direction → diff decreases
231        // For negative (label=-1): minimize score → update in -gradient direction → diff increases
232        // sign = +label achieves this: positive label pulls diff toward zero, negative pushes out
233        let sign = label;
234
235        for i in 0..self.dim {
236            let grad = sign * diffs[i] / norm;
237            self.entity_emb[head][i] -= lr * grad * rh[i];
238            self.relation_head[relation][i] -= lr * grad * h[i];
239            self.entity_emb[tail][i] += lr * grad * rt[i];
240            self.relation_tail[relation][i] += lr * grad * t[i];
241        }
242    }
243
244    /// Predict top-k tail entities for a (head, relation) query.
245    pub fn predict_tail(&self, head: usize, relation: usize, top_k: usize) -> Vec<(usize, f32)> {
246        let mut scores: Vec<(usize, f32)> = (0..self.entity_emb.len())
247            .map(|tail_idx| (tail_idx, self.score(head, relation, tail_idx)))
248            .collect();
249
250        // Sort descending by score
251        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
252        scores.truncate(top_k);
253        scores
254    }
255
256    /// Number of entities
257    pub fn entity_count(&self) -> usize {
258        self.entity_emb.len()
259    }
260}
261
262// ─────────────────────────────────────────────
263// RESCAL
264// ─────────────────────────────────────────────
265
266/// RESCAL: Bilinear model for knowledge graph embedding.
267///
268/// Each relation has a full dim×dim matrix. Scoring: h^T * M_r * t
269#[derive(Debug, Clone)]
270pub struct Rescal {
271    /// Entity embeddings: `[num_entities][dim]`
272    pub entity_emb: Vec<Vec<f32>>,
273    /// Relation matrices: `[num_relations][dim][dim]`
274    pub relation_mat: Vec<Vec<Vec<f32>>>,
275    /// Embedding dimension
276    pub dim: usize,
277}
278
279impl Rescal {
280    /// Create a new RESCAL model.
281    pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
282        let mut lcg = Lcg::new(13);
283        let e_scale = 1.0 / (dim as f32).sqrt();
284        let m_scale = 1.0 / (dim as f32);
285
286        let entity_emb = (0..num_entities)
287            .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * e_scale).collect())
288            .collect();
289
290        let relation_mat = (0..num_relations)
291            .map(|_| {
292                (0..dim)
293                    .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * m_scale).collect())
294                    .collect()
295            })
296            .collect();
297
298        Self {
299            entity_emb,
300            relation_mat,
301            dim,
302        }
303    }
304
305    /// Compute score = h^T * M_r * t
306    pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
307        let h = &self.entity_emb[head];
308        let t = &self.entity_emb[tail];
309        let m = &self.relation_mat[relation];
310
311        // M_r * t → dim-vector
312        let mt: Vec<f32> = (0..self.dim)
313            .map(|i| (0..self.dim).map(|j| m[i][j] * t[j]).sum())
314            .collect();
315
316        // h^T * (M_r * t)
317        h.iter().zip(mt.iter()).map(|(hi, mti)| hi * mti).sum()
318    }
319
320    /// SGD update step minimizing squared loss (score - label)^2.
321    pub fn update(&mut self, head: usize, relation: usize, tail: usize, label: f32, lr: f32) {
322        let s = self.score(head, relation, tail);
323        let err = s - label; // gradient of 0.5*(s-label)^2 = err
324
325        let h = self.entity_emb[head].clone();
326        let t = self.entity_emb[tail].clone();
327        let m = self.relation_mat[relation].clone();
328
329        // ∂loss/∂h_i = err * (M_r * t)[i]
330        let mt: Vec<f32> = (0..self.dim)
331            .map(|i| (0..self.dim).map(|j| m[i][j] * t[j]).sum())
332            .collect();
333
334        // ∂loss/∂t_j = err * (h^T * M_r)[j]
335        let hm: Vec<f32> = (0..self.dim)
336            .map(|j| (0..self.dim).map(|i| h[i] * m[i][j]).sum())
337            .collect();
338
339        // Apply gradients
340        for i in 0..self.dim {
341            self.entity_emb[head][i] -= lr * err * mt[i];
342            self.entity_emb[tail][i] -= lr * err * hm[i];
343            for (j, t_j) in t.iter().enumerate() {
344                self.relation_mat[relation][i][j] -= lr * err * h[i] * t_j;
345            }
346        }
347    }
348
349    /// Access the relation matrix for a given relation.
350    pub fn relation_matrix(&self, relation: usize) -> &Vec<Vec<f32>> {
351        &self.relation_mat[relation]
352    }
353
354    /// Number of entities
355    pub fn entity_count(&self) -> usize {
356        self.entity_emb.len()
357    }
358
359    /// Number of relations
360    pub fn relation_count(&self) -> usize {
361        self.relation_mat.len()
362    }
363}
364
365// ─────────────────────────────────────────────
366// Tests
367// ─────────────────────────────────────────────
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    // ── LCG ──────────────────────────────────
374
375    #[test]
376    fn test_lcg_range() {
377        let mut lcg = Lcg::new(1);
378        for _ in 0..1000 {
379            let v = lcg.next_f32();
380            assert!((0.0..1.0).contains(&v), "LCG value out of [0,1): {v}");
381        }
382    }
383
384    #[test]
385    fn test_lcg_deterministic() {
386        let mut a = Lcg::new(99);
387        let mut b = Lcg::new(99);
388        for _ in 0..50 {
389            assert_eq!(a.next_f32().to_bits(), b.next_f32().to_bits());
390        }
391    }
392
393    // ── RotatE+ ───────────────────────────────
394
395    #[test]
396    fn test_rotate_plus_creation() {
397        let m = RotatEPlus::new(10, 5, 16);
398        assert_eq!(m.entity_count(), 10);
399        assert_eq!(m.relation_count(), 5);
400        assert_eq!(m.dim, 16);
401    }
402
403    #[test]
404    fn test_rotate_plus_phases_in_range() {
405        let m = RotatEPlus::new(5, 3, 8);
406        let two_pi = 2.0 * std::f32::consts::PI;
407        for row in &m.entity_phase {
408            for &v in row {
409                assert!(v >= 0.0 && v < two_pi, "entity phase out of range: {v}");
410            }
411        }
412        for row in &m.relation_phase {
413            for &v in row {
414                assert!(v >= 0.0 && v < two_pi, "relation phase out of range: {v}");
415            }
416        }
417    }
418
419    #[test]
420    fn test_rotate_plus_score_is_finite() {
421        let m = RotatEPlus::new(4, 2, 8);
422        let s = m.score(0, 0, 1);
423        assert!(s.is_finite(), "score should be finite: {s}");
424    }
425
426    #[test]
427    fn test_rotate_plus_score_non_positive() {
428        let m = RotatEPlus::new(4, 2, 8);
429        let s = m.score(0, 0, 1);
430        assert!(s <= 0.0, "RotatE+ score should be ≤ 0 (it is -L1): {s}");
431    }
432
433    #[test]
434    fn test_rotate_plus_self_score() {
435        // score(h, r, h) should be -||r||_1 (not zero in general),
436        // but just verify it's finite and ≤ 0
437        let m = RotatEPlus::new(4, 2, 8);
438        let s = m.score(0, 0, 0);
439        assert!(s.is_finite() && s <= 0.0);
440    }
441
442    #[test]
443    fn test_rotate_plus_update_changes_embeddings() {
444        let mut m = RotatEPlus::new(4, 2, 8);
445        let before_h = m.entity_phase[0].clone();
446        let pos_score = m.score(0, 0, 1);
447        let neg_score = m.score(0, 0, 2);
448        m.update(0, 0, 1, pos_score, neg_score, 0.01);
449        // At least one phase should have changed
450        let changed = m.entity_phase[0]
451            .iter()
452            .zip(before_h.iter())
453            .any(|(a, b)| (a - b).abs() > 1e-9);
454        assert!(changed, "update should modify entity phases");
455    }
456
457    #[test]
458    fn test_rotate_plus_update_keeps_phases_in_range() {
459        let mut m = RotatEPlus::new(4, 2, 8);
460        let two_pi = 2.0 * std::f32::consts::PI;
461        let pos_score = m.score(0, 0, 1);
462        let neg_score = m.score(0, 0, 2) - 2.0; // force margin violation
463        m.update(0, 0, 1, pos_score, neg_score, 0.5);
464        for &v in &m.entity_phase[0] {
465            assert!(
466                v >= 0.0 && v < two_pi + 1e-5,
467                "phase out of range after update: {v}"
468            );
469        }
470    }
471
472    #[test]
473    fn test_rotate_plus_training_loop() {
474        let mut m = RotatEPlus::new(6, 3, 16);
475        let triples = [(0usize, 0usize, 1usize), (1, 1, 2), (2, 2, 3)];
476        for _ in 0..20 {
477            for &(h, r, t) in &triples {
478                let neg_t = (t + 1) % 6;
479                let ps = m.score(h, r, t);
480                let ns = m.score(h, r, neg_t);
481                m.update(h, r, t, ps, ns, 0.01);
482            }
483        }
484        // Should not panic and scores should be finite
485        for &(h, r, t) in &triples {
486            assert!(m.score(h, r, t).is_finite());
487        }
488    }
489
490    // ── PairRE ────────────────────────────────
491
492    #[test]
493    fn test_pairre_creation() {
494        let m = PairRE::new(8, 4, 16);
495        assert_eq!(m.entity_count(), 8);
496        assert_eq!(m.dim, 16);
497    }
498
499    #[test]
500    fn test_pairre_score_finite() {
501        let m = PairRE::new(5, 3, 8);
502        let s = m.score(0, 0, 1);
503        assert!(s.is_finite(), "PairRE score should be finite: {s}");
504    }
505
506    #[test]
507    fn test_pairre_score_non_positive() {
508        let m = PairRE::new(5, 3, 8);
509        let s = m.score(0, 0, 1);
510        assert!(s <= 0.0, "PairRE score should be ≤ 0 (it is -L2): {s}");
511    }
512
513    #[test]
514    fn test_pairre_update_changes_embeddings() {
515        let mut m = PairRE::new(5, 3, 8);
516        let before = m.entity_emb[0].clone();
517        m.update(0, 0, 1, 1.0, 0.01);
518        let changed = m.entity_emb[0]
519            .iter()
520            .zip(before.iter())
521            .any(|(a, b)| (a - b).abs() > 1e-9);
522        assert!(changed, "update should modify embeddings");
523    }
524
525    #[test]
526    fn test_pairre_predict_tail_returns_correct_count() {
527        let m = PairRE::new(10, 3, 8);
528        let preds = m.predict_tail(0, 0, 5);
529        assert_eq!(preds.len(), 5);
530    }
531
532    #[test]
533    fn test_pairre_predict_tail_sorted_desc() {
534        let m = PairRE::new(10, 3, 8);
535        let preds = m.predict_tail(0, 0, 5);
536        for w in preds.windows(2) {
537            assert!(
538                w[0].1 >= w[1].1,
539                "predictions should be sorted descending by score"
540            );
541        }
542    }
543
544    #[test]
545    fn test_pairre_predict_tail_k_larger_than_entities() {
546        let m = PairRE::new(3, 2, 8);
547        let preds = m.predict_tail(0, 0, 100);
548        assert_eq!(preds.len(), 3); // capped at entity count
549    }
550
551    #[test]
552    fn test_pairre_training_positive_vs_negative() {
553        let mut m = PairRE::new(8, 4, 16);
554        // After many steps the positive triple should score higher than random negative
555        for _ in 0..100 {
556            m.update(0, 0, 1, 1.0, 0.01); // positive
557            m.update(0, 0, 2, -1.0, 0.01); // negative
558        }
559        let pos_score = m.score(0, 0, 1);
560        let neg_score = m.score(0, 0, 2);
561        assert!(
562            pos_score > neg_score,
563            "positive score {pos_score} should exceed negative {neg_score}"
564        );
565    }
566
567    // ── RESCAL ────────────────────────────────
568
569    #[test]
570    fn test_rescal_creation() {
571        let m = Rescal::new(6, 3, 8);
572        assert_eq!(m.entity_count(), 6);
573        assert_eq!(m.relation_count(), 3);
574        assert_eq!(m.dim, 8);
575    }
576
577    #[test]
578    fn test_rescal_score_finite() {
579        let m = Rescal::new(5, 3, 8);
580        let s = m.score(0, 0, 1);
581        assert!(s.is_finite(), "RESCAL score should be finite: {s}");
582    }
583
584    #[test]
585    fn test_rescal_relation_matrix_shape() {
586        let m = Rescal::new(5, 3, 8);
587        let mat = m.relation_matrix(0);
588        assert_eq!(mat.len(), 8);
589        assert_eq!(mat[0].len(), 8);
590    }
591
592    #[test]
593    fn test_rescal_update_changes_embeddings() {
594        let mut m = Rescal::new(5, 3, 8);
595        let before = m.entity_emb[0].clone();
596        m.update(0, 0, 1, 1.0, 0.01);
597        let changed = m.entity_emb[0]
598            .iter()
599            .zip(before.iter())
600            .any(|(a, b)| (a - b).abs() > 1e-9);
601        assert!(changed, "update should modify entity embeddings");
602    }
603
604    #[test]
605    fn test_rescal_training_converges() {
606        let mut m = Rescal::new(5, 2, 4);
607        // Train on one positive triple, expect score to increase toward 1.0
608        let initial_score = m.score(0, 0, 1);
609        for _ in 0..500 {
610            m.update(0, 0, 1, 1.0, 0.001);
611        }
612        let final_score = m.score(0, 0, 1);
613        assert!(
614            final_score > initial_score,
615            "RESCAL score should increase toward label"
616        );
617    }
618
619    #[test]
620    fn test_rescal_antisymmetric_scores() {
621        let m = Rescal::new(5, 3, 8);
622        // score(h, r, t) and score(t, r, h) generally differ for RESCAL (bilinear)
623        let s_fwd = m.score(0, 0, 1);
624        let s_bwd = m.score(1, 0, 0);
625        // They may or may not be equal by chance; just check both are finite
626        assert!(s_fwd.is_finite() && s_bwd.is_finite());
627    }
628
629    #[test]
630    fn test_rescal_different_relations_give_different_scores() {
631        let m = Rescal::new(5, 4, 8);
632        let s0 = m.score(0, 0, 1);
633        let s1 = m.score(0, 1, 1);
634        let s2 = m.score(0, 2, 1);
635        // At least two of them should differ (with very high probability for random init)
636        assert!(
637            (s0 - s1).abs() > 1e-6 || (s1 - s2).abs() > 1e-6,
638            "Different relations should produce different scores"
639        );
640    }
641
642    #[test]
643    fn test_all_models_score_interface() {
644        let rotate = RotatEPlus::new(4, 2, 8);
645        let pairre = PairRE::new(4, 2, 8);
646        let rescal = Rescal::new(4, 2, 8);
647
648        // All should produce finite scores for valid indices
649        assert!(rotate.score(0, 0, 1).is_finite());
650        assert!(pairre.score(0, 0, 1).is_finite());
651        assert!(rescal.score(0, 0, 1).is_finite());
652    }
653}