1use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::random_utils::NormalSampler as Normal;
10use crate::Vector;
11use anyhow::{anyhow, Result};
12use nalgebra::{Complex, DVector};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20 TransE,
22 ComplEx,
24 RotatE,
26 GCN,
28 GraphSAGE,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35 pub model: KGEmbeddingModelType,
37 pub dimensions: usize,
39 pub learning_rate: f32,
41 pub margin: f32,
43 pub negative_samples: usize,
45 pub batch_size: usize,
47 pub epochs: usize,
49 pub norm: usize,
51 pub random_seed: Option<u64>,
53 pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58 fn default() -> Self {
59 Self {
60 model: KGEmbeddingModelType::TransE,
61 dimensions: 100,
62 learning_rate: 0.01,
63 margin: 1.0,
64 negative_samples: 10,
65 batch_size: 100,
66 epochs: 100,
67 norm: 2,
68 random_seed: Some(42),
69 regularization: 0.0,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77 pub subject: String,
78 pub predicate: String,
79 pub object: String,
80}
81
82impl Triple {
83 pub fn new(subject: String, predicate: String, object: String) -> Self {
84 Self {
85 subject,
86 predicate,
87 object,
88 }
89 }
90}
91
92pub trait KGEmbeddingModel: Send + Sync {
94 fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97 fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100 fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103 fn score_triple(&self, triple: &Triple) -> f32;
105
106 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112 fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115 fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119pub struct TransE {
122 config: KGEmbeddingConfig,
123 entity_embeddings: HashMap<String, DVector<f32>>,
124 relation_embeddings: HashMap<String, DVector<f32>>,
125 entities: Vec<String>,
126 relations: Vec<String>,
127}
128
129impl TransE {
130 pub fn new(config: KGEmbeddingConfig) -> Self {
131 Self {
132 config,
133 entity_embeddings: HashMap::new(),
134 relation_embeddings: HashMap::new(),
135 entities: Vec::new(),
136 relations: Vec::new(),
137 }
138 }
139
140 fn initialize_embeddings(&mut self, triples: &[Triple]) {
142 let mut entities = std::collections::HashSet::new();
144 let mut relations = std::collections::HashSet::new();
145
146 for triple in triples {
147 entities.insert(triple.subject.clone());
148 entities.insert(triple.object.clone());
149 relations.insert(triple.predicate.clone());
150 }
151
152 self.entities = entities.into_iter().collect();
153 self.relations = relations.into_iter().collect();
154
155 let mut rng = if let Some(seed) = self.config.random_seed {
157 Random::seed(seed)
158 } else {
159 Random::seed(42)
160 };
161
162 let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163 let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165 for entity in &self.entities {
167 let values: Vec<f32> = (0..self.config.dimensions)
168 .map(|_| rng.random_range(range_min..range_max))
169 .collect();
170 let mut embedding = DVector::from_vec(values);
171
172 let norm = embedding.norm();
174 if norm > 0.0 {
175 embedding /= norm;
176 }
177
178 self.entity_embeddings.insert(entity.clone(), embedding);
179 }
180
181 for relation in &self.relations {
183 let values: Vec<f32> = (0..self.config.dimensions)
184 .map(|_| rng.random_range(range_min..range_max))
185 .collect();
186 let embedding = DVector::from_vec(values);
187
188 self.relation_embeddings.insert(relation.clone(), embedding);
190 }
191 }
192
193 #[allow(deprecated)]
195 fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
196 let mut negatives = Vec::new();
197
198 for _ in 0..self.config.negative_samples {
199 if rng.gen_bool(0.5) {
200 let mut negative = triple.clone();
202 loop {
203 let idx = rng.gen_range(0..self.entities.len());
204 let entity = &self.entities[idx];
205 if entity != &triple.subject {
206 negative.subject = entity.clone();
207 break;
208 }
209 }
210 negatives.push(negative);
211 } else {
212 let mut negative = triple.clone();
214 loop {
215 let idx = rng.gen_range(0..self.entities.len());
216 let entity = &self.entities[idx];
217 if entity != &triple.object {
218 negative.object = entity.clone();
219 break;
220 }
221 }
222 negatives.push(negative);
223 }
224 }
225
226 negatives
227 }
228
229 fn distance(&self, triple: &Triple) -> f32 {
231 let h = self
232 .entity_embeddings
233 .get(&triple.subject)
234 .expect("subject entity should have embedding");
235 let r = self
236 .relation_embeddings
237 .get(&triple.predicate)
238 .expect("predicate relation should have embedding");
239 let t = self
240 .entity_embeddings
241 .get(&triple.object)
242 .expect("object entity should have embedding");
243
244 let translation = h + r - t;
245
246 match self.config.norm {
247 1 => translation.iter().map(|x| x.abs()).sum(),
248 2 => translation.norm(),
249 _ => translation.norm(),
250 }
251 }
252
253 fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
255 let pos_dist = self.distance(positive);
256
257 for negative in negatives {
258 let neg_dist = self.distance(negative);
259 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
260
261 if loss > 0.0 {
262 let h_pos = self
264 .entity_embeddings
265 .get(&positive.subject)
266 .expect("positive subject entity should have embedding")
267 .clone();
268 let r = self
269 .relation_embeddings
270 .get(&positive.predicate)
271 .expect("positive predicate relation should have embedding")
272 .clone();
273 let t_pos = self
274 .entity_embeddings
275 .get(&positive.object)
276 .expect("positive object entity should have embedding")
277 .clone();
278
279 let h_neg = self
280 .entity_embeddings
281 .get(&negative.subject)
282 .expect("negative subject entity should have embedding")
283 .clone();
284 let t_neg = self
285 .entity_embeddings
286 .get(&negative.object)
287 .expect("negative object entity should have embedding")
288 .clone();
289
290 let pos_grad = &h_pos + &r - &t_pos;
291 let neg_grad = &h_neg + &r - &t_neg;
292
293 let pos_norm = pos_grad.norm();
295 let neg_norm = neg_grad.norm();
296
297 let pos_grad_norm = if pos_norm > 0.0 {
298 &pos_grad / pos_norm
299 } else {
300 pos_grad
301 };
302 let neg_grad_norm = if neg_norm > 0.0 {
303 &neg_grad / neg_norm
304 } else {
305 neg_grad
306 };
307
308 let lr = self.config.learning_rate;
310
311 if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
313 *h -= lr * &pos_grad_norm;
314 let norm = h.norm();
316 if norm > 0.0 {
317 *h /= norm;
318 }
319 }
320
321 if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
322 *r -= lr * (&pos_grad_norm - &neg_grad_norm);
323 }
324
325 if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
326 *t += lr * &pos_grad_norm;
327 let norm = t.norm();
329 if norm > 0.0 {
330 *t /= norm;
331 }
332 }
333
334 if positive.subject != negative.subject {
336 if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
337 *h += lr * &neg_grad_norm;
338 let norm = h.norm();
340 if norm > 0.0 {
341 *h /= norm;
342 }
343 }
344 }
345
346 if positive.object != negative.object {
347 if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
348 *t -= lr * &neg_grad_norm;
349 let norm = t.norm();
351 if norm > 0.0 {
352 *t /= norm;
353 }
354 }
355 }
356 }
357 }
358 }
359}
360
361impl KGEmbeddingModel for TransE {
362 fn train(&mut self, triples: &[Triple]) -> Result<()> {
363 if triples.is_empty() {
364 return Err(anyhow!("No triples provided for training"));
365 }
366
367 self.initialize_embeddings(triples);
369
370 let mut rng = if let Some(seed) = self.config.random_seed {
371 Random::seed(seed)
372 } else {
373 Random::seed(42)
374 };
375
376 for epoch in 0..self.config.epochs {
378 let mut total_loss = 0.0;
379 let mut batch_count = 0;
380
381 let mut shuffled_triples = triples.to_vec();
383 for i in (1..shuffled_triples.len()).rev() {
386 let j = rng.random_range(0..i + 1);
387 shuffled_triples.swap(i, j);
388 }
389
390 for batch in shuffled_triples.chunks(self.config.batch_size) {
392 for triple in batch {
393 let negatives = self.generate_negative_samples(triple, &mut rng);
395
396 let pos_dist = self.distance(triple);
398 for negative in &negatives {
399 let neg_dist = self.distance(negative);
400 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
401 total_loss += loss;
402 }
403
404 self.update_embeddings(triple, &negatives);
406 }
407 batch_count += 1;
408 }
409
410 if epoch % 10 == 0 {
411 let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
412 tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
413 }
414 }
415
416 Ok(())
417 }
418
419 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
420 self.entity_embeddings
421 .get(entity)
422 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
423 }
424
425 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
426 self.relation_embeddings
427 .get(relation)
428 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
429 }
430
431 fn score_triple(&self, triple: &Triple) -> f32 {
432 -self.distance(triple)
433 }
434
435 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
436 let h = match self.entity_embeddings.get(head) {
437 Some(emb) => emb,
438 None => return Vec::new(),
439 };
440
441 let r = match self.relation_embeddings.get(relation) {
442 Some(emb) => emb,
443 None => return Vec::new(),
444 };
445
446 let translation = h + r;
447
448 let mut scores: Vec<(String, f32)> = self
449 .entities
450 .iter()
451 .filter(|e| *e != head)
452 .filter_map(|entity| {
453 self.entity_embeddings.get(entity).map(|t| {
454 let distance = match self.config.norm {
455 1 => (&translation - t).iter().map(|x| x.abs()).sum(),
456 2 => (&translation - t).norm(),
457 _ => (&translation - t).norm(),
458 };
459 (entity.clone(), -distance)
460 })
461 })
462 .collect();
463
464 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
465 scores.truncate(k);
466 scores
467 }
468
469 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
470 let t = match self.entity_embeddings.get(tail) {
471 Some(emb) => emb,
472 None => return Vec::new(),
473 };
474
475 let r = match self.relation_embeddings.get(relation) {
476 Some(emb) => emb,
477 None => return Vec::new(),
478 };
479
480 let target = t - r;
481
482 let mut scores: Vec<(String, f32)> = self
483 .entities
484 .iter()
485 .filter(|e| *e != tail)
486 .filter_map(|entity| {
487 self.entity_embeddings.get(entity).map(|h| {
488 let distance = match self.config.norm {
489 1 => (h - &target).iter().map(|x| x.abs()).sum(),
490 2 => (h - &target).norm(),
491 _ => (h - &target).norm(),
492 };
493 (entity.clone(), -distance)
494 })
495 })
496 .collect();
497
498 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
499 scores.truncate(k);
500 scores
501 }
502
503 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
504 self.entity_embeddings
505 .iter()
506 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
507 .collect()
508 }
509
510 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
511 self.relation_embeddings
512 .iter()
513 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
514 .collect()
515 }
516}
517
518pub struct ComplEx {
521 config: KGEmbeddingConfig,
522 entity_embeddings_real: HashMap<String, DVector<f32>>,
523 entity_embeddings_imag: HashMap<String, DVector<f32>>,
524 relation_embeddings_real: HashMap<String, DVector<f32>>,
525 relation_embeddings_imag: HashMap<String, DVector<f32>>,
526 entities: Vec<String>,
527 relations: Vec<String>,
528}
529
530impl ComplEx {
531 pub fn new(config: KGEmbeddingConfig) -> Self {
532 Self {
533 config,
534 entity_embeddings_real: HashMap::new(),
535 entity_embeddings_imag: HashMap::new(),
536 relation_embeddings_real: HashMap::new(),
537 relation_embeddings_imag: HashMap::new(),
538 entities: Vec::new(),
539 relations: Vec::new(),
540 }
541 }
542
543 fn initialize_embeddings(&mut self, triples: &[Triple]) {
545 let mut entities = std::collections::HashSet::new();
547 let mut relations = std::collections::HashSet::new();
548
549 for triple in triples {
550 entities.insert(triple.subject.clone());
551 entities.insert(triple.object.clone());
552 relations.insert(triple.predicate.clone());
553 }
554
555 self.entities = entities.into_iter().collect();
556 self.relations = relations.into_iter().collect();
557
558 let mut rng = if let Some(seed) = self.config.random_seed {
560 Random::seed(seed)
561 } else {
562 Random::seed(42)
563 };
564
565 let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
566 let normal =
567 Normal::new(0.0, std_dev).expect("normal distribution parameters should be valid");
568
569 for entity in &self.entities {
571 let real_values: Vec<f32> = (0..self.config.dimensions)
572 .map(|_| normal.sample(&mut rng))
573 .collect();
574 let imag_values: Vec<f32> = (0..self.config.dimensions)
575 .map(|_| normal.sample(&mut rng))
576 .collect();
577
578 self.entity_embeddings_real
579 .insert(entity.clone(), DVector::from_vec(real_values));
580 self.entity_embeddings_imag
581 .insert(entity.clone(), DVector::from_vec(imag_values));
582 }
583
584 for relation in &self.relations {
586 let real_values: Vec<f32> = (0..self.config.dimensions)
587 .map(|_| normal.sample(&mut rng))
588 .collect();
589 let imag_values: Vec<f32> = (0..self.config.dimensions)
590 .map(|_| normal.sample(&mut rng))
591 .collect();
592
593 self.relation_embeddings_real
594 .insert(relation.clone(), DVector::from_vec(real_values));
595 self.relation_embeddings_imag
596 .insert(relation.clone(), DVector::from_vec(imag_values));
597 }
598 }
599
600 fn hermitian_dot(&self, triple: &Triple) -> f32 {
602 let h_real = self
603 .entity_embeddings_real
604 .get(&triple.subject)
605 .expect("subject entity should have real embedding");
606 let h_imag = self
607 .entity_embeddings_imag
608 .get(&triple.subject)
609 .expect("subject entity should have imag embedding");
610 let r_real = self
611 .relation_embeddings_real
612 .get(&triple.predicate)
613 .expect("predicate relation should have real embedding");
614 let r_imag = self
615 .relation_embeddings_imag
616 .get(&triple.predicate)
617 .expect("predicate relation should have imag embedding");
618 let t_real = self
619 .entity_embeddings_real
620 .get(&triple.object)
621 .expect("object entity should have real embedding");
622 let t_imag = self
623 .entity_embeddings_imag
624 .get(&triple.object)
625 .expect("object entity should have imag embedding");
626
627 let mut score = 0.0;
633 for i in 0..self.config.dimensions {
634 score += h_real[i] * r_real[i] * t_real[i]
635 + h_real[i] * r_imag[i] * t_imag[i]
636 + h_imag[i] * r_real[i] * t_imag[i]
637 - h_imag[i] * r_imag[i] * t_real[i];
638 }
639
640 score
641 }
642}
643
644impl KGEmbeddingModel for ComplEx {
645 fn train(&mut self, triples: &[Triple]) -> Result<()> {
646 if triples.is_empty() {
647 return Err(anyhow!("No triples provided for training"));
648 }
649
650 self.initialize_embeddings(triples);
652
653 Ok(())
657 }
658
659 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
660 let real = self.entity_embeddings_real.get(entity)?;
662 let imag = self.entity_embeddings_imag.get(entity)?;
663
664 let mut values = Vec::with_capacity(self.config.dimensions * 2);
665 values.extend(real.iter().cloned());
666 values.extend(imag.iter().cloned());
667
668 Some(Vector::new(values))
669 }
670
671 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
672 let real = self.relation_embeddings_real.get(relation)?;
674 let imag = self.relation_embeddings_imag.get(relation)?;
675
676 let mut values = Vec::with_capacity(self.config.dimensions * 2);
677 values.extend(real.iter().cloned());
678 values.extend(imag.iter().cloned());
679
680 Some(Vector::new(values))
681 }
682
683 fn score_triple(&self, triple: &Triple) -> f32 {
684 self.hermitian_dot(triple)
685 }
686
687 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
688 let mut scores: Vec<(String, f32)> = self
689 .entities
690 .iter()
691 .filter(|e| *e != head)
692 .map(|tail| {
693 let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
694 let score = self.score_triple(&triple);
695 (tail.clone(), score)
696 })
697 .collect();
698
699 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
700 scores.truncate(k);
701 scores
702 }
703
704 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
705 let mut scores: Vec<(String, f32)> = self
706 .entities
707 .iter()
708 .filter(|e| *e != tail)
709 .map(|head| {
710 let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
711 let score = self.score_triple(&triple);
712 (head.clone(), score)
713 })
714 .collect();
715
716 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
717 scores.truncate(k);
718 scores
719 }
720
721 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
722 self.entity_embeddings_real
723 .iter()
724 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
725 .collect()
726 }
727
728 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
729 self.relation_embeddings_real
730 .iter()
731 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
732 .collect()
733 }
734}
735
736pub struct RotatE {
739 config: KGEmbeddingConfig,
740 entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
741 relation_embeddings: HashMap<String, DVector<f32>>, entities: Vec<String>,
743 relations: Vec<String>,
744}
745
746impl RotatE {
747 pub fn new(config: KGEmbeddingConfig) -> Self {
748 Self {
749 config,
750 entity_embeddings: HashMap::new(),
751 relation_embeddings: HashMap::new(),
752 entities: Vec::new(),
753 relations: Vec::new(),
754 }
755 }
756
757 fn initialize_embeddings(&mut self, triples: &[Triple]) {
759 let mut entities = std::collections::HashSet::new();
761 let mut relations = std::collections::HashSet::new();
762
763 for triple in triples {
764 entities.insert(triple.subject.clone());
765 entities.insert(triple.object.clone());
766 relations.insert(triple.predicate.clone());
767 }
768
769 self.entities = entities.into_iter().collect();
770 self.relations = relations.into_iter().collect();
771
772 let mut rng = if let Some(seed) = self.config.random_seed {
773 Random::seed(seed)
774 } else {
775 Random::seed(42)
776 };
777
778 let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
780
781 for entity in &self.entities {
782 let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
783 .map(|_| {
784 let phase = rng.gen_range(phase_range.clone());
785 Complex::new(phase.cos(), phase.sin())
786 })
787 .collect();
788
789 self.entity_embeddings
790 .insert(entity.clone(), DVector::from_vec(phases));
791 }
792
793 for relation in &self.relations {
795 let phases: Vec<f32> = (0..self.config.dimensions)
796 .map(|_| rng.gen_range(phase_range.clone()))
797 .collect();
798
799 self.relation_embeddings
800 .insert(relation.clone(), DVector::from_vec(phases));
801 }
802 }
803
804 fn distance(&self, triple: &Triple) -> f32 {
806 let h = self
807 .entity_embeddings
808 .get(&triple.subject)
809 .expect("subject entity should have embedding");
810 let r_phases = self
811 .relation_embeddings
812 .get(&triple.predicate)
813 .expect("predicate relation should have embedding");
814 let t = self
815 .entity_embeddings
816 .get(&triple.object)
817 .expect("object entity should have embedding");
818
819 let r: DVector<Complex<f32>> = DVector::from_iterator(
821 self.config.dimensions,
822 r_phases
823 .iter()
824 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
825 );
826
827 let rotated: DVector<Complex<f32>> = h.component_mul(&r);
829
830 let diff = rotated - t;
832 diff.iter().map(|c| c.norm()).sum::<f32>()
833 }
834}
835
836impl KGEmbeddingModel for RotatE {
837 fn train(&mut self, triples: &[Triple]) -> Result<()> {
838 if triples.is_empty() {
839 return Err(anyhow!("No triples provided for training"));
840 }
841
842 self.initialize_embeddings(triples);
844
845 Ok(())
849 }
850
851 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
852 let complex_emb = self.entity_embeddings.get(entity)?;
854
855 let mut values = Vec::with_capacity(self.config.dimensions * 2);
856 for c in complex_emb.iter() {
857 values.push(c.re); values.push(c.im); }
860
861 Some(Vector::new(values))
862 }
863
864 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
865 self.relation_embeddings
866 .get(relation)
867 .map(|phases| Vector::new(phases.iter().cloned().collect()))
868 }
869
870 fn score_triple(&self, triple: &Triple) -> f32 {
871 let gamma = 12.0; gamma - self.distance(triple)
873 }
874
875 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
876 let h = match self.entity_embeddings.get(head) {
877 Some(emb) => emb,
878 None => return Vec::new(),
879 };
880
881 let r_phases = match self.relation_embeddings.get(relation) {
882 Some(emb) => emb,
883 None => return Vec::new(),
884 };
885
886 let r: DVector<Complex<f32>> = DVector::from_iterator(
888 self.config.dimensions,
889 r_phases
890 .iter()
891 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
892 );
893
894 let rotated = h.component_mul(&r);
896
897 let mut scores: Vec<(String, f32)> = self
898 .entities
899 .iter()
900 .filter(|e| *e != head)
901 .filter_map(|entity| {
902 self.entity_embeddings.get(entity).map(|t| {
903 let diff = &rotated - t;
904 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
905 (entity.clone(), -distance)
906 })
907 })
908 .collect();
909
910 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
911 scores.truncate(k);
912 scores
913 }
914
915 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
916 let t = match self.entity_embeddings.get(tail) {
917 Some(emb) => emb,
918 None => return Vec::new(),
919 };
920
921 let r_phases = match self.relation_embeddings.get(relation) {
922 Some(emb) => emb,
923 None => return Vec::new(),
924 };
925
926 let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
928 self.config.dimensions,
929 r_phases
930 .iter()
931 .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
932 );
933
934 let mut scores: Vec<(String, f32)> = self
935 .entities
936 .iter()
937 .filter(|e| *e != tail)
938 .filter_map(|entity| {
939 self.entity_embeddings.get(entity).map(|h| {
940 let rotated = h.component_mul(&r_inv);
941 let diff = rotated - t;
942 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
943 (entity.clone(), -distance)
944 })
945 })
946 .collect();
947
948 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
949 scores.truncate(k);
950 scores
951 }
952
953 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
954 self.entity_embeddings
955 .iter()
956 .map(|(k, v)| {
957 let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
958 (k.clone(), Vector::new(real_values))
959 })
960 .collect()
961 }
962
963 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
964 self.relation_embeddings
965 .iter()
966 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
967 .collect()
968 }
969}
970
971pub struct KGEmbedding {
973 model: Box<dyn KGEmbeddingModel>,
974 config: KGEmbeddingConfig,
975}
976
977impl KGEmbedding {
978 pub fn new(config: KGEmbeddingConfig) -> Self {
980 let model: Box<dyn KGEmbeddingModel> = match config.model {
981 KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
982 KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
983 KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
984 KGEmbeddingModelType::GCN => {
985 let gcn = GCN::new(config.clone());
987 Box::new(GCNAdapter::new(gcn))
988 }
989 KGEmbeddingModelType::GraphSAGE => {
990 let graphsage = GraphSAGE::new(config.clone())
992 .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
993 Box::new(GraphSAGEAdapter::new(graphsage))
994 }
995 };
996
997 Self { model, config }
998 }
999
1000 pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
1002 self.model.train(triples)
1003 }
1004
1005 pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
1007 self.model.get_entity_embedding(entity)
1008 }
1009
1010 pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
1012 self.model.get_relation_embedding(relation)
1013 }
1014
1015 pub fn score_triple(&self, triple: &Triple) -> f32 {
1017 self.model.score_triple(triple)
1018 }
1019
1020 pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
1022 self.model.predict_tail(head, relation, k)
1023 }
1024
1025 pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
1027 self.model.predict_head(relation, tail, k)
1028 }
1029
1030 pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1032 self.model.score_triple(triple) > threshold
1033 }
1034}
1035
1036pub struct GCNAdapter {
1038 gcn: GCN,
1039}
1040
1041impl GCNAdapter {
1042 pub fn new(gcn: GCN) -> Self {
1043 Self { gcn }
1044 }
1045}
1046
1047impl KGEmbeddingModel for GCNAdapter {
1048 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1049 Ok(())
1051 }
1052
1053 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1054 Some(Vector::new(vec![0.0; 128]))
1057 }
1058
1059 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1060 Some(Vector::new(vec![0.0; 128]))
1062 }
1063
1064 fn score_triple(&self, _triple: &Triple) -> f32 {
1065 0.5
1067 }
1068
1069 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1070 vec![]
1072 }
1073
1074 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1075 vec![]
1077 }
1078
1079 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1080 HashMap::new()
1081 }
1082
1083 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1084 HashMap::new()
1085 }
1086}
1087
1088pub struct GraphSAGEAdapter {
1090 graphsage: GraphSAGE,
1091}
1092
1093impl GraphSAGEAdapter {
1094 pub fn new(graphsage: GraphSAGE) -> Self {
1095 Self { graphsage }
1096 }
1097}
1098
1099impl KGEmbeddingModel for GraphSAGEAdapter {
1100 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1101 Ok(())
1103 }
1104
1105 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1106 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1108 }
1109
1110 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1111 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1113 }
1114
1115 fn score_triple(&self, _triple: &Triple) -> f32 {
1116 0.5
1118 }
1119
1120 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1121 vec![]
1123 }
1124
1125 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1126 vec![]
1128 }
1129
1130 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1131 HashMap::new()
1132 }
1133
1134 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1135 HashMap::new()
1136 }
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141 use super::*;
1142
1143 fn create_test_triples() -> Vec<Triple> {
1144 vec![
1145 Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1146 Triple::new(
1147 "Bob".to_string(),
1148 "knows".to_string(),
1149 "Charlie".to_string(),
1150 ),
1151 Triple::new(
1152 "Alice".to_string(),
1153 "likes".to_string(),
1154 "Pizza".to_string(),
1155 ),
1156 Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1157 Triple::new(
1158 "Charlie".to_string(),
1159 "knows".to_string(),
1160 "Alice".to_string(),
1161 ),
1162 ]
1163 }
1164
1165 #[test]
1166 fn test_transe() {
1167 let config = KGEmbeddingConfig {
1168 model: KGEmbeddingModelType::TransE,
1169 dimensions: 50,
1170 epochs: 10,
1171 ..Default::default()
1172 };
1173
1174 let mut model = KGEmbedding::new(config);
1175 let triples = create_test_triples();
1176
1177 model.train(&triples).unwrap();
1178
1179 assert!(model.get_entity_embedding("Alice").is_some());
1181 assert!(model.get_relation_embedding("knows").is_some());
1182
1183 let score = model.score_triple(&triples[0]);
1185 assert!(score.is_finite());
1186
1187 let predictions = model.predict_tail("Alice", "knows", 2);
1189 assert!(!predictions.is_empty());
1190 }
1191
1192 #[test]
1193 fn test_complex() {
1194 let config = KGEmbeddingConfig {
1195 model: KGEmbeddingModelType::ComplEx,
1196 dimensions: 50,
1197 epochs: 10,
1198 ..Default::default()
1199 };
1200
1201 let mut model = KGEmbedding::new(config);
1202 let triples = create_test_triples();
1203
1204 model.train(&triples).unwrap();
1205
1206 assert!(model.get_entity_embedding("Bob").is_some());
1208 let emb = model.get_entity_embedding("Bob").unwrap();
1209 assert_eq!(emb.dimensions, 100); }
1211
1212 #[test]
1213 fn test_rotate() {
1214 let config = KGEmbeddingConfig {
1215 model: KGEmbeddingModelType::RotatE,
1216 dimensions: 50,
1217 epochs: 10,
1218 ..Default::default()
1219 };
1220
1221 let mut model = KGEmbedding::new(config);
1222 let triples = create_test_triples();
1223
1224 model.train(&triples).unwrap();
1225
1226 assert!(model.get_entity_embedding("Charlie").is_some());
1228 assert!(model.get_relation_embedding("likes").is_some());
1229
1230 let rel_emb = model.get_relation_embedding("likes").unwrap();
1232 assert_eq!(rel_emb.dimensions, 50);
1233 }
1234}