1pub struct Lcg {
10 state: u64,
11}
12
13impl Lcg {
14 pub fn new(seed: u64) -> Self {
15 Self {
16 state: seed.wrapping_add(1),
17 }
18 }
19
20 pub fn next_f32(&mut self) -> f32 {
22 self.state = self
24 .state
25 .wrapping_mul(6_364_136_223_846_793_005)
26 .wrapping_add(1_442_695_040_888_963_407);
27 ((self.state >> 33) as f32) / (u32::MAX as f32)
29 }
30
31 pub fn next_f32_range(&mut self, max: f32) -> f32 {
33 self.next_f32() * max
34 }
35}
36
37#[derive(Debug, Clone)]
47pub struct RotatEPlus {
48 pub entity_phase: Vec<Vec<f32>>,
50 pub relation_phase: Vec<Vec<f32>>,
52 pub dim: usize,
54}
55
56impl RotatEPlus {
57 pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
59 let two_pi = 2.0 * std::f32::consts::PI;
60 let mut lcg = Lcg::new(42);
61
62 let entity_phase = (0..num_entities)
63 .map(|_| (0..dim).map(|_| lcg.next_f32_range(two_pi)).collect())
64 .collect();
65
66 let relation_phase = (0..num_relations)
67 .map(|_| (0..dim).map(|_| lcg.next_f32_range(two_pi)).collect())
68 .collect();
69
70 Self {
71 entity_phase,
72 relation_phase,
73 dim,
74 }
75 }
76
77 pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
80 let two_pi = 2.0 * std::f32::consts::PI;
81 let h = &self.entity_phase[head];
82 let r = &self.relation_phase[relation];
83 let t = &self.entity_phase[tail];
84
85 let l1: f32 = (0..self.dim)
86 .map(|i| {
87 let rotated = (h[i] + r[i]) % two_pi;
89 let raw = (rotated - t[i]).abs();
91 raw.min(two_pi - raw)
92 })
93 .sum();
94
95 -l1
96 }
97
98 pub fn update(
103 &mut self,
104 head: usize,
105 relation: usize,
106 tail: usize,
107 pos_score: f32,
108 neg_score: f32,
109 lr: f32,
110 ) {
111 let two_pi = 2.0 * std::f32::consts::PI;
112 let margin = 1.0_f32;
113 let loss_gradient = if pos_score - neg_score < margin {
114 1.0_f32
115 } else {
116 0.0_f32
117 };
118
119 if loss_gradient.abs() < 1e-9 {
120 return;
121 }
122
123 for i in 0..self.dim {
125 let h_phase = self.entity_phase[head][i];
126 let r_phase = self.relation_phase[relation][i];
127 let t_phase = self.entity_phase[tail][i];
128
129 let rotated = (h_phase + r_phase) % two_pi;
130 let diff = rotated - t_phase;
131 let sign = if diff > 0.0 { 1.0_f32 } else { -1.0_f32 };
133
134 let grad = sign * loss_gradient * lr;
136 self.entity_phase[head][i] = (self.entity_phase[head][i] - grad).rem_euclid(two_pi);
137 self.relation_phase[relation][i] =
138 (self.relation_phase[relation][i] - grad).rem_euclid(two_pi);
139 self.entity_phase[tail][i] = (self.entity_phase[tail][i] + grad).rem_euclid(two_pi);
140 }
141 }
142
143 pub fn entity_count(&self) -> usize {
145 self.entity_phase.len()
146 }
147
148 pub fn relation_count(&self) -> usize {
150 self.relation_phase.len()
151 }
152}
153
154#[derive(Debug, Clone)]
163pub struct PairRE {
164 pub entity_emb: Vec<Vec<f32>>,
166 pub relation_head: Vec<Vec<f32>>,
168 pub relation_tail: Vec<Vec<f32>>,
170 pub dim: usize,
172}
173
174impl PairRE {
175 pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
177 let mut lcg = Lcg::new(7);
178 let scale = 0.1_f32;
179
180 let entity_emb = (0..num_entities)
181 .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
182 .collect();
183
184 let relation_head = (0..num_relations)
185 .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
186 .collect();
187
188 let relation_tail = (0..num_relations)
189 .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * scale).collect())
190 .collect();
191
192 Self {
193 entity_emb,
194 relation_head,
195 relation_tail,
196 dim,
197 }
198 }
199
200 pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
202 let h = &self.entity_emb[head];
203 let rh = &self.relation_head[relation];
204 let t = &self.entity_emb[tail];
205 let rt = &self.relation_tail[relation];
206
207 let l2_sq: f32 = (0..self.dim)
208 .map(|i| {
209 let diff = h[i] * rh[i] - t[i] * rt[i];
210 diff * diff
211 })
212 .sum();
213
214 -l2_sq.sqrt()
215 }
216
217 pub fn update(&mut self, head: usize, relation: usize, tail: usize, label: f32, lr: f32) {
219 let h = self.entity_emb[head].clone();
221 let rh = self.relation_head[relation].clone();
222 let t = self.entity_emb[tail].clone();
223 let rt = self.relation_tail[relation].clone();
224
225 let diffs: Vec<f32> = (0..self.dim).map(|i| h[i] * rh[i] - t[i] * rt[i]).collect();
226 let norm: f32 = diffs.iter().map(|d| d * d).sum::<f32>().sqrt().max(1e-8);
227
228 let sign = label;
234
235 for i in 0..self.dim {
236 let grad = sign * diffs[i] / norm;
237 self.entity_emb[head][i] -= lr * grad * rh[i];
238 self.relation_head[relation][i] -= lr * grad * h[i];
239 self.entity_emb[tail][i] += lr * grad * rt[i];
240 self.relation_tail[relation][i] += lr * grad * t[i];
241 }
242 }
243
244 pub fn predict_tail(&self, head: usize, relation: usize, top_k: usize) -> Vec<(usize, f32)> {
246 let mut scores: Vec<(usize, f32)> = (0..self.entity_emb.len())
247 .map(|tail_idx| (tail_idx, self.score(head, relation, tail_idx)))
248 .collect();
249
250 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
252 scores.truncate(top_k);
253 scores
254 }
255
256 pub fn entity_count(&self) -> usize {
258 self.entity_emb.len()
259 }
260}
261
262#[derive(Debug, Clone)]
270pub struct Rescal {
271 pub entity_emb: Vec<Vec<f32>>,
273 pub relation_mat: Vec<Vec<Vec<f32>>>,
275 pub dim: usize,
277}
278
279impl Rescal {
280 pub fn new(num_entities: usize, num_relations: usize, dim: usize) -> Self {
282 let mut lcg = Lcg::new(13);
283 let e_scale = 1.0 / (dim as f32).sqrt();
284 let m_scale = 1.0 / (dim as f32);
285
286 let entity_emb = (0..num_entities)
287 .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * e_scale).collect())
288 .collect();
289
290 let relation_mat = (0..num_relations)
291 .map(|_| {
292 (0..dim)
293 .map(|_| (0..dim).map(|_| (lcg.next_f32() - 0.5) * m_scale).collect())
294 .collect()
295 })
296 .collect();
297
298 Self {
299 entity_emb,
300 relation_mat,
301 dim,
302 }
303 }
304
305 pub fn score(&self, head: usize, relation: usize, tail: usize) -> f32 {
307 let h = &self.entity_emb[head];
308 let t = &self.entity_emb[tail];
309 let m = &self.relation_mat[relation];
310
311 let mt: Vec<f32> = (0..self.dim)
313 .map(|i| (0..self.dim).map(|j| m[i][j] * t[j]).sum())
314 .collect();
315
316 h.iter().zip(mt.iter()).map(|(hi, mti)| hi * mti).sum()
318 }
319
320 pub fn update(&mut self, head: usize, relation: usize, tail: usize, label: f32, lr: f32) {
322 let s = self.score(head, relation, tail);
323 let err = s - label; let h = self.entity_emb[head].clone();
326 let t = self.entity_emb[tail].clone();
327 let m = self.relation_mat[relation].clone();
328
329 let mt: Vec<f32> = (0..self.dim)
331 .map(|i| (0..self.dim).map(|j| m[i][j] * t[j]).sum())
332 .collect();
333
334 let hm: Vec<f32> = (0..self.dim)
336 .map(|j| (0..self.dim).map(|i| h[i] * m[i][j]).sum())
337 .collect();
338
339 for i in 0..self.dim {
341 self.entity_emb[head][i] -= lr * err * mt[i];
342 self.entity_emb[tail][i] -= lr * err * hm[i];
343 for (j, t_j) in t.iter().enumerate() {
344 self.relation_mat[relation][i][j] -= lr * err * h[i] * t_j;
345 }
346 }
347 }
348
349 pub fn relation_matrix(&self, relation: usize) -> &Vec<Vec<f32>> {
351 &self.relation_mat[relation]
352 }
353
354 pub fn entity_count(&self) -> usize {
356 self.entity_emb.len()
357 }
358
359 pub fn relation_count(&self) -> usize {
361 self.relation_mat.len()
362 }
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
376 fn test_lcg_range() {
377 let mut lcg = Lcg::new(1);
378 for _ in 0..1000 {
379 let v = lcg.next_f32();
380 assert!((0.0..1.0).contains(&v), "LCG value out of [0,1): {v}");
381 }
382 }
383
384 #[test]
385 fn test_lcg_deterministic() {
386 let mut a = Lcg::new(99);
387 let mut b = Lcg::new(99);
388 for _ in 0..50 {
389 assert_eq!(a.next_f32().to_bits(), b.next_f32().to_bits());
390 }
391 }
392
393 #[test]
396 fn test_rotate_plus_creation() {
397 let m = RotatEPlus::new(10, 5, 16);
398 assert_eq!(m.entity_count(), 10);
399 assert_eq!(m.relation_count(), 5);
400 assert_eq!(m.dim, 16);
401 }
402
403 #[test]
404 fn test_rotate_plus_phases_in_range() {
405 let m = RotatEPlus::new(5, 3, 8);
406 let two_pi = 2.0 * std::f32::consts::PI;
407 for row in &m.entity_phase {
408 for &v in row {
409 assert!(v >= 0.0 && v < two_pi, "entity phase out of range: {v}");
410 }
411 }
412 for row in &m.relation_phase {
413 for &v in row {
414 assert!(v >= 0.0 && v < two_pi, "relation phase out of range: {v}");
415 }
416 }
417 }
418
419 #[test]
420 fn test_rotate_plus_score_is_finite() {
421 let m = RotatEPlus::new(4, 2, 8);
422 let s = m.score(0, 0, 1);
423 assert!(s.is_finite(), "score should be finite: {s}");
424 }
425
426 #[test]
427 fn test_rotate_plus_score_non_positive() {
428 let m = RotatEPlus::new(4, 2, 8);
429 let s = m.score(0, 0, 1);
430 assert!(s <= 0.0, "RotatE+ score should be ≤ 0 (it is -L1): {s}");
431 }
432
433 #[test]
434 fn test_rotate_plus_self_score() {
435 let m = RotatEPlus::new(4, 2, 8);
438 let s = m.score(0, 0, 0);
439 assert!(s.is_finite() && s <= 0.0);
440 }
441
442 #[test]
443 fn test_rotate_plus_update_changes_embeddings() {
444 let mut m = RotatEPlus::new(4, 2, 8);
445 let before_h = m.entity_phase[0].clone();
446 let pos_score = m.score(0, 0, 1);
447 let neg_score = m.score(0, 0, 2);
448 m.update(0, 0, 1, pos_score, neg_score, 0.01);
449 let changed = m.entity_phase[0]
451 .iter()
452 .zip(before_h.iter())
453 .any(|(a, b)| (a - b).abs() > 1e-9);
454 assert!(changed, "update should modify entity phases");
455 }
456
457 #[test]
458 fn test_rotate_plus_update_keeps_phases_in_range() {
459 let mut m = RotatEPlus::new(4, 2, 8);
460 let two_pi = 2.0 * std::f32::consts::PI;
461 let pos_score = m.score(0, 0, 1);
462 let neg_score = m.score(0, 0, 2) - 2.0; m.update(0, 0, 1, pos_score, neg_score, 0.5);
464 for &v in &m.entity_phase[0] {
465 assert!(
466 v >= 0.0 && v < two_pi + 1e-5,
467 "phase out of range after update: {v}"
468 );
469 }
470 }
471
472 #[test]
473 fn test_rotate_plus_training_loop() {
474 let mut m = RotatEPlus::new(6, 3, 16);
475 let triples = [(0usize, 0usize, 1usize), (1, 1, 2), (2, 2, 3)];
476 for _ in 0..20 {
477 for &(h, r, t) in &triples {
478 let neg_t = (t + 1) % 6;
479 let ps = m.score(h, r, t);
480 let ns = m.score(h, r, neg_t);
481 m.update(h, r, t, ps, ns, 0.01);
482 }
483 }
484 for &(h, r, t) in &triples {
486 assert!(m.score(h, r, t).is_finite());
487 }
488 }
489
490 #[test]
493 fn test_pairre_creation() {
494 let m = PairRE::new(8, 4, 16);
495 assert_eq!(m.entity_count(), 8);
496 assert_eq!(m.dim, 16);
497 }
498
499 #[test]
500 fn test_pairre_score_finite() {
501 let m = PairRE::new(5, 3, 8);
502 let s = m.score(0, 0, 1);
503 assert!(s.is_finite(), "PairRE score should be finite: {s}");
504 }
505
506 #[test]
507 fn test_pairre_score_non_positive() {
508 let m = PairRE::new(5, 3, 8);
509 let s = m.score(0, 0, 1);
510 assert!(s <= 0.0, "PairRE score should be ≤ 0 (it is -L2): {s}");
511 }
512
513 #[test]
514 fn test_pairre_update_changes_embeddings() {
515 let mut m = PairRE::new(5, 3, 8);
516 let before = m.entity_emb[0].clone();
517 m.update(0, 0, 1, 1.0, 0.01);
518 let changed = m.entity_emb[0]
519 .iter()
520 .zip(before.iter())
521 .any(|(a, b)| (a - b).abs() > 1e-9);
522 assert!(changed, "update should modify embeddings");
523 }
524
525 #[test]
526 fn test_pairre_predict_tail_returns_correct_count() {
527 let m = PairRE::new(10, 3, 8);
528 let preds = m.predict_tail(0, 0, 5);
529 assert_eq!(preds.len(), 5);
530 }
531
532 #[test]
533 fn test_pairre_predict_tail_sorted_desc() {
534 let m = PairRE::new(10, 3, 8);
535 let preds = m.predict_tail(0, 0, 5);
536 for w in preds.windows(2) {
537 assert!(
538 w[0].1 >= w[1].1,
539 "predictions should be sorted descending by score"
540 );
541 }
542 }
543
544 #[test]
545 fn test_pairre_predict_tail_k_larger_than_entities() {
546 let m = PairRE::new(3, 2, 8);
547 let preds = m.predict_tail(0, 0, 100);
548 assert_eq!(preds.len(), 3); }
550
551 #[test]
552 fn test_pairre_training_positive_vs_negative() {
553 let mut m = PairRE::new(8, 4, 16);
554 for _ in 0..100 {
556 m.update(0, 0, 1, 1.0, 0.01); m.update(0, 0, 2, -1.0, 0.01); }
559 let pos_score = m.score(0, 0, 1);
560 let neg_score = m.score(0, 0, 2);
561 assert!(
562 pos_score > neg_score,
563 "positive score {pos_score} should exceed negative {neg_score}"
564 );
565 }
566
567 #[test]
570 fn test_rescal_creation() {
571 let m = Rescal::new(6, 3, 8);
572 assert_eq!(m.entity_count(), 6);
573 assert_eq!(m.relation_count(), 3);
574 assert_eq!(m.dim, 8);
575 }
576
577 #[test]
578 fn test_rescal_score_finite() {
579 let m = Rescal::new(5, 3, 8);
580 let s = m.score(0, 0, 1);
581 assert!(s.is_finite(), "RESCAL score should be finite: {s}");
582 }
583
584 #[test]
585 fn test_rescal_relation_matrix_shape() {
586 let m = Rescal::new(5, 3, 8);
587 let mat = m.relation_matrix(0);
588 assert_eq!(mat.len(), 8);
589 assert_eq!(mat[0].len(), 8);
590 }
591
592 #[test]
593 fn test_rescal_update_changes_embeddings() {
594 let mut m = Rescal::new(5, 3, 8);
595 let before = m.entity_emb[0].clone();
596 m.update(0, 0, 1, 1.0, 0.01);
597 let changed = m.entity_emb[0]
598 .iter()
599 .zip(before.iter())
600 .any(|(a, b)| (a - b).abs() > 1e-9);
601 assert!(changed, "update should modify entity embeddings");
602 }
603
604 #[test]
605 fn test_rescal_training_converges() {
606 let mut m = Rescal::new(5, 2, 4);
607 let initial_score = m.score(0, 0, 1);
609 for _ in 0..500 {
610 m.update(0, 0, 1, 1.0, 0.001);
611 }
612 let final_score = m.score(0, 0, 1);
613 assert!(
614 final_score > initial_score,
615 "RESCAL score should increase toward label"
616 );
617 }
618
619 #[test]
620 fn test_rescal_antisymmetric_scores() {
621 let m = Rescal::new(5, 3, 8);
622 let s_fwd = m.score(0, 0, 1);
624 let s_bwd = m.score(1, 0, 0);
625 assert!(s_fwd.is_finite() && s_bwd.is_finite());
627 }
628
629 #[test]
630 fn test_rescal_different_relations_give_different_scores() {
631 let m = Rescal::new(5, 4, 8);
632 let s0 = m.score(0, 0, 1);
633 let s1 = m.score(0, 1, 1);
634 let s2 = m.score(0, 2, 1);
635 assert!(
637 (s0 - s1).abs() > 1e-6 || (s1 - s2).abs() > 1e-6,
638 "Different relations should produce different scores"
639 );
640 }
641
642 #[test]
643 fn test_all_models_score_interface() {
644 let rotate = RotatEPlus::new(4, 2, 8);
645 let pairre = PairRE::new(4, 2, 8);
646 let rescal = Rescal::new(4, 2, 8);
647
648 assert!(rotate.score(0, 0, 1).is_finite());
650 assert!(pairre.score(0, 0, 1).is_finite());
651 assert!(rescal.score(0, 0, 1).is_finite());
652 }
653}