1use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::Vector;
10use anyhow::{anyhow, Result};
11use nalgebra::{Complex, DVector};
12use crate::random_utils::{NormalSampler as Normal, UniformSampler as Uniform};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20 TransE,
22 ComplEx,
24 RotatE,
26 GCN,
28 GraphSAGE,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35 pub model: KGEmbeddingModelType,
37 pub dimensions: usize,
39 pub learning_rate: f32,
41 pub margin: f32,
43 pub negative_samples: usize,
45 pub batch_size: usize,
47 pub epochs: usize,
49 pub norm: usize,
51 pub random_seed: Option<u64>,
53 pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58 fn default() -> Self {
59 Self {
60 model: KGEmbeddingModelType::TransE,
61 dimensions: 100,
62 learning_rate: 0.01,
63 margin: 1.0,
64 negative_samples: 10,
65 batch_size: 100,
66 epochs: 100,
67 norm: 2,
68 random_seed: Some(42),
69 regularization: 0.0,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77 pub subject: String,
78 pub predicate: String,
79 pub object: String,
80}
81
82impl Triple {
83 pub fn new(subject: String, predicate: String, object: String) -> Self {
84 Self {
85 subject,
86 predicate,
87 object,
88 }
89 }
90}
91
92pub trait KGEmbeddingModel: Send + Sync {
94 fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97 fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100 fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103 fn score_triple(&self, triple: &Triple) -> f32;
105
106 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112 fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115 fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119pub struct TransE {
122 config: KGEmbeddingConfig,
123 entity_embeddings: HashMap<String, DVector<f32>>,
124 relation_embeddings: HashMap<String, DVector<f32>>,
125 entities: Vec<String>,
126 relations: Vec<String>,
127}
128
129impl TransE {
130 pub fn new(config: KGEmbeddingConfig) -> Self {
131 Self {
132 config,
133 entity_embeddings: HashMap::new(),
134 relation_embeddings: HashMap::new(),
135 entities: Vec::new(),
136 relations: Vec::new(),
137 }
138 }
139
140 fn initialize_embeddings(&mut self, triples: &[Triple]) {
142 let mut entities = std::collections::HashSet::new();
144 let mut relations = std::collections::HashSet::new();
145
146 for triple in triples {
147 entities.insert(triple.subject.clone());
148 entities.insert(triple.object.clone());
149 relations.insert(triple.predicate.clone());
150 }
151
152 self.entities = entities.into_iter().collect();
153 self.relations = relations.into_iter().collect();
154
155 let mut rng = if let Some(seed) = self.config.random_seed {
157 Random::seed(seed)
158 } else {
159 Random::seed(42)
160 };
161
162 let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163 let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165 for entity in &self.entities {
167 let values: Vec<f32> = (0..self.config.dimensions)
168 .map(|_| rng.gen_range(range_min..range_max))
169 .collect();
170 let mut embedding = DVector::from_vec(values);
171
172 let norm = embedding.norm();
174 if norm > 0.0 {
175 embedding /= norm;
176 }
177
178 self.entity_embeddings.insert(entity.clone(), embedding);
179 }
180
181 for relation in &self.relations {
183 let values: Vec<f32> = (0..self.config.dimensions)
184 .map(|_| rng.gen_range(range_min..range_max))
185 .collect();
186 let embedding = DVector::from_vec(values);
187
188 self.relation_embeddings.insert(relation.clone(), embedding);
190 }
191 }
192
193 fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
195 let mut negatives = Vec::new();
196
197 for _ in 0..self.config.negative_samples {
198 if rng.gen_bool(0.5) {
199 let mut negative = triple.clone();
201 loop {
202 let idx = rng.gen_range(0..self.entities.len());
203 let entity = &self.entities[idx];
204 if entity != &triple.subject {
205 negative.subject = entity.clone();
206 break;
207 }
208 }
209 negatives.push(negative);
210 } else {
211 let mut negative = triple.clone();
213 loop {
214 let idx = rng.gen_range(0..self.entities.len());
215 let entity = &self.entities[idx];
216 if entity != &triple.object {
217 negative.object = entity.clone();
218 break;
219 }
220 }
221 negatives.push(negative);
222 }
223 }
224
225 negatives
226 }
227
228 fn distance(&self, triple: &Triple) -> f32 {
230 let h = self.entity_embeddings.get(&triple.subject).unwrap();
231 let r = self.relation_embeddings.get(&triple.predicate).unwrap();
232 let t = self.entity_embeddings.get(&triple.object).unwrap();
233
234 let translation = h + r - t;
235
236 match self.config.norm {
237 1 => translation.iter().map(|x| x.abs()).sum(),
238 2 => translation.norm(),
239 _ => translation.norm(),
240 }
241 }
242
243 fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
245 let pos_dist = self.distance(positive);
246
247 for negative in negatives {
248 let neg_dist = self.distance(negative);
249 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
250
251 if loss > 0.0 {
252 let h_pos = self
254 .entity_embeddings
255 .get(&positive.subject)
256 .unwrap()
257 .clone();
258 let r = self
259 .relation_embeddings
260 .get(&positive.predicate)
261 .unwrap()
262 .clone();
263 let t_pos = self
264 .entity_embeddings
265 .get(&positive.object)
266 .unwrap()
267 .clone();
268
269 let h_neg = self
270 .entity_embeddings
271 .get(&negative.subject)
272 .unwrap()
273 .clone();
274 let t_neg = self
275 .entity_embeddings
276 .get(&negative.object)
277 .unwrap()
278 .clone();
279
280 let pos_grad = &h_pos + &r - &t_pos;
281 let neg_grad = &h_neg + &r - &t_neg;
282
283 let pos_norm = pos_grad.norm();
285 let neg_norm = neg_grad.norm();
286
287 let pos_grad_norm = if pos_norm > 0.0 {
288 &pos_grad / pos_norm
289 } else {
290 pos_grad
291 };
292 let neg_grad_norm = if neg_norm > 0.0 {
293 &neg_grad / neg_norm
294 } else {
295 neg_grad
296 };
297
298 let lr = self.config.learning_rate;
300
301 if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
303 *h -= lr * &pos_grad_norm;
304 let norm = h.norm();
306 if norm > 0.0 {
307 *h /= norm;
308 }
309 }
310
311 if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
312 *r -= lr * (&pos_grad_norm - &neg_grad_norm);
313 }
314
315 if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
316 *t += lr * &pos_grad_norm;
317 let norm = t.norm();
319 if norm > 0.0 {
320 *t /= norm;
321 }
322 }
323
324 if positive.subject != negative.subject {
326 if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
327 *h += lr * &neg_grad_norm;
328 let norm = h.norm();
330 if norm > 0.0 {
331 *h /= norm;
332 }
333 }
334 }
335
336 if positive.object != negative.object {
337 if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
338 *t -= lr * &neg_grad_norm;
339 let norm = t.norm();
341 if norm > 0.0 {
342 *t /= norm;
343 }
344 }
345 }
346 }
347 }
348 }
349}
350
351impl KGEmbeddingModel for TransE {
352 fn train(&mut self, triples: &[Triple]) -> Result<()> {
353 if triples.is_empty() {
354 return Err(anyhow!("No triples provided for training"));
355 }
356
357 self.initialize_embeddings(triples);
359
360 let mut rng = if let Some(seed) = self.config.random_seed {
361 Random::seed(seed)
362 } else {
363 Random::seed(42)
364 };
365
366 for epoch in 0..self.config.epochs {
368 let mut total_loss = 0.0;
369 let mut batch_count = 0;
370
371 let mut shuffled_triples = triples.to_vec();
373 for i in (1..shuffled_triples.len()).rev() {
376 let j = rng.gen_range(0..=i);
377 shuffled_triples.swap(i, j);
378 }
379
380 for batch in shuffled_triples.chunks(self.config.batch_size) {
382 for triple in batch {
383 let negatives = self.generate_negative_samples(triple, &mut rng);
385
386 let pos_dist = self.distance(triple);
388 for negative in &negatives {
389 let neg_dist = self.distance(negative);
390 let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
391 total_loss += loss;
392 }
393
394 self.update_embeddings(triple, &negatives);
396 }
397 batch_count += 1;
398 }
399
400 if epoch % 10 == 0 {
401 let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
402 tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
403 }
404 }
405
406 Ok(())
407 }
408
409 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
410 self.entity_embeddings
411 .get(entity)
412 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
413 }
414
415 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
416 self.relation_embeddings
417 .get(relation)
418 .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
419 }
420
421 fn score_triple(&self, triple: &Triple) -> f32 {
422 -self.distance(triple)
423 }
424
425 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
426 let h = match self.entity_embeddings.get(head) {
427 Some(emb) => emb,
428 None => return Vec::new(),
429 };
430
431 let r = match self.relation_embeddings.get(relation) {
432 Some(emb) => emb,
433 None => return Vec::new(),
434 };
435
436 let translation = h + r;
437
438 let mut scores: Vec<(String, f32)> = self
439 .entities
440 .iter()
441 .filter(|e| *e != head)
442 .filter_map(|entity| {
443 self.entity_embeddings.get(entity).map(|t| {
444 let distance = match self.config.norm {
445 1 => (&translation - t).iter().map(|x| x.abs()).sum(),
446 2 => (&translation - t).norm(),
447 _ => (&translation - t).norm(),
448 };
449 (entity.clone(), -distance)
450 })
451 })
452 .collect();
453
454 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
455 scores.truncate(k);
456 scores
457 }
458
459 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
460 let t = match self.entity_embeddings.get(tail) {
461 Some(emb) => emb,
462 None => return Vec::new(),
463 };
464
465 let r = match self.relation_embeddings.get(relation) {
466 Some(emb) => emb,
467 None => return Vec::new(),
468 };
469
470 let target = t - r;
471
472 let mut scores: Vec<(String, f32)> = self
473 .entities
474 .iter()
475 .filter(|e| *e != tail)
476 .filter_map(|entity| {
477 self.entity_embeddings.get(entity).map(|h| {
478 let distance = match self.config.norm {
479 1 => (h - &target).iter().map(|x| x.abs()).sum(),
480 2 => (h - &target).norm(),
481 _ => (h - &target).norm(),
482 };
483 (entity.clone(), -distance)
484 })
485 })
486 .collect();
487
488 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 scores.truncate(k);
490 scores
491 }
492
493 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
494 self.entity_embeddings
495 .iter()
496 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
497 .collect()
498 }
499
500 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
501 self.relation_embeddings
502 .iter()
503 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
504 .collect()
505 }
506}
507
508pub struct ComplEx {
511 config: KGEmbeddingConfig,
512 entity_embeddings_real: HashMap<String, DVector<f32>>,
513 entity_embeddings_imag: HashMap<String, DVector<f32>>,
514 relation_embeddings_real: HashMap<String, DVector<f32>>,
515 relation_embeddings_imag: HashMap<String, DVector<f32>>,
516 entities: Vec<String>,
517 relations: Vec<String>,
518}
519
520impl ComplEx {
521 pub fn new(config: KGEmbeddingConfig) -> Self {
522 Self {
523 config,
524 entity_embeddings_real: HashMap::new(),
525 entity_embeddings_imag: HashMap::new(),
526 relation_embeddings_real: HashMap::new(),
527 relation_embeddings_imag: HashMap::new(),
528 entities: Vec::new(),
529 relations: Vec::new(),
530 }
531 }
532
533 fn initialize_embeddings(&mut self, triples: &[Triple]) {
535 let mut entities = std::collections::HashSet::new();
537 let mut relations = std::collections::HashSet::new();
538
539 for triple in triples {
540 entities.insert(triple.subject.clone());
541 entities.insert(triple.object.clone());
542 relations.insert(triple.predicate.clone());
543 }
544
545 self.entities = entities.into_iter().collect();
546 self.relations = relations.into_iter().collect();
547
548 let mut rng = if let Some(seed) = self.config.random_seed {
550 Random::seed(seed)
551 } else {
552 Random::seed(42)
553 };
554
555 let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
556 let normal = Normal::new(0.0, std_dev).unwrap();
557
558 for entity in &self.entities {
560 let real_values: Vec<f32> = (0..self.config.dimensions)
561 .map(|_| normal.sample(&mut rng))
562 .collect();
563 let imag_values: Vec<f32> = (0..self.config.dimensions)
564 .map(|_| normal.sample(&mut rng))
565 .collect();
566
567 self.entity_embeddings_real
568 .insert(entity.clone(), DVector::from_vec(real_values));
569 self.entity_embeddings_imag
570 .insert(entity.clone(), DVector::from_vec(imag_values));
571 }
572
573 for relation in &self.relations {
575 let real_values: Vec<f32> = (0..self.config.dimensions)
576 .map(|_| normal.sample(&mut rng))
577 .collect();
578 let imag_values: Vec<f32> = (0..self.config.dimensions)
579 .map(|_| normal.sample(&mut rng))
580 .collect();
581
582 self.relation_embeddings_real
583 .insert(relation.clone(), DVector::from_vec(real_values));
584 self.relation_embeddings_imag
585 .insert(relation.clone(), DVector::from_vec(imag_values));
586 }
587 }
588
589 fn hermitian_dot(&self, triple: &Triple) -> f32 {
591 let h_real = self.entity_embeddings_real.get(&triple.subject).unwrap();
592 let h_imag = self.entity_embeddings_imag.get(&triple.subject).unwrap();
593 let r_real = self
594 .relation_embeddings_real
595 .get(&triple.predicate)
596 .unwrap();
597 let r_imag = self
598 .relation_embeddings_imag
599 .get(&triple.predicate)
600 .unwrap();
601 let t_real = self.entity_embeddings_real.get(&triple.object).unwrap();
602 let t_imag = self.entity_embeddings_imag.get(&triple.object).unwrap();
603
604 let mut score = 0.0;
610 for i in 0..self.config.dimensions {
611 score += h_real[i] * r_real[i] * t_real[i]
612 + h_real[i] * r_imag[i] * t_imag[i]
613 + h_imag[i] * r_real[i] * t_imag[i]
614 - h_imag[i] * r_imag[i] * t_real[i];
615 }
616
617 score
618 }
619}
620
621impl KGEmbeddingModel for ComplEx {
622 fn train(&mut self, triples: &[Triple]) -> Result<()> {
623 if triples.is_empty() {
624 return Err(anyhow!("No triples provided for training"));
625 }
626
627 self.initialize_embeddings(triples);
629
630 Ok(())
634 }
635
636 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
637 let real = self.entity_embeddings_real.get(entity)?;
639 let imag = self.entity_embeddings_imag.get(entity)?;
640
641 let mut values = Vec::with_capacity(self.config.dimensions * 2);
642 values.extend(real.iter().cloned());
643 values.extend(imag.iter().cloned());
644
645 Some(Vector::new(values))
646 }
647
648 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
649 let real = self.relation_embeddings_real.get(relation)?;
651 let imag = self.relation_embeddings_imag.get(relation)?;
652
653 let mut values = Vec::with_capacity(self.config.dimensions * 2);
654 values.extend(real.iter().cloned());
655 values.extend(imag.iter().cloned());
656
657 Some(Vector::new(values))
658 }
659
660 fn score_triple(&self, triple: &Triple) -> f32 {
661 self.hermitian_dot(triple)
662 }
663
664 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
665 let mut scores: Vec<(String, f32)> = self
666 .entities
667 .iter()
668 .filter(|e| *e != head)
669 .map(|tail| {
670 let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
671 let score = self.score_triple(&triple);
672 (tail.clone(), score)
673 })
674 .collect();
675
676 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
677 scores.truncate(k);
678 scores
679 }
680
681 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
682 let mut scores: Vec<(String, f32)> = self
683 .entities
684 .iter()
685 .filter(|e| *e != tail)
686 .map(|head| {
687 let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
688 let score = self.score_triple(&triple);
689 (head.clone(), score)
690 })
691 .collect();
692
693 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
694 scores.truncate(k);
695 scores
696 }
697
698 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
699 self.entity_embeddings_real
700 .iter()
701 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
702 .collect()
703 }
704
705 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
706 self.relation_embeddings_real
707 .iter()
708 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
709 .collect()
710 }
711}
712
713pub struct RotatE {
716 config: KGEmbeddingConfig,
717 entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
718 relation_embeddings: HashMap<String, DVector<f32>>, entities: Vec<String>,
720 relations: Vec<String>,
721}
722
723impl RotatE {
724 pub fn new(config: KGEmbeddingConfig) -> Self {
725 Self {
726 config,
727 entity_embeddings: HashMap::new(),
728 relation_embeddings: HashMap::new(),
729 entities: Vec::new(),
730 relations: Vec::new(),
731 }
732 }
733
734 fn initialize_embeddings(&mut self, triples: &[Triple]) {
736 let mut entities = std::collections::HashSet::new();
738 let mut relations = std::collections::HashSet::new();
739
740 for triple in triples {
741 entities.insert(triple.subject.clone());
742 entities.insert(triple.object.clone());
743 relations.insert(triple.predicate.clone());
744 }
745
746 self.entities = entities.into_iter().collect();
747 self.relations = relations.into_iter().collect();
748
749 let mut rng = if let Some(seed) = self.config.random_seed {
750 Random::seed(seed)
751 } else {
752 Random::seed(42)
753 };
754
755 let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
757
758 for entity in &self.entities {
759 let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
760 .map(|_| {
761 let phase = rng.gen_range(phase_range.clone());
762 Complex::new(phase.cos(), phase.sin())
763 })
764 .collect();
765
766 self.entity_embeddings
767 .insert(entity.clone(), DVector::from_vec(phases));
768 }
769
770 for relation in &self.relations {
772 let phases: Vec<f32> = (0..self.config.dimensions)
773 .map(|_| rng.gen_range(phase_range.clone()))
774 .collect();
775
776 self.relation_embeddings
777 .insert(relation.clone(), DVector::from_vec(phases));
778 }
779 }
780
781 fn distance(&self, triple: &Triple) -> f32 {
783 let h = self.entity_embeddings.get(&triple.subject).unwrap();
784 let r_phases = self.relation_embeddings.get(&triple.predicate).unwrap();
785 let t = self.entity_embeddings.get(&triple.object).unwrap();
786
787 let r: DVector<Complex<f32>> = DVector::from_iterator(
789 self.config.dimensions,
790 r_phases
791 .iter()
792 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
793 );
794
795 let rotated: DVector<Complex<f32>> = h.component_mul(&r);
797
798 let diff = rotated - t;
800 diff.iter().map(|c| c.norm()).sum::<f32>()
801 }
802}
803
804impl KGEmbeddingModel for RotatE {
805 fn train(&mut self, triples: &[Triple]) -> Result<()> {
806 if triples.is_empty() {
807 return Err(anyhow!("No triples provided for training"));
808 }
809
810 self.initialize_embeddings(triples);
812
813 Ok(())
817 }
818
819 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
820 let complex_emb = self.entity_embeddings.get(entity)?;
822
823 let mut values = Vec::with_capacity(self.config.dimensions * 2);
824 for c in complex_emb.iter() {
825 values.push(c.re); values.push(c.im); }
828
829 Some(Vector::new(values))
830 }
831
832 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
833 self.relation_embeddings
834 .get(relation)
835 .map(|phases| Vector::new(phases.iter().cloned().collect()))
836 }
837
838 fn score_triple(&self, triple: &Triple) -> f32 {
839 let gamma = 12.0; gamma - self.distance(triple)
841 }
842
843 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
844 let h = match self.entity_embeddings.get(head) {
845 Some(emb) => emb,
846 None => return Vec::new(),
847 };
848
849 let r_phases = match self.relation_embeddings.get(relation) {
850 Some(emb) => emb,
851 None => return Vec::new(),
852 };
853
854 let r: DVector<Complex<f32>> = DVector::from_iterator(
856 self.config.dimensions,
857 r_phases
858 .iter()
859 .map(|&phase| Complex::new(phase.cos(), phase.sin())),
860 );
861
862 let rotated = h.component_mul(&r);
864
865 let mut scores: Vec<(String, f32)> = self
866 .entities
867 .iter()
868 .filter(|e| *e != head)
869 .filter_map(|entity| {
870 self.entity_embeddings.get(entity).map(|t| {
871 let diff = &rotated - t;
872 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
873 (entity.clone(), -distance)
874 })
875 })
876 .collect();
877
878 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
879 scores.truncate(k);
880 scores
881 }
882
883 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
884 let t = match self.entity_embeddings.get(tail) {
885 Some(emb) => emb,
886 None => return Vec::new(),
887 };
888
889 let r_phases = match self.relation_embeddings.get(relation) {
890 Some(emb) => emb,
891 None => return Vec::new(),
892 };
893
894 let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
896 self.config.dimensions,
897 r_phases
898 .iter()
899 .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
900 );
901
902 let mut scores: Vec<(String, f32)> = self
903 .entities
904 .iter()
905 .filter(|e| *e != tail)
906 .filter_map(|entity| {
907 self.entity_embeddings.get(entity).map(|h| {
908 let rotated = h.component_mul(&r_inv);
909 let diff = rotated - t;
910 let distance: f32 = diff.iter().map(|c| c.norm()).sum();
911 (entity.clone(), -distance)
912 })
913 })
914 .collect();
915
916 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
917 scores.truncate(k);
918 scores
919 }
920
921 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
922 self.entity_embeddings
923 .iter()
924 .map(|(k, v)| {
925 let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
926 (k.clone(), Vector::new(real_values))
927 })
928 .collect()
929 }
930
931 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
932 self.relation_embeddings
933 .iter()
934 .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
935 .collect()
936 }
937}
938
939pub struct KGEmbedding {
941 model: Box<dyn KGEmbeddingModel>,
942 config: KGEmbeddingConfig,
943}
944
945impl KGEmbedding {
946 pub fn new(config: KGEmbeddingConfig) -> Self {
948 let model: Box<dyn KGEmbeddingModel> = match config.model {
949 KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
950 KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
951 KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
952 KGEmbeddingModelType::GCN => {
953 let gcn = GCN::new(config.clone());
955 Box::new(GCNAdapter::new(gcn))
956 }
957 KGEmbeddingModelType::GraphSAGE => {
958 let graphsage = GraphSAGE::new(config.clone())
960 .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
961 Box::new(GraphSAGEAdapter::new(graphsage))
962 }
963 };
964
965 Self { model, config }
966 }
967
968 pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
970 self.model.train(triples)
971 }
972
973 pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
975 self.model.get_entity_embedding(entity)
976 }
977
978 pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
980 self.model.get_relation_embedding(relation)
981 }
982
983 pub fn score_triple(&self, triple: &Triple) -> f32 {
985 self.model.score_triple(triple)
986 }
987
988 pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
990 self.model.predict_tail(head, relation, k)
991 }
992
993 pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
995 self.model.predict_head(relation, tail, k)
996 }
997
998 pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1000 self.model.score_triple(triple) > threshold
1001 }
1002}
1003
1004pub struct GCNAdapter {
1006 gcn: GCN,
1007}
1008
1009impl GCNAdapter {
1010 pub fn new(gcn: GCN) -> Self {
1011 Self { gcn }
1012 }
1013}
1014
1015impl KGEmbeddingModel for GCNAdapter {
1016 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1017 Ok(())
1019 }
1020
1021 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1022 Some(Vector::new(vec![0.0; 128]))
1025 }
1026
1027 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1028 Some(Vector::new(vec![0.0; 128]))
1030 }
1031
1032 fn score_triple(&self, _triple: &Triple) -> f32 {
1033 0.5
1035 }
1036
1037 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1038 vec![]
1040 }
1041
1042 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1043 vec![]
1045 }
1046
1047 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1048 HashMap::new()
1049 }
1050
1051 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1052 HashMap::new()
1053 }
1054}
1055
1056pub struct GraphSAGEAdapter {
1058 graphsage: GraphSAGE,
1059}
1060
1061impl GraphSAGEAdapter {
1062 pub fn new(graphsage: GraphSAGE) -> Self {
1063 Self { graphsage }
1064 }
1065}
1066
1067impl KGEmbeddingModel for GraphSAGEAdapter {
1068 fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1069 Ok(())
1071 }
1072
1073 fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1074 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1076 }
1077
1078 fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1079 Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1081 }
1082
1083 fn score_triple(&self, _triple: &Triple) -> f32 {
1084 0.5
1086 }
1087
1088 fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1089 vec![]
1091 }
1092
1093 fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1094 vec![]
1096 }
1097
1098 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1099 HashMap::new()
1100 }
1101
1102 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1103 HashMap::new()
1104 }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109 use super::*;
1110
1111 fn create_test_triples() -> Vec<Triple> {
1112 vec![
1113 Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1114 Triple::new(
1115 "Bob".to_string(),
1116 "knows".to_string(),
1117 "Charlie".to_string(),
1118 ),
1119 Triple::new(
1120 "Alice".to_string(),
1121 "likes".to_string(),
1122 "Pizza".to_string(),
1123 ),
1124 Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1125 Triple::new(
1126 "Charlie".to_string(),
1127 "knows".to_string(),
1128 "Alice".to_string(),
1129 ),
1130 ]
1131 }
1132
1133 #[test]
1134 fn test_transe() {
1135 let config = KGEmbeddingConfig {
1136 model: KGEmbeddingModelType::TransE,
1137 dimensions: 50,
1138 epochs: 10,
1139 ..Default::default()
1140 };
1141
1142 let mut model = KGEmbedding::new(config);
1143 let triples = create_test_triples();
1144
1145 model.train(&triples).unwrap();
1146
1147 assert!(model.get_entity_embedding("Alice").is_some());
1149 assert!(model.get_relation_embedding("knows").is_some());
1150
1151 let score = model.score_triple(&triples[0]);
1153 assert!(score.is_finite());
1154
1155 let predictions = model.predict_tail("Alice", "knows", 2);
1157 assert!(!predictions.is_empty());
1158 }
1159
1160 #[test]
1161 fn test_complex() {
1162 let config = KGEmbeddingConfig {
1163 model: KGEmbeddingModelType::ComplEx,
1164 dimensions: 50,
1165 epochs: 10,
1166 ..Default::default()
1167 };
1168
1169 let mut model = KGEmbedding::new(config);
1170 let triples = create_test_triples();
1171
1172 model.train(&triples).unwrap();
1173
1174 assert!(model.get_entity_embedding("Bob").is_some());
1176 let emb = model.get_entity_embedding("Bob").unwrap();
1177 assert_eq!(emb.dimensions, 100); }
1179
1180 #[test]
1181 fn test_rotate() {
1182 let config = KGEmbeddingConfig {
1183 model: KGEmbeddingModelType::RotatE,
1184 dimensions: 50,
1185 epochs: 10,
1186 ..Default::default()
1187 };
1188
1189 let mut model = KGEmbedding::new(config);
1190 let triples = create_test_triples();
1191
1192 model.train(&triples).unwrap();
1193
1194 assert!(model.get_entity_embedding("Charlie").is_some());
1196 assert!(model.get_relation_embedding("likes").is_some());
1197
1198 let rel_emb = model.get_relation_embedding("likes").unwrap();
1200 assert_eq!(rel_emb.dimensions, 50);
1201 }
1202}