1use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum DistanceMetric {
25 L1,
27 L2,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TransEConfig {
34 pub dim: usize,
36 pub learning_rate: f64,
38 pub margin: f64,
40 pub distance_metric: DistanceMetric,
42 pub max_epochs: usize,
44 pub num_negatives: usize,
46 pub normalize_embeddings: bool,
48}
49
50impl Default for TransEConfig {
51 fn default() -> Self {
52 Self {
53 dim: 50,
54 learning_rate: 0.01,
55 margin: 1.0,
56 distance_metric: DistanceMetric::L2,
57 max_epochs: 100,
58 num_negatives: 1,
59 normalize_embeddings: true,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct Triple {
71 pub head: usize,
72 pub relation: usize,
73 pub tail: usize,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ScoredTriple {
79 pub head: usize,
80 pub relation: usize,
81 pub tail: usize,
82 pub score: f64,
83}
84
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct TrainingStats {
88 pub loss_history: Vec<f64>,
90 pub epochs_completed: usize,
92 pub triples_processed: u64,
94}
95
96pub struct TransEModel {
102 config: TransEConfig,
103 entity_embeddings: HashMap<usize, Vec<f64>>,
105 relation_embeddings: HashMap<usize, Vec<f64>>,
107 entity_to_id: HashMap<String, usize>,
109 id_to_entity: HashMap<usize, String>,
111 relation_to_id: HashMap<String, usize>,
113 id_to_relation: HashMap<usize, String>,
115 known_triples: HashSet<Triple>,
117 stats: TrainingStats,
119 rng_state: u64,
121}
122
123impl TransEModel {
124 pub fn new() -> Self {
126 Self::with_config(TransEConfig::default())
127 }
128
129 pub fn with_config(config: TransEConfig) -> Self {
131 Self {
132 config,
133 entity_embeddings: HashMap::new(),
134 relation_embeddings: HashMap::new(),
135 entity_to_id: HashMap::new(),
136 id_to_entity: HashMap::new(),
137 relation_to_id: HashMap::new(),
138 id_to_relation: HashMap::new(),
139 known_triples: HashSet::new(),
140 stats: TrainingStats::default(),
141 rng_state: 12345,
142 }
143 }
144
145 pub fn config(&self) -> &TransEConfig {
147 &self.config
148 }
149
150 pub fn stats(&self) -> &TrainingStats {
152 &self.stats
153 }
154
155 pub fn entity_count(&self) -> usize {
157 self.entity_to_id.len()
158 }
159
160 pub fn relation_count(&self) -> usize {
162 self.relation_to_id.len()
163 }
164
165 pub fn triple_count(&self) -> usize {
167 self.known_triples.len()
168 }
169
170 pub fn add_entity(&mut self, name: impl Into<String>) -> usize {
172 let name = name.into();
173 if let Some(&id) = self.entity_to_id.get(&name) {
174 return id;
175 }
176 let id = self.entity_to_id.len();
177 self.entity_to_id.insert(name.clone(), id);
178 self.id_to_entity.insert(id, name);
179 let embedding = self.random_embedding();
181 self.entity_embeddings.insert(id, embedding);
182 id
183 }
184
185 pub fn add_relation(&mut self, name: impl Into<String>) -> usize {
187 let name = name.into();
188 if let Some(&id) = self.relation_to_id.get(&name) {
189 return id;
190 }
191 let id = self.relation_to_id.len();
192 self.relation_to_id.insert(name.clone(), id);
193 self.id_to_relation.insert(id, name);
194 let mut embedding = self.random_embedding();
196 let norm = l2_norm(&embedding);
198 if norm > 1e-12 {
199 for v in &mut embedding {
200 *v /= norm;
201 }
202 }
203 self.relation_embeddings.insert(id, embedding);
204 id
205 }
206
207 pub fn add_triple(
209 &mut self,
210 head: impl Into<String>,
211 relation: impl Into<String>,
212 tail: impl Into<String>,
213 ) {
214 let h = self.add_entity(head);
215 let r = self.add_relation(relation);
216 let t = self.add_entity(tail);
217 self.known_triples.insert(Triple {
218 head: h,
219 relation: r,
220 tail: t,
221 });
222 }
223
224 pub fn train(&mut self, epochs: usize) -> TrainingStats {
226 let num_epochs = epochs.min(self.config.max_epochs);
227 let triples: Vec<Triple> = self.known_triples.iter().cloned().collect();
228
229 if triples.is_empty() {
230 return self.stats.clone();
231 }
232
233 let num_entities = self.entity_to_id.len();
234
235 for _epoch in 0..num_epochs {
236 let mut epoch_loss = 0.0;
237
238 for triple in &triples {
239 let neg_triple = self.corrupt_triple(triple, num_entities);
241
242 let pos_score = self.score_triple_ids(triple.head, triple.relation, triple.tail);
244 let neg_score =
245 self.score_triple_ids(neg_triple.head, neg_triple.relation, neg_triple.tail);
246
247 let loss = (self.config.margin + pos_score - neg_score).max(0.0);
249 epoch_loss += loss;
250
251 if loss > 0.0 {
252 self.update_embeddings(triple, &neg_triple);
254 }
255
256 self.stats.triples_processed += 1;
257 }
258
259 let avg_loss = epoch_loss / triples.len() as f64;
260 self.stats.loss_history.push(avg_loss);
261 self.stats.epochs_completed += 1;
262
263 if self.config.normalize_embeddings {
265 self.normalize_entities();
266 }
267 }
268
269 self.stats.clone()
270 }
271
272 pub fn score(&self, head: &str, relation: &str, tail: &str) -> Option<f64> {
274 let h = self.entity_to_id.get(head)?;
275 let r = self.relation_to_id.get(relation)?;
276 let t = self.entity_to_id.get(tail)?;
277 Some(self.score_triple_ids(*h, *r, *t))
278 }
279
280 fn score_triple_ids(&self, head: usize, relation: usize, tail: usize) -> f64 {
282 let h = match self.entity_embeddings.get(&head) {
283 Some(e) => e,
284 None => return f64::MAX,
285 };
286 let r = match self.relation_embeddings.get(&relation) {
287 Some(e) => e,
288 None => return f64::MAX,
289 };
290 let t = match self.entity_embeddings.get(&tail) {
291 Some(e) => e,
292 None => return f64::MAX,
293 };
294
295 let dim = self.config.dim;
297 match self.config.distance_metric {
298 DistanceMetric::L1 => {
299 let mut dist = 0.0;
300 for i in 0..dim {
301 let hi = h.get(i).copied().unwrap_or(0.0);
302 let ri = r.get(i).copied().unwrap_or(0.0);
303 let ti = t.get(i).copied().unwrap_or(0.0);
304 dist += (hi + ri - ti).abs();
305 }
306 dist
307 }
308 DistanceMetric::L2 => {
309 let mut dist = 0.0;
310 for i in 0..dim {
311 let hi = h.get(i).copied().unwrap_or(0.0);
312 let ri = r.get(i).copied().unwrap_or(0.0);
313 let ti = t.get(i).copied().unwrap_or(0.0);
314 dist += (hi + ri - ti).powi(2);
315 }
316 dist.sqrt()
317 }
318 }
319 }
320
321 pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<ScoredTriple> {
323 let h = match self.entity_to_id.get(head) {
324 Some(&id) => id,
325 None => return Vec::new(),
326 };
327 let r = match self.relation_to_id.get(relation) {
328 Some(&id) => id,
329 None => return Vec::new(),
330 };
331
332 let mut scores: Vec<ScoredTriple> = self
333 .entity_to_id
334 .values()
335 .map(|&t_id| {
336 let score = self.score_triple_ids(h, r, t_id);
337 ScoredTriple {
338 head: h,
339 relation: r,
340 tail: t_id,
341 score,
342 }
343 })
344 .collect();
345
346 scores.sort_by(|a, b| {
347 a.score
348 .partial_cmp(&b.score)
349 .unwrap_or(std::cmp::Ordering::Equal)
350 });
351 scores.truncate(k);
352 scores
353 }
354
355 pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<ScoredTriple> {
357 let r = match self.relation_to_id.get(relation) {
358 Some(&id) => id,
359 None => return Vec::new(),
360 };
361 let t = match self.entity_to_id.get(tail) {
362 Some(&id) => id,
363 None => return Vec::new(),
364 };
365
366 let mut scores: Vec<ScoredTriple> = self
367 .entity_to_id
368 .values()
369 .map(|&h_id| {
370 let score = self.score_triple_ids(h_id, r, t);
371 ScoredTriple {
372 head: h_id,
373 relation: r,
374 tail: t,
375 score,
376 }
377 })
378 .collect();
379
380 scores.sort_by(|a, b| {
381 a.score
382 .partial_cmp(&b.score)
383 .unwrap_or(std::cmp::Ordering::Equal)
384 });
385 scores.truncate(k);
386 scores
387 }
388
389 pub fn entity_embedding(&self, name: &str) -> Option<&Vec<f64>> {
391 self.entity_to_id
392 .get(name)
393 .and_then(|id| self.entity_embeddings.get(id))
394 }
395
396 pub fn relation_embedding(&self, name: &str) -> Option<&Vec<f64>> {
398 self.relation_to_id
399 .get(name)
400 .and_then(|id| self.relation_embeddings.get(id))
401 }
402
403 pub fn nearest_entities(&self, query: &[f64], k: usize) -> Vec<(String, f64)> {
405 let mut dists: Vec<(String, f64)> = self
406 .entity_embeddings
407 .iter()
408 .map(|(&id, emb)| {
409 let dist = l2_distance(query, emb);
410 let name = self
411 .id_to_entity
412 .get(&id)
413 .cloned()
414 .unwrap_or_else(|| format!("entity_{id}"));
415 (name, dist)
416 })
417 .collect();
418
419 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
420 dists.truncate(k);
421 dists
422 }
423
424 pub fn entity_name(&self, id: usize) -> Option<&str> {
426 self.id_to_entity.get(&id).map(|s| s.as_str())
427 }
428
429 pub fn relation_name(&self, id: usize) -> Option<&str> {
431 self.id_to_relation.get(&id).map(|s| s.as_str())
432 }
433
434 fn random_embedding(&mut self) -> Vec<f64> {
437 let dim = self.config.dim;
438 (0..dim)
439 .map(|_| {
440 self.rng_state = self
441 .rng_state
442 .wrapping_mul(6364136223846793005)
443 .wrapping_add(1442695040888963407);
444 let val = ((self.rng_state >> 33) as f64) / (u32::MAX as f64);
445 (val - 0.5) * 2.0 / (dim as f64).sqrt()
446 })
447 .collect()
448 }
449
450 fn corrupt_triple(&mut self, triple: &Triple, num_entities: usize) -> Triple {
451 if num_entities == 0 {
452 return triple.clone();
453 }
454 self.rng_state = self
455 .rng_state
456 .wrapping_mul(6364136223846793005)
457 .wrapping_add(1442695040888963407);
458 let corrupt_head = (self.rng_state >> 33) % 2 == 0;
459 let random_entity = ((self.rng_state >> 17) as usize) % num_entities;
460
461 if corrupt_head {
462 Triple {
463 head: random_entity,
464 relation: triple.relation,
465 tail: triple.tail,
466 }
467 } else {
468 Triple {
469 head: triple.head,
470 relation: triple.relation,
471 tail: random_entity,
472 }
473 }
474 }
475
476 fn update_embeddings(&mut self, positive: &Triple, negative: &Triple) {
477 let lr = self.config.learning_rate;
478 let dim = self.config.dim;
479
480 let pos_h = self
483 .entity_embeddings
484 .get(&positive.head)
485 .cloned()
486 .unwrap_or_else(|| vec![0.0; dim]);
487 let pos_t = self
488 .entity_embeddings
489 .get(&positive.tail)
490 .cloned()
491 .unwrap_or_else(|| vec![0.0; dim]);
492 let neg_h = self
493 .entity_embeddings
494 .get(&negative.head)
495 .cloned()
496 .unwrap_or_else(|| vec![0.0; dim]);
497 let neg_t = self
498 .entity_embeddings
499 .get(&negative.tail)
500 .cloned()
501 .unwrap_or_else(|| vec![0.0; dim]);
502 let rel = self
503 .relation_embeddings
504 .get(&positive.relation)
505 .cloned()
506 .unwrap_or_else(|| vec![0.0; dim]);
507
508 let mut pos_grad = vec![0.0; dim];
510 let mut neg_grad = vec![0.0; dim];
511 for i in 0..dim {
512 pos_grad[i] = pos_h[i] + rel[i] - pos_t[i];
513 neg_grad[i] = neg_h[i] + rel[i] - neg_t[i];
514 }
515
516 let pos_norm = l2_norm(&pos_grad).max(1e-12);
518 let neg_norm = l2_norm(&neg_grad).max(1e-12);
519
520 if let Some(h_emb) = self.entity_embeddings.get_mut(&positive.head) {
522 for i in 0..dim {
523 h_emb[i] -= lr * pos_grad[i] / pos_norm;
524 }
525 }
526
527 if let Some(t_emb) = self.entity_embeddings.get_mut(&positive.tail) {
529 for i in 0..dim {
530 t_emb[i] += lr * pos_grad[i] / pos_norm;
531 }
532 }
533
534 if let Some(h_emb) = self.entity_embeddings.get_mut(&negative.head) {
536 for i in 0..dim {
537 h_emb[i] += lr * neg_grad[i] / neg_norm;
538 }
539 }
540
541 if let Some(t_emb) = self.entity_embeddings.get_mut(&negative.tail) {
543 for i in 0..dim {
544 t_emb[i] -= lr * neg_grad[i] / neg_norm;
545 }
546 }
547
548 if let Some(r_emb) = self.relation_embeddings.get_mut(&positive.relation) {
550 for i in 0..dim {
551 r_emb[i] -= lr * (pos_grad[i] / pos_norm - neg_grad[i] / neg_norm);
552 }
553 }
554 }
555
556 fn normalize_entities(&mut self) {
557 for emb in self.entity_embeddings.values_mut() {
558 let norm = l2_norm(emb);
559 if norm > 1.0 {
560 for v in emb.iter_mut() {
561 *v /= norm;
562 }
563 }
564 }
565 }
566}
567
568impl Default for TransEModel {
569 fn default() -> Self {
570 Self::new()
571 }
572}
573
574fn l2_norm(v: &[f64]) -> f64 {
577 v.iter().map(|x| x * x).sum::<f64>().sqrt()
578}
579
580fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
581 a.iter()
582 .zip(b.iter())
583 .map(|(x, y)| (x - y).powi(2))
584 .sum::<f64>()
585 .sqrt()
586}
587
588#[cfg(test)]
593mod tests {
594 use super::*;
595
596 fn sample_model() -> TransEModel {
597 let mut model = TransEModel::with_config(TransEConfig {
598 dim: 10,
599 learning_rate: 0.01,
600 margin: 1.0,
601 max_epochs: 50,
602 ..Default::default()
603 });
604 model.add_triple("alice", "knows", "bob");
605 model.add_triple("bob", "knows", "charlie");
606 model.add_triple("alice", "likes", "music");
607 model.add_triple("bob", "likes", "sports");
608 model.add_triple("charlie", "likes", "music");
609 model
610 }
611
612 #[test]
615 fn test_default_config() {
616 let config = TransEConfig::default();
617 assert_eq!(config.dim, 50);
618 assert_eq!(config.distance_metric, DistanceMetric::L2);
619 assert!(config.normalize_embeddings);
620 }
621
622 #[test]
625 fn test_add_entity() {
626 let mut model = TransEModel::new();
627 let id = model.add_entity("alice");
628 assert_eq!(id, 0);
629 assert_eq!(model.entity_count(), 1);
630 }
631
632 #[test]
633 fn test_add_entity_idempotent() {
634 let mut model = TransEModel::new();
635 let id1 = model.add_entity("alice");
636 let id2 = model.add_entity("alice");
637 assert_eq!(id1, id2);
638 assert_eq!(model.entity_count(), 1);
639 }
640
641 #[test]
642 fn test_add_relation() {
643 let mut model = TransEModel::new();
644 let id = model.add_relation("knows");
645 assert_eq!(id, 0);
646 assert_eq!(model.relation_count(), 1);
647 }
648
649 #[test]
650 fn test_add_triple() {
651 let model = sample_model();
652 assert_eq!(model.triple_count(), 5);
653 assert_eq!(model.entity_count(), 5); assert_eq!(model.relation_count(), 2); }
656
657 #[test]
658 fn test_entity_name() {
659 let model = sample_model();
660 assert_eq!(model.entity_name(0), Some("alice"));
661 }
662
663 #[test]
664 fn test_relation_name() {
665 let model = sample_model();
666 let name = model.relation_name(0);
667 assert!(name.is_some());
668 }
669
670 #[test]
673 fn test_train_basic() {
674 let mut model = sample_model();
675 let stats = model.train(10);
676 assert_eq!(stats.epochs_completed, 10);
677 assert_eq!(stats.loss_history.len(), 10);
678 }
679
680 #[test]
681 fn test_train_loss_decreases() {
682 let mut model = sample_model();
683 model.train(20);
684 let losses = &model.stats().loss_history;
685 let first_avg: f64 = losses[..5].iter().sum::<f64>() / 5.0;
687 let last_avg: f64 = losses[15..].iter().sum::<f64>() / 5.0;
688 assert!(last_avg < first_avg * 10.0);
690 }
691
692 #[test]
693 fn test_train_empty_triples() {
694 let mut model = TransEModel::new();
695 let stats = model.train(10);
696 assert_eq!(stats.epochs_completed, 0);
697 }
698
699 #[test]
700 fn test_train_stats_cumulative() {
701 let mut model = sample_model();
702 model.train(5);
703 model.train(5);
704 assert_eq!(model.stats().epochs_completed, 10);
705 }
706
707 #[test]
710 fn test_score_known_triple() {
711 let mut model = sample_model();
712 model.train(20);
713 let score = model.score("alice", "knows", "bob");
714 assert!(score.is_some());
715 assert!(score.expect("score") < 100.0);
716 }
717
718 #[test]
719 fn test_score_unknown_entity() {
720 let model = sample_model();
721 assert!(model.score("unknown", "knows", "bob").is_none());
722 }
723
724 #[test]
725 fn test_score_unknown_relation() {
726 let model = sample_model();
727 assert!(model.score("alice", "unknown", "bob").is_none());
728 }
729
730 #[test]
733 fn test_predict_tail() {
734 let mut model = sample_model();
735 model.train(10);
736 let predictions = model.predict_tail("alice", "knows", 3);
737 assert_eq!(predictions.len(), 3);
738 for window in predictions.windows(2) {
740 assert!(window[0].score <= window[1].score);
741 }
742 }
743
744 #[test]
745 fn test_predict_head() {
746 let mut model = sample_model();
747 model.train(10);
748 let predictions = model.predict_head("knows", "bob", 3);
749 assert_eq!(predictions.len(), 3);
750 }
751
752 #[test]
753 fn test_predict_unknown_entity() {
754 let model = sample_model();
755 let predictions = model.predict_tail("unknown", "knows", 3);
756 assert!(predictions.is_empty());
757 }
758
759 #[test]
760 fn test_predict_unknown_relation() {
761 let model = sample_model();
762 let predictions = model.predict_tail("alice", "unknown", 3);
763 assert!(predictions.is_empty());
764 }
765
766 #[test]
769 fn test_entity_embedding() {
770 let model = sample_model();
771 let emb = model.entity_embedding("alice");
772 assert!(emb.is_some());
773 assert_eq!(emb.expect("embedding").len(), 10);
774 }
775
776 #[test]
777 fn test_relation_embedding() {
778 let model = sample_model();
779 let emb = model.relation_embedding("knows");
780 assert!(emb.is_some());
781 assert_eq!(emb.expect("embedding").len(), 10);
782 }
783
784 #[test]
785 fn test_embedding_unknown() {
786 let model = sample_model();
787 assert!(model.entity_embedding("unknown").is_none());
788 assert!(model.relation_embedding("unknown").is_none());
789 }
790
791 #[test]
794 fn test_nearest_entities() {
795 let mut model = sample_model();
796 model.train(10);
797 let alice_emb = model.entity_embedding("alice").expect("alice").clone();
798 let nearest = model.nearest_entities(&alice_emb, 3);
799 assert_eq!(nearest.len(), 3);
800 assert_eq!(nearest[0].0, "alice");
802 assert!(nearest[0].1 < 1e-10);
803 }
804
805 #[test]
808 fn test_l1_distance_metric() {
809 let mut model = TransEModel::with_config(TransEConfig {
810 dim: 10,
811 distance_metric: DistanceMetric::L1,
812 ..Default::default()
813 });
814 model.add_triple("a", "r", "b");
815 model.train(5);
816 let score = model.score("a", "r", "b");
817 assert!(score.is_some());
818 }
819
820 #[test]
823 fn test_normalized_embeddings() {
824 let mut model = sample_model();
825 model.train(10);
826 for emb in model.entity_embeddings.values() {
827 let norm = l2_norm(emb);
828 assert!(norm <= 1.0 + 1e-6);
829 }
830 }
831
832 #[test]
833 fn test_no_normalization() {
834 let mut model = TransEModel::with_config(TransEConfig {
835 dim: 10,
836 normalize_embeddings: false,
837 ..Default::default()
838 });
839 model.add_triple("a", "r", "b");
840 model.train(5);
841 assert_eq!(model.triple_count(), 1);
843 }
844
845 #[test]
848 fn test_l2_norm() {
849 let v = vec![3.0, 4.0];
850 assert!((l2_norm(&v) - 5.0).abs() < 1e-10);
851 }
852
853 #[test]
854 fn test_l2_distance() {
855 let a = vec![0.0, 0.0];
856 let b = vec![3.0, 4.0];
857 assert!((l2_distance(&a, &b) - 5.0).abs() < 1e-10);
858 }
859
860 #[test]
861 fn test_l2_distance_same() {
862 let a = vec![1.0, 2.0, 3.0];
863 assert!(l2_distance(&a, &a) < 1e-10);
864 }
865
866 #[test]
869 fn test_default_model() {
870 let model = TransEModel::default();
871 assert_eq!(model.entity_count(), 0);
872 assert_eq!(model.relation_count(), 0);
873 assert_eq!(model.triple_count(), 0);
874 }
875}