1use std::collections::HashMap;
14use std::fmt;
15
16#[derive(Debug)]
22pub enum KgError {
23 NotTrained,
25 UnknownEntity(EntityId),
27 UnknownRelation(RelationId),
29 InvalidDimension,
31 NoTrainingData,
33 NumericalError(String),
35 InvalidTopK,
37}
38
39impl fmt::Display for KgError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 KgError::NotTrained => write!(f, "model has not been trained"),
43 KgError::UnknownEntity(id) => write!(f, "unknown entity id {id}"),
44 KgError::UnknownRelation(id) => write!(f, "unknown relation id {id}"),
45 KgError::InvalidDimension => write!(f, "embedding dimension must be > 0"),
46 KgError::NoTrainingData => write!(f, "no training triples provided"),
47 KgError::NumericalError(msg) => write!(f, "numerical error: {msg}"),
48 KgError::InvalidTopK => write!(f, "top_k must be > 0"),
49 }
50 }
51}
52
53impl std::error::Error for KgError {}
54
55pub type KgResult<T> = Result<T, KgError>;
57
58pub type EntityId = usize;
64pub type RelationId = usize;
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct KgTriple {
70 pub head: EntityId,
71 pub relation: RelationId,
72 pub tail: EntityId,
73}
74
75impl KgTriple {
76 pub fn new(head: EntityId, relation: RelationId, tail: EntityId) -> Self {
78 Self {
79 head,
80 relation,
81 tail,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct KgEmbeddingConfig {
89 pub embedding_dim: usize,
91 pub learning_rate: f64,
93 pub num_epochs: usize,
95 pub batch_size: usize,
97 pub neg_samples: usize,
99 pub margin: f64,
101 pub regularization: f64,
103 pub seed: u64,
105}
106
107impl Default for KgEmbeddingConfig {
108 fn default() -> Self {
109 Self {
110 embedding_dim: 50,
111 learning_rate: 0.01,
112 num_epochs: 100,
113 batch_size: 32,
114 neg_samples: 1,
115 margin: 1.0,
116 regularization: 1e-4,
117 seed: 42,
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct KgEmbeddings {
125 pub entity_embeddings: Vec<Vec<f64>>,
127 pub relation_embeddings: Vec<Vec<f64>>,
129 pub entity_to_id: HashMap<String, EntityId>,
131 pub relation_to_id: HashMap<String, RelationId>,
133}
134
135#[derive(Debug, Clone)]
137pub struct TrainingHistory {
138 pub losses: Vec<f64>,
140 pub final_loss: f64,
142 pub epochs_trained: usize,
144}
145
146pub trait KgModel {
152 fn score(&self, triple: &KgTriple) -> KgResult<f64>;
154
155 fn predict_tail(
157 &self,
158 head: EntityId,
159 relation: RelationId,
160 top_k: usize,
161 ) -> KgResult<Vec<(EntityId, f64)>>;
162
163 fn predict_head(
165 &self,
166 relation: RelationId,
167 tail: EntityId,
168 top_k: usize,
169 ) -> KgResult<Vec<(EntityId, f64)>>;
170}
171
172#[derive(Debug, Clone)]
179struct Lcg {
180 state: u64,
181}
182
183impl Lcg {
184 fn new(seed: u64) -> Self {
185 Self {
186 state: seed.wrapping_add(1),
187 }
188 }
189
190 fn next_f64(&mut self) -> f64 {
192 self.state = self
193 .state
194 .wrapping_mul(6_364_136_223_846_793_005)
195 .wrapping_add(1_442_695_040_888_963_407);
196 (self.state >> 11) as f64 / (1u64 << 53) as f64
197 }
198
199 fn next_usize(&mut self, n: usize) -> usize {
201 (self.next_f64() * n as f64) as usize % n
202 }
203}
204
205fn l2_norm(v: &[f64]) -> f64 {
210 v.iter().map(|x| x * x).sum::<f64>().sqrt()
211}
212
213fn l2_dist(a: &[f64], b: &[f64]) -> f64 {
214 a.iter()
215 .zip(b.iter())
216 .map(|(x, y)| (x - y).powi(2))
217 .sum::<f64>()
218 .sqrt()
219}
220
221fn dot(a: &[f64], b: &[f64]) -> f64 {
222 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
223}
224
225fn clamp_vec(v: &mut [f64], lo: f64, hi: f64) {
226 for x in v.iter_mut() {
227 *x = x.clamp(lo, hi);
228 }
229}
230
231fn normalize_vec(v: &mut [f64]) {
232 let norm = l2_norm(v);
233 if norm > 1e-12 {
234 for x in v.iter_mut() {
235 *x /= norm;
236 }
237 }
238}
239
240fn corrupt_triple(
246 triple: &KgTriple,
247 num_entities: usize,
248 positive_set: &std::collections::HashSet<(usize, usize, usize)>,
249 rng: &mut Lcg,
250) -> KgTriple {
251 for _ in 0..20 {
253 let corrupt_head = rng.next_usize(2) == 0;
254 let candidate = if corrupt_head {
255 let new_head = rng.next_usize(num_entities);
256 KgTriple::new(new_head, triple.relation, triple.tail)
257 } else {
258 let new_tail = rng.next_usize(num_entities);
259 KgTriple::new(triple.head, triple.relation, new_tail)
260 };
261 if !positive_set.contains(&(candidate.head, candidate.relation, candidate.tail)) {
262 return candidate;
263 }
264 }
265 let new_tail = (triple.tail + 1) % num_entities;
267 KgTriple::new(triple.head, triple.relation, new_tail)
268}
269
270#[derive(Debug, Clone)]
279pub struct TransE {
280 pub config: KgEmbeddingConfig,
281 pub embeddings: Option<KgEmbeddings>,
282 num_entities: usize,
283 num_relations: usize,
284}
285
286impl TransE {
287 pub fn new(config: KgEmbeddingConfig) -> Self {
289 Self {
290 config,
291 embeddings: None,
292 num_entities: 0,
293 num_relations: 0,
294 }
295 }
296
297 pub fn train(
303 &mut self,
304 triples: &[KgTriple],
305 num_entities: usize,
306 num_relations: usize,
307 ) -> KgResult<TrainingHistory> {
308 if triples.is_empty() {
309 return Err(KgError::NoTrainingData);
310 }
311 if self.config.embedding_dim == 0 {
312 return Err(KgError::InvalidDimension);
313 }
314 self.num_entities = num_entities;
315 self.num_relations = num_relations;
316
317 let dim = self.config.embedding_dim;
318 let mut rng = Lcg::new(self.config.seed);
319
320 let bound = 6.0 / (dim as f64).sqrt();
322 let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
323 .map(|_| {
324 let mut v: Vec<f64> = (0..dim)
325 .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
326 .collect();
327 normalize_vec(&mut v);
328 v
329 })
330 .collect();
331
332 let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
334 .map(|_| {
335 (0..dim)
336 .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
337 .collect()
338 })
339 .collect();
340
341 let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
343 .iter()
344 .map(|t| (t.head, t.relation, t.tail))
345 .collect();
346
347 let lr = self.config.learning_rate;
348 let margin = self.config.margin;
349 let reg = self.config.regularization;
350 let mut losses = Vec::with_capacity(self.config.num_epochs);
351
352 for _epoch in 0..self.config.num_epochs {
353 let mut epoch_loss = 0.0_f64;
354 let mut count = 0usize;
355
356 for pos in triples {
357 for _ in 0..self.config.neg_samples {
358 let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
359
360 let h_pos = &ent_emb[pos.head];
361 let r = &rel_emb[pos.relation];
362 let t_pos = &ent_emb[pos.tail];
363 let h_neg = &ent_emb[neg.head];
364 let t_neg = &ent_emb[neg.tail];
365
366 let pos_diff: Vec<f64> = (0..dim).map(|i| h_pos[i] + r[i] - t_pos[i]).collect();
368 let neg_diff: Vec<f64> = (0..dim).map(|i| h_neg[i] + r[i] - t_neg[i]).collect();
369
370 let d_pos = l2_norm(&pos_diff);
371 let d_neg = l2_norm(&neg_diff);
372
373 let loss = (margin + d_pos - d_neg).max(0.0);
374 epoch_loss += loss;
375 count += 1;
376
377 if loss > 0.0 {
378 let grad_pos: Vec<f64> = if d_pos > 1e-12 {
380 pos_diff.iter().map(|x| x / d_pos).collect()
381 } else {
382 vec![0.0; dim]
383 };
384 let grad_neg: Vec<f64> = if d_neg > 1e-12 {
385 neg_diff.iter().map(|x| x / d_neg).collect()
386 } else {
387 vec![0.0; dim]
388 };
389
390 for i in 0..dim {
392 let g = grad_pos[i];
393 ent_emb[pos.head][i] -= lr * (g + reg * ent_emb[pos.head][i]);
394 rel_emb[pos.relation][i] -= lr * (g + reg * rel_emb[pos.relation][i]);
395 ent_emb[pos.tail][i] += lr * (g - reg * ent_emb[pos.tail][i]);
396 }
397
398 for i in 0..dim {
400 let g = grad_neg[i];
401 ent_emb[neg.head][i] += lr * (g + reg * ent_emb[neg.head][i]);
402 ent_emb[neg.tail][i] -= lr * (g - reg * ent_emb[neg.tail][i]);
403 }
404 }
405
406 normalize_vec(&mut ent_emb[pos.head]);
408 normalize_vec(&mut ent_emb[pos.tail]);
409 normalize_vec(&mut ent_emb[neg.head]);
410 normalize_vec(&mut ent_emb[neg.tail]);
411 }
412 }
413
414 let mean_loss = if count > 0 {
415 epoch_loss / count as f64
416 } else {
417 0.0
418 };
419 losses.push(mean_loss);
420 }
421
422 let final_loss = losses.last().copied().unwrap_or(0.0);
423 let epochs_trained = losses.len();
424
425 self.embeddings = Some(KgEmbeddings {
426 entity_embeddings: ent_emb,
427 relation_embeddings: rel_emb,
428 entity_to_id: HashMap::new(),
429 relation_to_id: HashMap::new(),
430 });
431
432 Ok(TrainingHistory {
433 losses,
434 final_loss,
435 epochs_trained,
436 })
437 }
438
439 pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
441 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
442 let h = emb
443 .entity_embeddings
444 .get(triple.head)
445 .ok_or(KgError::UnknownEntity(triple.head))?;
446 let r = emb
447 .relation_embeddings
448 .get(triple.relation)
449 .ok_or(KgError::UnknownRelation(triple.relation))?;
450 let t = emb
451 .entity_embeddings
452 .get(triple.tail)
453 .ok_or(KgError::UnknownEntity(triple.tail))?;
454
455 Ok(-Self::score_fn(h, r, t))
456 }
457
458 pub fn predict_tail(
460 &self,
461 head: EntityId,
462 relation: RelationId,
463 top_k: usize,
464 ) -> KgResult<Vec<(EntityId, f64)>> {
465 if top_k == 0 {
466 return Err(KgError::InvalidTopK);
467 }
468 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
469 let h = emb
470 .entity_embeddings
471 .get(head)
472 .ok_or(KgError::UnknownEntity(head))?;
473 let r = emb
474 .relation_embeddings
475 .get(relation)
476 .ok_or(KgError::UnknownRelation(relation))?;
477
478 let mut scored: Vec<(EntityId, f64)> = emb
479 .entity_embeddings
480 .iter()
481 .enumerate()
482 .map(|(id, t)| (id, -Self::score_fn(h, r, t)))
483 .collect();
484
485 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
486 scored.truncate(top_k);
487 Ok(scored)
488 }
489
490 pub fn predict_head(
492 &self,
493 relation: RelationId,
494 tail: EntityId,
495 top_k: usize,
496 ) -> KgResult<Vec<(EntityId, f64)>> {
497 if top_k == 0 {
498 return Err(KgError::InvalidTopK);
499 }
500 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
501 let r = emb
502 .relation_embeddings
503 .get(relation)
504 .ok_or(KgError::UnknownRelation(relation))?;
505 let t = emb
506 .entity_embeddings
507 .get(tail)
508 .ok_or(KgError::UnknownEntity(tail))?;
509
510 let mut scored: Vec<(EntityId, f64)> = emb
511 .entity_embeddings
512 .iter()
513 .enumerate()
514 .map(|(id, h)| (id, -Self::score_fn(h, r, t)))
515 .collect();
516
517 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
518 scored.truncate(top_k);
519 Ok(scored)
520 }
521
522 pub fn normalize_entities(&mut self) {
524 if let Some(ref mut emb) = self.embeddings {
525 for v in emb.entity_embeddings.iter_mut() {
526 normalize_vec(v);
527 }
528 }
529 }
530
531 fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
533 let diff: Vec<f64> = (0..h.len()).map(|i| h[i] + r[i] - t[i]).collect();
534 l2_norm(&diff)
535 }
536}
537
538impl KgModel for TransE {
539 fn score(&self, triple: &KgTriple) -> KgResult<f64> {
540 self.score(triple)
541 }
542
543 fn predict_tail(
544 &self,
545 head: EntityId,
546 relation: RelationId,
547 top_k: usize,
548 ) -> KgResult<Vec<(EntityId, f64)>> {
549 self.predict_tail(head, relation, top_k)
550 }
551
552 fn predict_head(
553 &self,
554 relation: RelationId,
555 tail: EntityId,
556 top_k: usize,
557 ) -> KgResult<Vec<(EntityId, f64)>> {
558 self.predict_head(relation, tail, top_k)
559 }
560}
561
562#[derive(Debug, Clone)]
571pub struct DistMult {
572 pub config: KgEmbeddingConfig,
573 pub embeddings: Option<KgEmbeddings>,
574 num_entities: usize,
575 num_relations: usize,
576}
577
578impl DistMult {
579 pub fn new(config: KgEmbeddingConfig) -> Self {
581 Self {
582 config,
583 embeddings: None,
584 num_entities: 0,
585 num_relations: 0,
586 }
587 }
588
589 pub fn train(
591 &mut self,
592 triples: &[KgTriple],
593 num_entities: usize,
594 num_relations: usize,
595 ) -> KgResult<TrainingHistory> {
596 if triples.is_empty() {
597 return Err(KgError::NoTrainingData);
598 }
599 if self.config.embedding_dim == 0 {
600 return Err(KgError::InvalidDimension);
601 }
602 self.num_entities = num_entities;
603 self.num_relations = num_relations;
604
605 let dim = self.config.embedding_dim;
606 let mut rng = Lcg::new(self.config.seed);
607 let bound = 1.0 / (dim as f64).sqrt();
608
609 let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
610 .map(|_| {
611 (0..dim)
612 .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
613 .collect()
614 })
615 .collect();
616 let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
617 .map(|_| {
618 (0..dim)
619 .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
620 .collect()
621 })
622 .collect();
623
624 let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
625 .iter()
626 .map(|t| (t.head, t.relation, t.tail))
627 .collect();
628
629 let lr = self.config.learning_rate;
630 let reg = self.config.regularization;
631 let mut losses = Vec::with_capacity(self.config.num_epochs);
632
633 for _epoch in 0..self.config.num_epochs {
634 let mut epoch_loss = 0.0_f64;
635 let mut count = 0usize;
636
637 for pos in triples {
638 {
640 let s = Self::score_fn(
641 &ent_emb[pos.head],
642 &rel_emb[pos.relation],
643 &ent_emb[pos.tail],
644 );
645 let sig = sigmoid(s);
646 let loss = -sig.ln().max(-100.0);
647 epoch_loss += loss;
648 count += 1;
649
650 let g = -(1.0 - sig);
652 for i in 0..dim {
653 let h_i = ent_emb[pos.head][i];
654 let r_i = rel_emb[pos.relation][i];
655 let t_i = ent_emb[pos.tail][i];
656 ent_emb[pos.head][i] -= lr * (g * r_i * t_i + reg * h_i);
657 rel_emb[pos.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
658 ent_emb[pos.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
659 }
660 clamp_vec(&mut ent_emb[pos.head], -10.0, 10.0);
661 clamp_vec(&mut rel_emb[pos.relation], -10.0, 10.0);
662 clamp_vec(&mut ent_emb[pos.tail], -10.0, 10.0);
663 }
664
665 for _ in 0..self.config.neg_samples {
666 let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
667 let s = Self::score_fn(
668 &ent_emb[neg.head],
669 &rel_emb[neg.relation],
670 &ent_emb[neg.tail],
671 );
672 let sig = sigmoid(-s);
673 let loss = -sig.ln().max(-100.0);
674 epoch_loss += loss;
675 count += 1;
676
677 let g = 1.0 - sig; for i in 0..dim {
679 let h_i = ent_emb[neg.head][i];
680 let r_i = rel_emb[neg.relation][i];
681 let t_i = ent_emb[neg.tail][i];
682 ent_emb[neg.head][i] -= lr * (g * r_i * t_i + reg * h_i);
683 rel_emb[neg.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
684 ent_emb[neg.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
685 }
686 clamp_vec(&mut ent_emb[neg.head], -10.0, 10.0);
687 clamp_vec(&mut ent_emb[neg.tail], -10.0, 10.0);
688 }
689 }
690
691 let mean_loss = if count > 0 {
692 epoch_loss / count as f64
693 } else {
694 0.0
695 };
696 losses.push(mean_loss);
697 }
698
699 let final_loss = losses.last().copied().unwrap_or(0.0);
700 let epochs_trained = losses.len();
701
702 self.embeddings = Some(KgEmbeddings {
703 entity_embeddings: ent_emb,
704 relation_embeddings: rel_emb,
705 entity_to_id: HashMap::new(),
706 relation_to_id: HashMap::new(),
707 });
708
709 Ok(TrainingHistory {
710 losses,
711 final_loss,
712 epochs_trained,
713 })
714 }
715
716 pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
718 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
719 let h = emb
720 .entity_embeddings
721 .get(triple.head)
722 .ok_or(KgError::UnknownEntity(triple.head))?;
723 let r = emb
724 .relation_embeddings
725 .get(triple.relation)
726 .ok_or(KgError::UnknownRelation(triple.relation))?;
727 let t = emb
728 .entity_embeddings
729 .get(triple.tail)
730 .ok_or(KgError::UnknownEntity(triple.tail))?;
731 Ok(Self::score_fn(h, r, t))
732 }
733
734 pub fn predict_tail(
736 &self,
737 head: EntityId,
738 relation: RelationId,
739 top_k: usize,
740 ) -> KgResult<Vec<(EntityId, f64)>> {
741 if top_k == 0 {
742 return Err(KgError::InvalidTopK);
743 }
744 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
745 let h = emb
746 .entity_embeddings
747 .get(head)
748 .ok_or(KgError::UnknownEntity(head))?;
749 let r = emb
750 .relation_embeddings
751 .get(relation)
752 .ok_or(KgError::UnknownRelation(relation))?;
753
754 let mut scored: Vec<(EntityId, f64)> = emb
755 .entity_embeddings
756 .iter()
757 .enumerate()
758 .map(|(id, t)| (id, Self::score_fn(h, r, t)))
759 .collect();
760 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
761 scored.truncate(top_k);
762 Ok(scored)
763 }
764
765 pub fn predict_head(
767 &self,
768 relation: RelationId,
769 tail: EntityId,
770 top_k: usize,
771 ) -> KgResult<Vec<(EntityId, f64)>> {
772 if top_k == 0 {
773 return Err(KgError::InvalidTopK);
774 }
775 let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
776 let r = emb
777 .relation_embeddings
778 .get(relation)
779 .ok_or(KgError::UnknownRelation(relation))?;
780 let t = emb
781 .entity_embeddings
782 .get(tail)
783 .ok_or(KgError::UnknownEntity(tail))?;
784
785 let mut scored: Vec<(EntityId, f64)> = emb
786 .entity_embeddings
787 .iter()
788 .enumerate()
789 .map(|(id, h)| (id, Self::score_fn(h, r, t)))
790 .collect();
791 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
792 scored.truncate(top_k);
793 Ok(scored)
794 }
795
796 fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
798 h.iter()
799 .zip(r.iter())
800 .zip(t.iter())
801 .map(|((hi, ri), ti)| hi * ri * ti)
802 .sum()
803 }
804}
805
806impl KgModel for DistMult {
807 fn score(&self, triple: &KgTriple) -> KgResult<f64> {
808 self.score(triple)
809 }
810
811 fn predict_tail(
812 &self,
813 head: EntityId,
814 relation: RelationId,
815 top_k: usize,
816 ) -> KgResult<Vec<(EntityId, f64)>> {
817 self.predict_tail(head, relation, top_k)
818 }
819
820 fn predict_head(
821 &self,
822 relation: RelationId,
823 tail: EntityId,
824 top_k: usize,
825 ) -> KgResult<Vec<(EntityId, f64)>> {
826 self.predict_head(relation, tail, top_k)
827 }
828}
829
830#[derive(Debug, Clone)]
841pub struct RotatE {
842 pub config: KgEmbeddingConfig,
843 pub entity_re: Option<Vec<Vec<f64>>>,
845 pub entity_im: Option<Vec<Vec<f64>>>,
847 pub relation_phases: Option<Vec<Vec<f64>>>,
849 num_entities: usize,
850 num_relations: usize,
851}
852
853impl RotatE {
854 pub fn new(config: KgEmbeddingConfig) -> Self {
856 Self {
857 config,
858 entity_re: None,
859 entity_im: None,
860 relation_phases: None,
861 num_entities: 0,
862 num_relations: 0,
863 }
864 }
865
866 pub fn train(
868 &mut self,
869 triples: &[KgTriple],
870 num_entities: usize,
871 num_relations: usize,
872 ) -> KgResult<TrainingHistory> {
873 if triples.is_empty() {
874 return Err(KgError::NoTrainingData);
875 }
876 if self.config.embedding_dim == 0 {
877 return Err(KgError::InvalidDimension);
878 }
879
880 self.num_entities = num_entities;
881 self.num_relations = num_relations;
882
883 let half_dim = (self.config.embedding_dim + 1) / 2;
885 let mut rng = Lcg::new(self.config.seed);
886 let pi = std::f64::consts::PI;
887
888 let mut ent_re: Vec<Vec<f64>> = (0..num_entities)
890 .map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
891 .collect();
892 let mut ent_im: Vec<Vec<f64>> = (0..num_entities)
893 .map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
894 .collect();
895
896 for i in 0..num_entities {
898 for k in 0..half_dim {
899 let norm = (ent_re[i][k].powi(2) + ent_im[i][k].powi(2))
900 .sqrt()
901 .max(1e-12);
902 ent_re[i][k] /= norm;
903 ent_im[i][k] /= norm;
904 }
905 }
906
907 let mut rel_phases: Vec<Vec<f64>> = (0..num_relations)
909 .map(|_| {
910 (0..half_dim)
911 .map(|_| (rng.next_f64() * 2.0 - 1.0) * pi)
912 .collect()
913 })
914 .collect();
915
916 let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
917 .iter()
918 .map(|t| (t.head, t.relation, t.tail))
919 .collect();
920
921 let lr = self.config.learning_rate;
922 let margin = self.config.margin;
923 let reg = self.config.regularization;
924 let mut losses = Vec::with_capacity(self.config.num_epochs);
925
926 for _epoch in 0..self.config.num_epochs {
927 let mut epoch_loss = 0.0_f64;
928 let mut count = 0usize;
929
930 for pos in triples {
931 for _ in 0..self.config.neg_samples {
932 let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
933
934 let d_pos = Self::dist_fn(
935 &ent_re[pos.head],
936 &ent_im[pos.head],
937 &rel_phases[pos.relation],
938 &ent_re[pos.tail],
939 &ent_im[pos.tail],
940 );
941 let d_neg = Self::dist_fn(
942 &ent_re[neg.head],
943 &ent_im[neg.head],
944 &rel_phases[neg.relation],
945 &ent_re[neg.tail],
946 &ent_im[neg.tail],
947 );
948
949 let loss = (margin + d_pos - d_neg).max(0.0);
950 epoch_loss += loss;
951 count += 1;
952
953 if loss > 0.0 && d_pos > 1e-12 {
954 let r_re: Vec<f64> = rel_phases[pos.relation]
956 .iter()
957 .map(|&ph| ph.cos())
958 .collect();
959 let r_im: Vec<f64> = rel_phases[pos.relation]
960 .iter()
961 .map(|&ph| ph.sin())
962 .collect();
963
964 for k in 0..half_dim {
965 let (res_re, res_im) = Self::complex_multiply(
966 ent_re[pos.head][k],
967 ent_im[pos.head][k],
968 r_re[k],
969 r_im[k],
970 );
971 let err_re = res_re - ent_re[pos.tail][k];
972 let err_im = res_im - ent_im[pos.tail][k];
973
974 let g_scale = 1.0 / d_pos;
976
977 let d_h_re = g_scale * (err_re * r_re[k] + err_im * r_im[k]);
979 let d_h_im = g_scale * (err_im * r_re[k] - err_re * r_im[k]);
981 let d_ph = g_scale
983 * ((-ent_re[pos.head][k] * r_im[k]
984 + ent_im[pos.head][k] * r_re[k])
985 * err_re
986 + (-ent_re[pos.head][k] * r_re[k]
987 - ent_im[pos.head][k] * r_im[k])
988 * err_im);
989 let d_t_re = g_scale * (-err_re);
991 let d_t_im = g_scale * (-err_im);
993
994 ent_re[pos.head][k] -= lr * (d_h_re + reg * ent_re[pos.head][k]);
995 ent_im[pos.head][k] -= lr * (d_h_im + reg * ent_im[pos.head][k]);
996 rel_phases[pos.relation][k] -=
997 lr * (d_ph + reg * rel_phases[pos.relation][k]);
998 ent_re[pos.tail][k] -= lr * (d_t_re + reg * ent_re[pos.tail][k]);
999 ent_im[pos.tail][k] -= lr * (d_t_im + reg * ent_im[pos.tail][k]);
1000 }
1001
1002 for ph in rel_phases[pos.relation].iter_mut() {
1004 *ph = ph.clamp(-2.0 * pi, 2.0 * pi);
1005 }
1006 }
1007 }
1008 }
1009
1010 let mean_loss = if count > 0 {
1011 epoch_loss / count as f64
1012 } else {
1013 0.0
1014 };
1015 losses.push(mean_loss);
1016 }
1017
1018 let final_loss = losses.last().copied().unwrap_or(0.0);
1019 let epochs_trained = losses.len();
1020
1021 self.entity_re = Some(ent_re);
1022 self.entity_im = Some(ent_im);
1023 self.relation_phases = Some(rel_phases);
1024
1025 Ok(TrainingHistory {
1026 losses,
1027 final_loss,
1028 epochs_trained,
1029 })
1030 }
1031
1032 pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
1034 let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1035 let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1036 let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1037
1038 let h_re = ent_re
1039 .get(triple.head)
1040 .ok_or(KgError::UnknownEntity(triple.head))?;
1041 let h_im = ent_im
1042 .get(triple.head)
1043 .ok_or(KgError::UnknownEntity(triple.head))?;
1044 let ph = phases
1045 .get(triple.relation)
1046 .ok_or(KgError::UnknownRelation(triple.relation))?;
1047 let t_re = ent_re
1048 .get(triple.tail)
1049 .ok_or(KgError::UnknownEntity(triple.tail))?;
1050 let t_im = ent_im
1051 .get(triple.tail)
1052 .ok_or(KgError::UnknownEntity(triple.tail))?;
1053
1054 Ok(-Self::dist_fn(h_re, h_im, ph, t_re, t_im))
1055 }
1056
1057 pub fn predict_tail(
1059 &self,
1060 head: EntityId,
1061 relation: RelationId,
1062 top_k: usize,
1063 ) -> KgResult<Vec<(EntityId, f64)>> {
1064 if top_k == 0 {
1065 return Err(KgError::InvalidTopK);
1066 }
1067 let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1068 let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1069 let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1070
1071 let h_re = ent_re.get(head).ok_or(KgError::UnknownEntity(head))?;
1072 let h_im = ent_im.get(head).ok_or(KgError::UnknownEntity(head))?;
1073 let ph = phases
1074 .get(relation)
1075 .ok_or(KgError::UnknownRelation(relation))?;
1076
1077 let num = ent_re.len();
1078 let mut scored: Vec<(EntityId, f64)> = (0..num)
1079 .map(|id| (id, -Self::dist_fn(h_re, h_im, ph, &ent_re[id], &ent_im[id])))
1080 .collect();
1081 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1082 scored.truncate(top_k);
1083 Ok(scored)
1084 }
1085
1086 pub fn predict_head(
1088 &self,
1089 relation: RelationId,
1090 tail: EntityId,
1091 top_k: usize,
1092 ) -> KgResult<Vec<(EntityId, f64)>> {
1093 if top_k == 0 {
1094 return Err(KgError::InvalidTopK);
1095 }
1096 let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1097 let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1098 let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1099
1100 let ph = phases
1101 .get(relation)
1102 .ok_or(KgError::UnknownRelation(relation))?;
1103 let t_re = ent_re.get(tail).ok_or(KgError::UnknownEntity(tail))?;
1104 let t_im = ent_im.get(tail).ok_or(KgError::UnknownEntity(tail))?;
1105
1106 let num = ent_re.len();
1107 let mut scored: Vec<(EntityId, f64)> = (0..num)
1111 .map(|id| (id, -Self::dist_fn(&ent_re[id], &ent_im[id], ph, t_re, t_im)))
1112 .collect();
1113 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1114 scored.truncate(top_k);
1115 Ok(scored)
1116 }
1117
1118 fn dist_fn(h_re: &[f64], h_im: &[f64], phases: &[f64], t_re: &[f64], t_im: &[f64]) -> f64 {
1120 let sum_sq: f64 = phases
1121 .iter()
1122 .enumerate()
1123 .map(|(k, &ph)| {
1124 let (res_re, res_im) = Self::complex_multiply(h_re[k], h_im[k], ph.cos(), ph.sin());
1125 (res_re - t_re[k]).powi(2) + (res_im - t_im[k]).powi(2)
1126 })
1127 .sum();
1128 sum_sq.sqrt()
1129 }
1130
1131 pub fn complex_multiply(a_re: f64, a_im: f64, b_re: f64, b_im: f64) -> (f64, f64) {
1133 (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re)
1134 }
1135}
1136
1137impl KgModel for RotatE {
1138 fn score(&self, triple: &KgTriple) -> KgResult<f64> {
1139 self.score(triple)
1140 }
1141
1142 fn predict_tail(
1143 &self,
1144 head: EntityId,
1145 relation: RelationId,
1146 top_k: usize,
1147 ) -> KgResult<Vec<(EntityId, f64)>> {
1148 self.predict_tail(head, relation, top_k)
1149 }
1150
1151 fn predict_head(
1152 &self,
1153 relation: RelationId,
1154 tail: EntityId,
1155 top_k: usize,
1156 ) -> KgResult<Vec<(EntityId, f64)>> {
1157 self.predict_head(relation, tail, top_k)
1158 }
1159}
1160
1161pub struct LinkPredictionEvaluator;
1167
1168impl LinkPredictionEvaluator {
1169 pub fn hits_at_k(model: &dyn KgModel, test_triples: &[KgTriple], k: usize) -> f64 {
1172 if test_triples.is_empty() || k == 0 {
1173 return 0.0;
1174 }
1175 let hits: usize = test_triples
1176 .iter()
1177 .filter(|t| {
1178 model
1179 .predict_tail(t.head, t.relation, k)
1180 .map(|preds| preds.iter().any(|(eid, _)| *eid == t.tail))
1181 .unwrap_or(false)
1182 })
1183 .count();
1184 hits as f64 / test_triples.len() as f64
1185 }
1186
1187 pub fn mean_rank(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
1190 if test_triples.is_empty() {
1191 return 0.0;
1192 }
1193 let total: usize = test_triples
1194 .iter()
1195 .map(|t| {
1196 model
1197 .predict_tail(t.head, t.relation, num_entities)
1198 .map(|preds| {
1199 preds
1200 .iter()
1201 .position(|(eid, _)| *eid == t.tail)
1202 .map(|p| p + 1)
1203 .unwrap_or(num_entities + 1)
1204 })
1205 .unwrap_or(num_entities + 1)
1206 })
1207 .sum();
1208 total as f64 / test_triples.len() as f64
1209 }
1210
1211 pub fn mrr(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
1214 if test_triples.is_empty() {
1215 return 0.0;
1216 }
1217 let sum: f64 = test_triples
1218 .iter()
1219 .map(|t| {
1220 model
1221 .predict_tail(t.head, t.relation, num_entities)
1222 .map(|preds| {
1223 preds
1224 .iter()
1225 .position(|(eid, _)| *eid == t.tail)
1226 .map(|p| 1.0 / (p as f64 + 1.0))
1227 .unwrap_or(0.0)
1228 })
1229 .unwrap_or(0.0)
1230 })
1231 .sum();
1232 sum / test_triples.len() as f64
1233 }
1234}
1235
1236pub fn serialize_embeddings(emb: &KgEmbeddings) -> Vec<u8> {
1250 let mut out = String::new();
1251 out.push_str(&format!("ENTITIES {}\n", emb.entity_embeddings.len()));
1252 for row in &emb.entity_embeddings {
1253 let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
1254 out.push_str(&line.join(","));
1255 out.push('\n');
1256 }
1257 out.push_str(&format!("RELATIONS {}\n", emb.relation_embeddings.len()));
1258 for row in &emb.relation_embeddings {
1259 let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
1260 out.push_str(&line.join(","));
1261 out.push('\n');
1262 }
1263 out.into_bytes()
1264}
1265
1266pub fn deserialize_embeddings(data: &[u8]) -> Result<KgEmbeddings, KgError> {
1269 let text = std::str::from_utf8(data)
1270 .map_err(|e| KgError::NumericalError(format!("utf8 error: {e}")))?;
1271 let mut lines = text.lines();
1272
1273 let parse_section_header = |line: &str, prefix: &str| -> Result<usize, KgError> {
1274 let rest = line
1275 .strip_prefix(prefix)
1276 .ok_or_else(|| KgError::NumericalError(format!("expected '{prefix}', got '{line}'")))?;
1277 rest.trim()
1278 .parse::<usize>()
1279 .map_err(|e| KgError::NumericalError(e.to_string()))
1280 };
1281
1282 let parse_row = |line: &str| -> Result<Vec<f64>, KgError> {
1283 line.split(',')
1284 .map(|s| {
1285 s.trim()
1286 .parse::<f64>()
1287 .map_err(|e| KgError::NumericalError(e.to_string()))
1288 })
1289 .collect()
1290 };
1291
1292 let ent_header = lines
1293 .next()
1294 .ok_or(KgError::NumericalError("empty data".into()))?;
1295 let num_ent = parse_section_header(ent_header, "ENTITIES ")?;
1296 let mut entity_embeddings = Vec::with_capacity(num_ent);
1297 for _ in 0..num_ent {
1298 let line = lines
1299 .next()
1300 .ok_or(KgError::NumericalError("truncated entity data".into()))?;
1301 entity_embeddings.push(parse_row(line)?);
1302 }
1303
1304 let rel_header = lines
1305 .next()
1306 .ok_or(KgError::NumericalError("missing RELATIONS header".into()))?;
1307 let num_rel = parse_section_header(rel_header, "RELATIONS ")?;
1308 let mut relation_embeddings = Vec::with_capacity(num_rel);
1309 for _ in 0..num_rel {
1310 let line = lines
1311 .next()
1312 .ok_or(KgError::NumericalError("truncated relation data".into()))?;
1313 relation_embeddings.push(parse_row(line)?);
1314 }
1315
1316 Ok(KgEmbeddings {
1317 entity_embeddings,
1318 relation_embeddings,
1319 entity_to_id: HashMap::new(),
1320 relation_to_id: HashMap::new(),
1321 })
1322}
1323
1324#[inline]
1329fn sigmoid(x: f64) -> f64 {
1330 1.0 / (1.0 + (-x).exp())
1331}
1332
1333#[cfg(test)]
1342#[path = "kg_embeddings_tests.rs"]
1343mod tests;