1use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::random_utils::NormalSampler as Normal;
10use crate::Vector;
11use anyhow::{anyhow, Result};
12use nalgebra::{Complex, DVector};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20 TransE,
22 ComplEx,
24 RotatE,
26 GCN,
28 GraphSAGE,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35 pub model: KGEmbeddingModelType,
37 pub dimensions: usize,
39 pub learning_rate: f32,
41 pub margin: f32,
43 pub negative_samples: usize,
45 pub batch_size: usize,
47 pub epochs: usize,
49 pub norm: usize,
51 pub random_seed: Option<u64>,
53 pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58 fn default() -> Self {
59 Self {
60 model: KGEmbeddingModelType::TransE,
61 dimensions: 100,
62 learning_rate: 0.01,
63 margin: 1.0,
64 negative_samples: 10,
65 batch_size: 100,
66 epochs: 100,
67 norm: 2,
68 random_seed: Some(42),
69 regularization: 0.0,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77 pub subject: String,
78 pub predicate: String,
79 pub object: String,
80}
81
82impl Triple {
83 pub fn new(subject: String, predicate: String, object: String) -> Self {
84 Self {
85 subject,
86 predicate,
87 object,
88 }
89 }
90}
91
92pub trait KGEmbeddingModel: Send + Sync {
94 fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97 fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100 fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103 fn score_triple(&self, triple: &Triple) -> f32;
105
106 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112 fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115 fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119pub struct TransE {
122 config: KGEmbeddingConfig,
123 entity_embeddings: HashMap<String, DVector<f32>>,
124 relation_embeddings: HashMap<String, DVector<f32>>,
125 entities: Vec<String>,
126 relations: Vec<String>,
127}
128
129impl TransE {
130 pub fn new(config: KGEmbeddingConfig) -> Self {
131 Self {
132 config,
133 entity_embeddings: HashMap::new(),
134 relation_embeddings: HashMap::new(),
135 entities: Vec::new(),
136 relations: Vec::new(),
137 }
138 }
139
140 fn initialize_embeddings(&mut self, triples: &[Triple]) {
142 let mut entities = std::collections::HashSet::new();
144 let mut relations = std::collections::HashSet::new();
145
146 for triple in triples {
147 entities.insert(triple.subject.clone());
148 entities.insert(triple.object.clone());
149 relations.insert(triple.predicate.clone());
150 }
151
152 self.entities = entities.into_iter().collect();
153 self.relations = relations.into_iter().collect();
154
155 let mut rng = if let Some(seed) = self.config.random_seed {
157 Random::seed(seed)
158 } else {
159 Random::seed(42)
160 };
161
162 let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163 let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165 for entity in &self.entities {
167 let values: Vec<f32> = (0..self.config.dimensions)
168 .map(|_| rng.random_range(range_min, range_max))
169 .collect();
170 let mut embedding = DVector::from_vec(values);
171
172 let norm = embedding.norm();
174 if norm > 0.0 {
175 embedding /= norm;
176 }
177
178 self.entity_embeddings.insert(entity.clone(), embedding);
179 }
180
181 for relation in &self.relations {
183 let values: Vec<f32> = (0..self.config.dimensions)
184 .map(|_| rng.random_range(range_min, range_max))
185 .collect();
186 let embedding = DVector::from_vec(values);
187
188 self.relation_embeddings.insert(relation.clone(), embedding);
190 }
191 }
192
193 #[allow(deprecated)]
195 fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
196 let mut negatives = Vec::new();
197
198 for _ in 0..self.config.negative_samples {
199 if rng.gen_bool(0.5) {
200 let mut negative = triple.clone();
202 loop {
203 let idx = rng.gen_range(0..self.entities.len());
204 let entity = &self.entities[idx];
205 if entity != &triple.subject {
206 negative.subject = entity.clone();
207 break;
208 }
209 }
210 negatives.push(negative);
211 } else {
212 let mut negative = triple.clone();
214 loop {
215 let idx = rng.gen_range(0..self.entities.len());
216 let entity = &self.entities[idx];
217 if entity != &triple.object {
218 negative.object = entity.clone();
219 break;
220 }
221 }
222 negatives.push(negative);
223 }
224 }
225
226 negatives
227 }
228
229 fn distance(&self, triple: &Triple) -> f32 {
231 let h = self.entity_embeddings.get(&triple.subject).unwrap();
232 let r = self.relation_embeddings.get(&triple.predicate).unwrap();
233 let t = self.entity_embeddings.get(&triple.object).unwrap();
234
235 let translation = h + r - t;
236
237 match self.config.norm {
238 1 => translation.iter().map(|x| x.abs()).sum(),
239 2 => translation.norm(),
240 _ => translation.norm(),
241 }
242 }
243
244 fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
246 let pos_dist = self.distance(positive);
247
248 for negative in negatives {
249 let neg_dist = self.distance(negative);
250 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
251
252 if loss > 0.0 {
253 let h_pos = self
255 .entity_embeddings
256 .get(&positive.subject)
257 .unwrap()
258 .clone();
259 let r = self
260 .relation_embeddings
261 .get(&positive.predicate)
262 .unwrap()
263 .clone();
264 let t_pos = self
265 .entity_embeddings
266 .get(&positive.object)
267 .unwrap()
268 .clone();
269
270 let h_neg = self
271 .entity_embeddings
272 .get(&negative.subject)
273 .unwrap()
274 .clone();
275 let t_neg = self
276 .entity_embeddings
277 .get(&negative.object)
278 .unwrap()
279 .clone();
280
281 let pos_grad = &h_pos + &r - &t_pos;
282 let neg_grad = &h_neg + &r - &t_neg;
283
284 let pos_norm = pos_grad.norm();
286 let neg_norm = neg_grad.norm();
287
288 let pos_grad_norm = if pos_norm > 0.0 {
289 &pos_grad / pos_norm
290 } else {
291 pos_grad
292 };
293 let neg_grad_norm = if neg_norm > 0.0 {
294 &neg_grad / neg_norm
295 } else {
296 neg_grad
297 };
298
299 let lr = self.config.learning_rate;
301
302 if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
304 *h -= lr * &pos_grad_norm;
305 let norm = h.norm();
307 if norm > 0.0 {
308 *h /= norm;
309 }
310 }
311
312 if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
313 *r -= lr * (&pos_grad_norm - &neg_grad_norm);
314 }
315
316 if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
317 *t += lr * &pos_grad_norm;
318 let norm = t.norm();
320 if norm > 0.0 {
321 *t /= norm;
322 }
323 }
324
325 if positive.subject != negative.subject {
327 if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
328 *h += lr * &neg_grad_norm;
329 let norm = h.norm();
331 if norm > 0.0 {
332 *h /= norm;
333 }
334 }
335 }
336
337 if positive.object != negative.object {
338 if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
339 *t -= lr * &neg_grad_norm;
340 let norm = t.norm();
342 if norm > 0.0 {
343 *t /= norm;
344 }
345 }
346 }
347 }
348 }
349 }
350}
351
352impl KGEmbeddingModel for TransE {
353 fn train(&mut self, triples: &[Triple]) -> Result<()> {
354 if triples.is_empty() {
355 return Err(anyhow!("No triples provided for training"));
356 }
357
358 self.initialize_embeddings(triples);
360
361 let mut rng = if let Some(seed) = self.config.random_seed {
362 Random::seed(seed)
363 } else {
364 Random::seed(42)
365 };
366
367 for epoch in 0..self.config.epochs {
369 let mut total_loss = 0.0;
370 let mut batch_count = 0;
371
372 let mut shuffled_triples = triples.to_vec();
374 for i in (1..shuffled_triples.len()).rev() {
377 let j = rng.random_range(0, i + 1);
378 shuffled_triples.swap(i, j);
379 }
380
381 for batch in shuffled_triples.chunks(self.config.batch_size) {
383 for triple in batch {
384 let negatives = self.generate_negative_samples(triple, &mut rng);
386
387 let pos_dist = self.distance(triple);
389 for negative in &negatives {
390 let neg_dist = self.distance(negative);
391 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
392 total_loss += loss;
393 }
394
395 self.update_embeddings(triple, &negatives);
397 }
398 batch_count += 1;
399 }
400
401 if epoch % 10 == 0 {
402 let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
403 tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
404 }
405 }
406
407 Ok(())
408 }
409
410 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
411 self.entity_embeddings
412 .get(entity)
413 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
414 }
415
416 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
417 self.relation_embeddings
418 .get(relation)
419 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
420 }
421
422 fn score_triple(&self, triple: &Triple) -> f32 {
423 -self.distance(triple)
424 }
425
426 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
427 let h = match self.entity_embeddings.get(head) {
428 Some(emb) => emb,
429 None => return Vec::new(),
430 };
431
432 let r = match self.relation_embeddings.get(relation) {
433 Some(emb) => emb,
434 None => return Vec::new(),
435 };
436
437 let translation = h + r;
438
439 let mut scores: Vec<(String, f32)> = self
440 .entities
441 .iter()
442 .filter(|e| *e != head)
443 .filter_map(|entity| {
444 self.entity_embeddings.get(entity).map(|t| {
445 let distance = match self.config.norm {
446 1 => (&translation - t).iter().map(|x| x.abs()).sum(),
447 2 => (&translation - t).norm(),
448 _ => (&translation - t).norm(),
449 };
450 (entity.clone(), -distance)
451 })
452 })
453 .collect();
454
455 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
456 scores.truncate(k);
457 scores
458 }
459
460 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
461 let t = match self.entity_embeddings.get(tail) {
462 Some(emb) => emb,
463 None => return Vec::new(),
464 };
465
466 let r = match self.relation_embeddings.get(relation) {
467 Some(emb) => emb,
468 None => return Vec::new(),
469 };
470
471 let target = t - r;
472
473 let mut scores: Vec<(String, f32)> = self
474 .entities
475 .iter()
476 .filter(|e| *e != tail)
477 .filter_map(|entity| {
478 self.entity_embeddings.get(entity).map(|h| {
479 let distance = match self.config.norm {
480 1 => (h - &target).iter().map(|x| x.abs()).sum(),
481 2 => (h - &target).norm(),
482 _ => (h - &target).norm(),
483 };
484 (entity.clone(), -distance)
485 })
486 })
487 .collect();
488
489 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
490 scores.truncate(k);
491 scores
492 }
493
494 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
495 self.entity_embeddings
496 .iter()
497 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
498 .collect()
499 }
500
501 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
502 self.relation_embeddings
503 .iter()
504 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
505 .collect()
506 }
507}
508
509pub struct ComplEx {
512 config: KGEmbeddingConfig,
513 entity_embeddings_real: HashMap<String, DVector<f32>>,
514 entity_embeddings_imag: HashMap<String, DVector<f32>>,
515 relation_embeddings_real: HashMap<String, DVector<f32>>,
516 relation_embeddings_imag: HashMap<String, DVector<f32>>,
517 entities: Vec<String>,
518 relations: Vec<String>,
519}
520
521impl ComplEx {
522 pub fn new(config: KGEmbeddingConfig) -> Self {
523 Self {
524 config,
525 entity_embeddings_real: HashMap::new(),
526 entity_embeddings_imag: HashMap::new(),
527 relation_embeddings_real: HashMap::new(),
528 relation_embeddings_imag: HashMap::new(),
529 entities: Vec::new(),
530 relations: Vec::new(),
531 }
532 }
533
534 fn initialize_embeddings(&mut self, triples: &[Triple]) {
536 let mut entities = std::collections::HashSet::new();
538 let mut relations = std::collections::HashSet::new();
539
540 for triple in triples {
541 entities.insert(triple.subject.clone());
542 entities.insert(triple.object.clone());
543 relations.insert(triple.predicate.clone());
544 }
545
546 self.entities = entities.into_iter().collect();
547 self.relations = relations.into_iter().collect();
548
549 let mut rng = if let Some(seed) = self.config.random_seed {
551 Random::seed(seed)
552 } else {
553 Random::seed(42)
554 };
555
556 let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
557 let normal = Normal::new(0.0, std_dev).unwrap();
558
559 for entity in &self.entities {
561 let real_values: Vec<f32> = (0..self.config.dimensions)
562 .map(|_| normal.sample(&mut rng))
563 .collect();
564 let imag_values: Vec<f32> = (0..self.config.dimensions)
565 .map(|_| normal.sample(&mut rng))
566 .collect();
567
568 self.entity_embeddings_real
569 .insert(entity.clone(), DVector::from_vec(real_values));
570 self.entity_embeddings_imag
571 .insert(entity.clone(), DVector::from_vec(imag_values));
572 }
573
574 for relation in &self.relations {
576 let real_values: Vec<f32> = (0..self.config.dimensions)
577 .map(|_| normal.sample(&mut rng))
578 .collect();
579 let imag_values: Vec<f32> = (0..self.config.dimensions)
580 .map(|_| normal.sample(&mut rng))
581 .collect();
582
583 self.relation_embeddings_real
584 .insert(relation.clone(), DVector::from_vec(real_values));
585 self.relation_embeddings_imag
586 .insert(relation.clone(), DVector::from_vec(imag_values));
587 }
588 }
589
590 fn hermitian_dot(&self, triple: &Triple) -> f32 {
592 let h_real = self.entity_embeddings_real.get(&triple.subject).unwrap();
593 let h_imag = self.entity_embeddings_imag.get(&triple.subject).unwrap();
594 let r_real = self
595 .relation_embeddings_real
596 .get(&triple.predicate)
597 .unwrap();
598 let r_imag = self
599 .relation_embeddings_imag
600 .get(&triple.predicate)
601 .unwrap();
602 let t_real = self.entity_embeddings_real.get(&triple.object).unwrap();
603 let t_imag = self.entity_embeddings_imag.get(&triple.object).unwrap();
604
605 let mut score = 0.0;
611 for i in 0..self.config.dimensions {
612 score += h_real[i] * r_real[i] * t_real[i]
613 + h_real[i] * r_imag[i] * t_imag[i]
614 + h_imag[i] * r_real[i] * t_imag[i]
615 - h_imag[i] * r_imag[i] * t_real[i];
616 }
617
618 score
619 }
620}
621
622impl KGEmbeddingModel for ComplEx {
623 fn train(&mut self, triples: &[Triple]) -> Result<()> {
624 if triples.is_empty() {
625 return Err(anyhow!("No triples provided for training"));
626 }
627
628 self.initialize_embeddings(triples);
630
631 Ok(())
635 }
636
637 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
638 let real = self.entity_embeddings_real.get(entity)?;
640 let imag = self.entity_embeddings_imag.get(entity)?;
641
642 let mut values = Vec::with_capacity(self.config.dimensions * 2);
643 values.extend(real.iter().cloned());
644 values.extend(imag.iter().cloned());
645
646 Some(Vector::new(values))
647 }
648
649 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
650 let real = self.relation_embeddings_real.get(relation)?;
652 let imag = self.relation_embeddings_imag.get(relation)?;
653
654 let mut values = Vec::with_capacity(self.config.dimensions * 2);
655 values.extend(real.iter().cloned());
656 values.extend(imag.iter().cloned());
657
658 Some(Vector::new(values))
659 }
660
661 fn score_triple(&self, triple: &Triple) -> f32 {
662 self.hermitian_dot(triple)
663 }
664
665 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
666 let mut scores: Vec<(String, f32)> = self
667 .entities
668 .iter()
669 .filter(|e| *e != head)
670 .map(|tail| {
671 let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
672 let score = self.score_triple(&triple);
673 (tail.clone(), score)
674 })
675 .collect();
676
677 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
678 scores.truncate(k);
679 scores
680 }
681
682 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
683 let mut scores: Vec<(String, f32)> = self
684 .entities
685 .iter()
686 .filter(|e| *e != tail)
687 .map(|head| {
688 let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
689 let score = self.score_triple(&triple);
690 (head.clone(), score)
691 })
692 .collect();
693
694 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
695 scores.truncate(k);
696 scores
697 }
698
699 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
700 self.entity_embeddings_real
701 .iter()
702 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
703 .collect()
704 }
705
706 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
707 self.relation_embeddings_real
708 .iter()
709 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
710 .collect()
711 }
712}
713
714pub struct RotatE {
717 config: KGEmbeddingConfig,
718 entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
719 relation_embeddings: HashMap<String, DVector<f32>>, entities: Vec<String>,
721 relations: Vec<String>,
722}
723
724impl RotatE {
725 pub fn new(config: KGEmbeddingConfig) -> Self {
726 Self {
727 config,
728 entity_embeddings: HashMap::new(),
729 relation_embeddings: HashMap::new(),
730 entities: Vec::new(),
731 relations: Vec::new(),
732 }
733 }
734
735 fn initialize_embeddings(&mut self, triples: &[Triple]) {
737 let mut entities = std::collections::HashSet::new();
739 let mut relations = std::collections::HashSet::new();
740
741 for triple in triples {
742 entities.insert(triple.subject.clone());
743 entities.insert(triple.object.clone());
744 relations.insert(triple.predicate.clone());
745 }
746
747 self.entities = entities.into_iter().collect();
748 self.relations = relations.into_iter().collect();
749
750 let mut rng = if let Some(seed) = self.config.random_seed {
751 Random::seed(seed)
752 } else {
753 Random::seed(42)
754 };
755
756 let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
758
759 for entity in &self.entities {
760 let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
761 .map(|_| {
762 let phase = rng.gen_range(phase_range.clone());
763 Complex::new(phase.cos(), phase.sin())
764 })
765 .collect();
766
767 self.entity_embeddings
768 .insert(entity.clone(), DVector::from_vec(phases));
769 }
770
771 for relation in &self.relations {
773 let phases: Vec<f32> = (0..self.config.dimensions)
774 .map(|_| rng.gen_range(phase_range.clone()))
775 .collect();
776
777 self.relation_embeddings
778 .insert(relation.clone(), DVector::from_vec(phases));
779 }
780 }
781
782 fn distance(&self, triple: &Triple) -> f32 {
784 let h = self.entity_embeddings.get(&triple.subject).unwrap();
785 let r_phases = self.relation_embeddings.get(&triple.predicate).unwrap();
786 let t = self.entity_embeddings.get(&triple.object).unwrap();
787
788 let r: DVector<Complex<f32>> = DVector::from_iterator(
790 self.config.dimensions,
791 r_phases
792 .iter()
793 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
794 );
795
796 let rotated: DVector<Complex<f32>> = h.component_mul(&r);
798
799 let diff = rotated - t;
801 diff.iter().map(|c| c.norm()).sum::<f32>()
802 }
803}
804
805impl KGEmbeddingModel for RotatE {
806 fn train(&mut self, triples: &[Triple]) -> Result<()> {
807 if triples.is_empty() {
808 return Err(anyhow!("No triples provided for training"));
809 }
810
811 self.initialize_embeddings(triples);
813
814 Ok(())
818 }
819
820 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
821 let complex_emb = self.entity_embeddings.get(entity)?;
823
824 let mut values = Vec::with_capacity(self.config.dimensions * 2);
825 for c in complex_emb.iter() {
826 values.push(c.re); values.push(c.im); }
829
830 Some(Vector::new(values))
831 }
832
833 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
834 self.relation_embeddings
835 .get(relation)
836 .map(|phases| Vector::new(phases.iter().cloned().collect()))
837 }
838
839 fn score_triple(&self, triple: &Triple) -> f32 {
840 let gamma = 12.0; gamma - self.distance(triple)
842 }
843
844 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
845 let h = match self.entity_embeddings.get(head) {
846 Some(emb) => emb,
847 None => return Vec::new(),
848 };
849
850 let r_phases = match self.relation_embeddings.get(relation) {
851 Some(emb) => emb,
852 None => return Vec::new(),
853 };
854
855 let r: DVector<Complex<f32>> = DVector::from_iterator(
857 self.config.dimensions,
858 r_phases
859 .iter()
860 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
861 );
862
863 let rotated = h.component_mul(&r);
865
866 let mut scores: Vec<(String, f32)> = self
867 .entities
868 .iter()
869 .filter(|e| *e != head)
870 .filter_map(|entity| {
871 self.entity_embeddings.get(entity).map(|t| {
872 let diff = &rotated - t;
873 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
874 (entity.clone(), -distance)
875 })
876 })
877 .collect();
878
879 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
880 scores.truncate(k);
881 scores
882 }
883
884 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
885 let t = match self.entity_embeddings.get(tail) {
886 Some(emb) => emb,
887 None => return Vec::new(),
888 };
889
890 let r_phases = match self.relation_embeddings.get(relation) {
891 Some(emb) => emb,
892 None => return Vec::new(),
893 };
894
895 let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
897 self.config.dimensions,
898 r_phases
899 .iter()
900 .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
901 );
902
903 let mut scores: Vec<(String, f32)> = self
904 .entities
905 .iter()
906 .filter(|e| *e != tail)
907 .filter_map(|entity| {
908 self.entity_embeddings.get(entity).map(|h| {
909 let rotated = h.component_mul(&r_inv);
910 let diff = rotated - t;
911 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
912 (entity.clone(), -distance)
913 })
914 })
915 .collect();
916
917 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
918 scores.truncate(k);
919 scores
920 }
921
922 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
923 self.entity_embeddings
924 .iter()
925 .map(|(k, v)| {
926 let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
927 (k.clone(), Vector::new(real_values))
928 })
929 .collect()
930 }
931
932 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
933 self.relation_embeddings
934 .iter()
935 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
936 .collect()
937 }
938}
939
940pub struct KGEmbedding {
942 model: Box<dyn KGEmbeddingModel>,
943 config: KGEmbeddingConfig,
944}
945
946impl KGEmbedding {
947 pub fn new(config: KGEmbeddingConfig) -> Self {
949 let model: Box<dyn KGEmbeddingModel> = match config.model {
950 KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
951 KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
952 KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
953 KGEmbeddingModelType::GCN => {
954 let gcn = GCN::new(config.clone());
956 Box::new(GCNAdapter::new(gcn))
957 }
958 KGEmbeddingModelType::GraphSAGE => {
959 let graphsage = GraphSAGE::new(config.clone())
961 .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
962 Box::new(GraphSAGEAdapter::new(graphsage))
963 }
964 };
965
966 Self { model, config }
967 }
968
969 pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
971 self.model.train(triples)
972 }
973
974 pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
976 self.model.get_entity_embedding(entity)
977 }
978
979 pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
981 self.model.get_relation_embedding(relation)
982 }
983
984 pub fn score_triple(&self, triple: &Triple) -> f32 {
986 self.model.score_triple(triple)
987 }
988
989 pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
991 self.model.predict_tail(head, relation, k)
992 }
993
994 pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
996 self.model.predict_head(relation, tail, k)
997 }
998
999 pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1001 self.model.score_triple(triple) > threshold
1002 }
1003}
1004
1005pub struct GCNAdapter {
1007 gcn: GCN,
1008}
1009
1010impl GCNAdapter {
1011 pub fn new(gcn: GCN) -> Self {
1012 Self { gcn }
1013 }
1014}
1015
1016impl KGEmbeddingModel for GCNAdapter {
1017 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1018 Ok(())
1020 }
1021
1022 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1023 Some(Vector::new(vec![0.0; 128]))
1026 }
1027
1028 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1029 Some(Vector::new(vec![0.0; 128]))
1031 }
1032
1033 fn score_triple(&self, _triple: &Triple) -> f32 {
1034 0.5
1036 }
1037
1038 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1039 vec![]
1041 }
1042
1043 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1044 vec![]
1046 }
1047
1048 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1049 HashMap::new()
1050 }
1051
1052 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1053 HashMap::new()
1054 }
1055}
1056
1057pub struct GraphSAGEAdapter {
1059 graphsage: GraphSAGE,
1060}
1061
1062impl GraphSAGEAdapter {
1063 pub fn new(graphsage: GraphSAGE) -> Self {
1064 Self { graphsage }
1065 }
1066}
1067
1068impl KGEmbeddingModel for GraphSAGEAdapter {
1069 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1070 Ok(())
1072 }
1073
1074 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1075 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1077 }
1078
1079 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1080 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1082 }
1083
1084 fn score_triple(&self, _triple: &Triple) -> f32 {
1085 0.5
1087 }
1088
1089 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1090 vec![]
1092 }
1093
1094 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1095 vec![]
1097 }
1098
1099 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1100 HashMap::new()
1101 }
1102
1103 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1104 HashMap::new()
1105 }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110 use super::*;
1111
1112 fn create_test_triples() -> Vec<Triple> {
1113 vec![
1114 Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1115 Triple::new(
1116 "Bob".to_string(),
1117 "knows".to_string(),
1118 "Charlie".to_string(),
1119 ),
1120 Triple::new(
1121 "Alice".to_string(),
1122 "likes".to_string(),
1123 "Pizza".to_string(),
1124 ),
1125 Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1126 Triple::new(
1127 "Charlie".to_string(),
1128 "knows".to_string(),
1129 "Alice".to_string(),
1130 ),
1131 ]
1132 }
1133
1134 #[test]
1135 fn test_transe() {
1136 let config = KGEmbeddingConfig {
1137 model: KGEmbeddingModelType::TransE,
1138 dimensions: 50,
1139 epochs: 10,
1140 ..Default::default()
1141 };
1142
1143 let mut model = KGEmbedding::new(config);
1144 let triples = create_test_triples();
1145
1146 model.train(&triples).unwrap();
1147
1148 assert!(model.get_entity_embedding("Alice").is_some());
1150 assert!(model.get_relation_embedding("knows").is_some());
1151
1152 let score = model.score_triple(&triples[0]);
1154 assert!(score.is_finite());
1155
1156 let predictions = model.predict_tail("Alice", "knows", 2);
1158 assert!(!predictions.is_empty());
1159 }
1160
1161 #[test]
1162 fn test_complex() {
1163 let config = KGEmbeddingConfig {
1164 model: KGEmbeddingModelType::ComplEx,
1165 dimensions: 50,
1166 epochs: 10,
1167 ..Default::default()
1168 };
1169
1170 let mut model = KGEmbedding::new(config);
1171 let triples = create_test_triples();
1172
1173 model.train(&triples).unwrap();
1174
1175 assert!(model.get_entity_embedding("Bob").is_some());
1177 let emb = model.get_entity_embedding("Bob").unwrap();
1178 assert_eq!(emb.dimensions, 100); }
1180
1181 #[test]
1182 fn test_rotate() {
1183 let config = KGEmbeddingConfig {
1184 model: KGEmbeddingModelType::RotatE,
1185 dimensions: 50,
1186 epochs: 10,
1187 ..Default::default()
1188 };
1189
1190 let mut model = KGEmbedding::new(config);
1191 let triples = create_test_triples();
1192
1193 model.train(&triples).unwrap();
1194
1195 assert!(model.get_entity_embedding("Charlie").is_some());
1197 assert!(model.get_relation_embedding("likes").is_some());
1198
1199 let rel_emb = model.get_relation_embedding("likes").unwrap();
1201 assert_eq!(rel_emb.dimensions, 50);
1202 }
1203}