oxirs_vec/
kg_embeddings.rs

1//! Knowledge Graph Embeddings for RDF data
2//!
3//! This module implements various knowledge graph embedding methods:
4//! - TransE: Translation-based embeddings
5//! - ComplEx: Complex number embeddings
6//! - RotatE: Rotation-based embeddings
7
8use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::random_utils::NormalSampler as Normal;
10use crate::Vector;
11use anyhow::{anyhow, Result};
12use nalgebra::{Complex, DVector};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17/// Knowledge graph embedding model type
18#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20    /// Translation-based embeddings (TransE)
21    TransE,
22    /// Complex number embeddings (ComplEx)
23    ComplEx,
24    /// Rotation-based embeddings (RotatE)
25    RotatE,
26    /// Graph Convolutional Network (GCN)
27    GCN,
28    /// GraphSAGE (Graph Sample and Aggregate)
29    GraphSAGE,
30}
31
32/// Configuration for knowledge graph embeddings
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35    /// Model type
36    pub model: KGEmbeddingModelType,
37    /// Embedding dimensions
38    pub dimensions: usize,
39    /// Learning rate
40    pub learning_rate: f32,
41    /// Margin for loss function
42    pub margin: f32,
43    /// Negative sampling ratio
44    pub negative_samples: usize,
45    /// Batch size for training
46    pub batch_size: usize,
47    /// Number of epochs
48    pub epochs: usize,
49    /// L1 or L2 norm
50    pub norm: usize,
51    /// Random seed
52    pub random_seed: Option<u64>,
53    /// Regularization weight
54    pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58    fn default() -> Self {
59        Self {
60            model: KGEmbeddingModelType::TransE,
61            dimensions: 100,
62            learning_rate: 0.01,
63            margin: 1.0,
64            negative_samples: 10,
65            batch_size: 100,
66            epochs: 100,
67            norm: 2,
68            random_seed: Some(42),
69            regularization: 0.0,
70        }
71    }
72}
73
74/// Triple for knowledge graph
75#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77    pub subject: String,
78    pub predicate: String,
79    pub object: String,
80}
81
82impl Triple {
83    pub fn new(subject: String, predicate: String, object: String) -> Self {
84        Self {
85            subject,
86            predicate,
87            object,
88        }
89    }
90}
91
92/// Base trait for knowledge graph embedding models
93pub trait KGEmbeddingModel: Send + Sync {
94    /// Train the model on triples
95    fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97    /// Get entity embedding
98    fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100    /// Get relation embedding
101    fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103    /// Score a triple
104    fn score_triple(&self, triple: &Triple) -> f32;
105
106    /// Predict tail entities for (head, relation, ?)
107    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109    /// Predict head entities for (?, relation, tail)
110    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112    /// Get all entity embeddings
113    fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115    /// Get all relation embeddings
116    fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119/// TransE: Translation-based embeddings
120/// Learns embeddings where h + r ≈ t for triple (h, r, t)
121pub struct TransE {
122    config: KGEmbeddingConfig,
123    entity_embeddings: HashMap<String, DVector<f32>>,
124    relation_embeddings: HashMap<String, DVector<f32>>,
125    entities: Vec<String>,
126    relations: Vec<String>,
127}
128
129impl TransE {
130    pub fn new(config: KGEmbeddingConfig) -> Self {
131        Self {
132            config,
133            entity_embeddings: HashMap::new(),
134            relation_embeddings: HashMap::new(),
135            entities: Vec::new(),
136            relations: Vec::new(),
137        }
138    }
139
140    /// Initialize embeddings
141    fn initialize_embeddings(&mut self, triples: &[Triple]) {
142        // Collect unique entities and relations
143        let mut entities = std::collections::HashSet::new();
144        let mut relations = std::collections::HashSet::new();
145
146        for triple in triples {
147            entities.insert(triple.subject.clone());
148            entities.insert(triple.object.clone());
149            relations.insert(triple.predicate.clone());
150        }
151
152        self.entities = entities.into_iter().collect();
153        self.relations = relations.into_iter().collect();
154
155        // Initialize embeddings with uniform distribution
156        let mut rng = if let Some(seed) = self.config.random_seed {
157            Random::seed(seed)
158        } else {
159            Random::seed(42)
160        };
161
162        let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163        let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165        // Initialize entity embeddings
166        for entity in &self.entities {
167            let values: Vec<f32> = (0..self.config.dimensions)
168                .map(|_| rng.random_range(range_min, range_max))
169                .collect();
170            let mut embedding = DVector::from_vec(values);
171
172            // Normalize entities
173            let norm = embedding.norm();
174            if norm > 0.0 {
175                embedding /= norm;
176            }
177
178            self.entity_embeddings.insert(entity.clone(), embedding);
179        }
180
181        // Initialize relation embeddings
182        for relation in &self.relations {
183            let values: Vec<f32> = (0..self.config.dimensions)
184                .map(|_| rng.random_range(range_min, range_max))
185                .collect();
186            let embedding = DVector::from_vec(values);
187
188            // Relations are not normalized in TransE
189            self.relation_embeddings.insert(relation.clone(), embedding);
190        }
191    }
192
193    /// Generate negative samples
194    #[allow(deprecated)]
195    fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
196        let mut negatives = Vec::new();
197
198        for _ in 0..self.config.negative_samples {
199            if rng.gen_bool(0.5) {
200                // Corrupt head
201                let mut negative = triple.clone();
202                loop {
203                    let idx = rng.gen_range(0..self.entities.len());
204                    let entity = &self.entities[idx];
205                    if entity != &triple.subject {
206                        negative.subject = entity.clone();
207                        break;
208                    }
209                }
210                negatives.push(negative);
211            } else {
212                // Corrupt tail
213                let mut negative = triple.clone();
214                loop {
215                    let idx = rng.gen_range(0..self.entities.len());
216                    let entity = &self.entities[idx];
217                    if entity != &triple.object {
218                        negative.object = entity.clone();
219                        break;
220                    }
221                }
222                negatives.push(negative);
223            }
224        }
225
226        negatives
227    }
228
229    /// Calculate distance for a triple
230    fn distance(&self, triple: &Triple) -> f32 {
231        let h = self.entity_embeddings.get(&triple.subject).unwrap();
232        let r = self.relation_embeddings.get(&triple.predicate).unwrap();
233        let t = self.entity_embeddings.get(&triple.object).unwrap();
234
235        let translation = h + r - t;
236
237        match self.config.norm {
238            1 => translation.iter().map(|x| x.abs()).sum(),
239            2 => translation.norm(),
240            _ => translation.norm(),
241        }
242    }
243
244    /// Update embeddings using gradient descent
245    fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
246        let pos_dist = self.distance(positive);
247
248        for negative in negatives {
249            let neg_dist = self.distance(negative);
250            let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
251
252            if loss > 0.0 {
253                // Calculate gradients
254                let h_pos = self
255                    .entity_embeddings
256                    .get(&positive.subject)
257                    .unwrap()
258                    .clone();
259                let r = self
260                    .relation_embeddings
261                    .get(&positive.predicate)
262                    .unwrap()
263                    .clone();
264                let t_pos = self
265                    .entity_embeddings
266                    .get(&positive.object)
267                    .unwrap()
268                    .clone();
269
270                let h_neg = self
271                    .entity_embeddings
272                    .get(&negative.subject)
273                    .unwrap()
274                    .clone();
275                let t_neg = self
276                    .entity_embeddings
277                    .get(&negative.object)
278                    .unwrap()
279                    .clone();
280
281                let pos_grad = &h_pos + &r - &t_pos;
282                let neg_grad = &h_neg + &r - &t_neg;
283
284                // Normalize gradients
285                let pos_norm = pos_grad.norm();
286                let neg_norm = neg_grad.norm();
287
288                let pos_grad_norm = if pos_norm > 0.0 {
289                    &pos_grad / pos_norm
290                } else {
291                    pos_grad
292                };
293                let neg_grad_norm = if neg_norm > 0.0 {
294                    &neg_grad / neg_norm
295                } else {
296                    neg_grad
297                };
298
299                // Update embeddings
300                let lr = self.config.learning_rate;
301
302                // Update positive triple embeddings
303                if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
304                    *h -= lr * &pos_grad_norm;
305                    // Re-normalize entity
306                    let norm = h.norm();
307                    if norm > 0.0 {
308                        *h /= norm;
309                    }
310                }
311
312                if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
313                    *r -= lr * (&pos_grad_norm - &neg_grad_norm);
314                }
315
316                if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
317                    *t += lr * &pos_grad_norm;
318                    // Re-normalize entity
319                    let norm = t.norm();
320                    if norm > 0.0 {
321                        *t /= norm;
322                    }
323                }
324
325                // Update negative triple embeddings
326                if positive.subject != negative.subject {
327                    if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
328                        *h += lr * &neg_grad_norm;
329                        // Re-normalize entity
330                        let norm = h.norm();
331                        if norm > 0.0 {
332                            *h /= norm;
333                        }
334                    }
335                }
336
337                if positive.object != negative.object {
338                    if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
339                        *t -= lr * &neg_grad_norm;
340                        // Re-normalize entity
341                        let norm = t.norm();
342                        if norm > 0.0 {
343                            *t /= norm;
344                        }
345                    }
346                }
347            }
348        }
349    }
350}
351
352impl KGEmbeddingModel for TransE {
353    fn train(&mut self, triples: &[Triple]) -> Result<()> {
354        if triples.is_empty() {
355            return Err(anyhow!("No triples provided for training"));
356        }
357
358        // Initialize embeddings
359        self.initialize_embeddings(triples);
360
361        let mut rng = if let Some(seed) = self.config.random_seed {
362            Random::seed(seed)
363        } else {
364            Random::seed(42)
365        };
366
367        // Training loop
368        for epoch in 0..self.config.epochs {
369            let mut total_loss = 0.0;
370            let mut batch_count = 0;
371
372            // Shuffle triples
373            let mut shuffled_triples = triples.to_vec();
374            // Note: Using manual random selection instead of SliceRandom
375            // Manually shuffle using Fisher-Yates algorithm
376            for i in (1..shuffled_triples.len()).rev() {
377                let j = rng.random_range(0, i + 1);
378                shuffled_triples.swap(i, j);
379            }
380
381            // Process batches
382            for batch in shuffled_triples.chunks(self.config.batch_size) {
383                for triple in batch {
384                    // Generate negative samples
385                    let negatives = self.generate_negative_samples(triple, &mut rng);
386
387                    // Calculate loss
388                    let pos_dist = self.distance(triple);
389                    for negative in &negatives {
390                        let neg_dist = self.distance(negative);
391                        let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
392                        total_loss += loss;
393                    }
394
395                    // Update embeddings
396                    self.update_embeddings(triple, &negatives);
397                }
398                batch_count += 1;
399            }
400
401            if epoch % 10 == 0 {
402                let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
403                tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
404            }
405        }
406
407        Ok(())
408    }
409
410    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
411        self.entity_embeddings
412            .get(entity)
413            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
414    }
415
416    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
417        self.relation_embeddings
418            .get(relation)
419            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
420    }
421
422    fn score_triple(&self, triple: &Triple) -> f32 {
423        -self.distance(triple)
424    }
425
426    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
427        let h = match self.entity_embeddings.get(head) {
428            Some(emb) => emb,
429            None => return Vec::new(),
430        };
431
432        let r = match self.relation_embeddings.get(relation) {
433            Some(emb) => emb,
434            None => return Vec::new(),
435        };
436
437        let translation = h + r;
438
439        let mut scores: Vec<(String, f32)> = self
440            .entities
441            .iter()
442            .filter(|e| *e != head)
443            .filter_map(|entity| {
444                self.entity_embeddings.get(entity).map(|t| {
445                    let distance = match self.config.norm {
446                        1 => (&translation - t).iter().map(|x| x.abs()).sum(),
447                        2 => (&translation - t).norm(),
448                        _ => (&translation - t).norm(),
449                    };
450                    (entity.clone(), -distance)
451                })
452            })
453            .collect();
454
455        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
456        scores.truncate(k);
457        scores
458    }
459
460    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
461        let t = match self.entity_embeddings.get(tail) {
462            Some(emb) => emb,
463            None => return Vec::new(),
464        };
465
466        let r = match self.relation_embeddings.get(relation) {
467            Some(emb) => emb,
468            None => return Vec::new(),
469        };
470
471        let target = t - r;
472
473        let mut scores: Vec<(String, f32)> = self
474            .entities
475            .iter()
476            .filter(|e| *e != tail)
477            .filter_map(|entity| {
478                self.entity_embeddings.get(entity).map(|h| {
479                    let distance = match self.config.norm {
480                        1 => (h - &target).iter().map(|x| x.abs()).sum(),
481                        2 => (h - &target).norm(),
482                        _ => (h - &target).norm(),
483                    };
484                    (entity.clone(), -distance)
485                })
486            })
487            .collect();
488
489        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
490        scores.truncate(k);
491        scores
492    }
493
494    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
495        self.entity_embeddings
496            .iter()
497            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
498            .collect()
499    }
500
501    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
502        self.relation_embeddings
503            .iter()
504            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
505            .collect()
506    }
507}
508
509/// ComplEx: Complex number embeddings
510/// Uses complex-valued embeddings and Hermitian dot product
511pub struct ComplEx {
512    config: KGEmbeddingConfig,
513    entity_embeddings_real: HashMap<String, DVector<f32>>,
514    entity_embeddings_imag: HashMap<String, DVector<f32>>,
515    relation_embeddings_real: HashMap<String, DVector<f32>>,
516    relation_embeddings_imag: HashMap<String, DVector<f32>>,
517    entities: Vec<String>,
518    relations: Vec<String>,
519}
520
521impl ComplEx {
522    pub fn new(config: KGEmbeddingConfig) -> Self {
523        Self {
524            config,
525            entity_embeddings_real: HashMap::new(),
526            entity_embeddings_imag: HashMap::new(),
527            relation_embeddings_real: HashMap::new(),
528            relation_embeddings_imag: HashMap::new(),
529            entities: Vec::new(),
530            relations: Vec::new(),
531        }
532    }
533
534    /// Initialize embeddings with Xavier initialization
535    fn initialize_embeddings(&mut self, triples: &[Triple]) {
536        // Collect unique entities and relations
537        let mut entities = std::collections::HashSet::new();
538        let mut relations = std::collections::HashSet::new();
539
540        for triple in triples {
541            entities.insert(triple.subject.clone());
542            entities.insert(triple.object.clone());
543            relations.insert(triple.predicate.clone());
544        }
545
546        self.entities = entities.into_iter().collect();
547        self.relations = relations.into_iter().collect();
548
549        // Initialize with Xavier initialization
550        let mut rng = if let Some(seed) = self.config.random_seed {
551            Random::seed(seed)
552        } else {
553            Random::seed(42)
554        };
555
556        let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
557        let normal = Normal::new(0.0, std_dev).unwrap();
558
559        // Initialize entity embeddings
560        for entity in &self.entities {
561            let real_values: Vec<f32> = (0..self.config.dimensions)
562                .map(|_| normal.sample(&mut rng))
563                .collect();
564            let imag_values: Vec<f32> = (0..self.config.dimensions)
565                .map(|_| normal.sample(&mut rng))
566                .collect();
567
568            self.entity_embeddings_real
569                .insert(entity.clone(), DVector::from_vec(real_values));
570            self.entity_embeddings_imag
571                .insert(entity.clone(), DVector::from_vec(imag_values));
572        }
573
574        // Initialize relation embeddings
575        for relation in &self.relations {
576            let real_values: Vec<f32> = (0..self.config.dimensions)
577                .map(|_| normal.sample(&mut rng))
578                .collect();
579            let imag_values: Vec<f32> = (0..self.config.dimensions)
580                .map(|_| normal.sample(&mut rng))
581                .collect();
582
583            self.relation_embeddings_real
584                .insert(relation.clone(), DVector::from_vec(real_values));
585            self.relation_embeddings_imag
586                .insert(relation.clone(), DVector::from_vec(imag_values));
587        }
588    }
589
590    /// Hermitian dot product for scoring
591    fn hermitian_dot(&self, triple: &Triple) -> f32 {
592        let h_real = self.entity_embeddings_real.get(&triple.subject).unwrap();
593        let h_imag = self.entity_embeddings_imag.get(&triple.subject).unwrap();
594        let r_real = self
595            .relation_embeddings_real
596            .get(&triple.predicate)
597            .unwrap();
598        let r_imag = self
599            .relation_embeddings_imag
600            .get(&triple.predicate)
601            .unwrap();
602        let t_real = self.entity_embeddings_real.get(&triple.object).unwrap();
603        let t_imag = self.entity_embeddings_imag.get(&triple.object).unwrap();
604
605        // ComplEx scoring function: Re(<h, r, t̄>)
606        // = Re(∑ h_i * r_i * conj(t_i))
607        // = ∑ (h_real * r_real * t_real + h_real * r_imag * t_imag +
608        //      h_imag * r_real * t_imag - h_imag * r_imag * t_real)
609
610        let mut score = 0.0;
611        for i in 0..self.config.dimensions {
612            score += h_real[i] * r_real[i] * t_real[i]
613                + h_real[i] * r_imag[i] * t_imag[i]
614                + h_imag[i] * r_real[i] * t_imag[i]
615                - h_imag[i] * r_imag[i] * t_real[i];
616        }
617
618        score
619    }
620}
621
622impl KGEmbeddingModel for ComplEx {
623    fn train(&mut self, triples: &[Triple]) -> Result<()> {
624        if triples.is_empty() {
625            return Err(anyhow!("No triples provided for training"));
626        }
627
628        // Initialize embeddings
629        self.initialize_embeddings(triples);
630
631        // Training implementation would go here
632        // For brevity, using a simplified version
633
634        Ok(())
635    }
636
637    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
638        // Return concatenated real and imaginary parts
639        let real = self.entity_embeddings_real.get(entity)?;
640        let imag = self.entity_embeddings_imag.get(entity)?;
641
642        let mut values = Vec::with_capacity(self.config.dimensions * 2);
643        values.extend(real.iter().cloned());
644        values.extend(imag.iter().cloned());
645
646        Some(Vector::new(values))
647    }
648
649    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
650        // Return concatenated real and imaginary parts
651        let real = self.relation_embeddings_real.get(relation)?;
652        let imag = self.relation_embeddings_imag.get(relation)?;
653
654        let mut values = Vec::with_capacity(self.config.dimensions * 2);
655        values.extend(real.iter().cloned());
656        values.extend(imag.iter().cloned());
657
658        Some(Vector::new(values))
659    }
660
661    fn score_triple(&self, triple: &Triple) -> f32 {
662        self.hermitian_dot(triple)
663    }
664
665    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
666        let mut scores: Vec<(String, f32)> = self
667            .entities
668            .iter()
669            .filter(|e| *e != head)
670            .map(|tail| {
671                let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
672                let score = self.score_triple(&triple);
673                (tail.clone(), score)
674            })
675            .collect();
676
677        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
678        scores.truncate(k);
679        scores
680    }
681
682    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
683        let mut scores: Vec<(String, f32)> = self
684            .entities
685            .iter()
686            .filter(|e| *e != tail)
687            .map(|head| {
688                let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
689                let score = self.score_triple(&triple);
690                (head.clone(), score)
691            })
692            .collect();
693
694        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
695        scores.truncate(k);
696        scores
697    }
698
699    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
700        self.entity_embeddings_real
701            .iter()
702            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
703            .collect()
704    }
705
706    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
707        self.relation_embeddings_real
708            .iter()
709            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
710            .collect()
711    }
712}
713
714/// RotatE: Rotation-based embeddings
715/// Models relations as rotations in complex space
716pub struct RotatE {
717    config: KGEmbeddingConfig,
718    entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
719    relation_embeddings: HashMap<String, DVector<f32>>, // Phase angles
720    entities: Vec<String>,
721    relations: Vec<String>,
722}
723
724impl RotatE {
725    pub fn new(config: KGEmbeddingConfig) -> Self {
726        Self {
727            config,
728            entity_embeddings: HashMap::new(),
729            relation_embeddings: HashMap::new(),
730            entities: Vec::new(),
731            relations: Vec::new(),
732        }
733    }
734
735    /// Initialize embeddings
736    fn initialize_embeddings(&mut self, triples: &[Triple]) {
737        // Collect unique entities and relations
738        let mut entities = std::collections::HashSet::new();
739        let mut relations = std::collections::HashSet::new();
740
741        for triple in triples {
742            entities.insert(triple.subject.clone());
743            entities.insert(triple.object.clone());
744            relations.insert(triple.predicate.clone());
745        }
746
747        self.entities = entities.into_iter().collect();
748        self.relations = relations.into_iter().collect();
749
750        let mut rng = if let Some(seed) = self.config.random_seed {
751            Random::seed(seed)
752        } else {
753            Random::seed(42)
754        };
755
756        // Initialize entity embeddings (complex numbers with unit modulus)
757        let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
758
759        for entity in &self.entities {
760            let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
761                .map(|_| {
762                    let phase = rng.gen_range(phase_range.clone());
763                    Complex::new(phase.cos(), phase.sin())
764                })
765                .collect();
766
767            self.entity_embeddings
768                .insert(entity.clone(), DVector::from_vec(phases));
769        }
770
771        // Initialize relation embeddings (phase angles)
772        for relation in &self.relations {
773            let phases: Vec<f32> = (0..self.config.dimensions)
774                .map(|_| rng.gen_range(phase_range.clone()))
775                .collect();
776
777            self.relation_embeddings
778                .insert(relation.clone(), DVector::from_vec(phases));
779        }
780    }
781
782    /// Calculate distance for RotatE
783    fn distance(&self, triple: &Triple) -> f32 {
784        let h = self.entity_embeddings.get(&triple.subject).unwrap();
785        let r_phases = self.relation_embeddings.get(&triple.predicate).unwrap();
786        let t = self.entity_embeddings.get(&triple.object).unwrap();
787
788        // Convert relation phases to complex numbers
789        let r: DVector<Complex<f32>> = DVector::from_iterator(
790            self.config.dimensions,
791            r_phases
792                .iter()
793                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
794        );
795
796        // Apply rotation: h ∘ r (element-wise complex multiplication)
797        let rotated: DVector<Complex<f32>> = h.component_mul(&r);
798
799        // Calculate distance ||h ∘ r - t||
800        let diff = rotated - t;
801        diff.iter().map(|c| c.norm()).sum::<f32>()
802    }
803}
804
805impl KGEmbeddingModel for RotatE {
806    fn train(&mut self, triples: &[Triple]) -> Result<()> {
807        if triples.is_empty() {
808            return Err(anyhow!("No triples provided for training"));
809        }
810
811        // Initialize embeddings
812        self.initialize_embeddings(triples);
813
814        // Training implementation would go here
815        // For brevity, using a simplified version
816
817        Ok(())
818    }
819
820    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
821        // Return magnitude and phase representation
822        let complex_emb = self.entity_embeddings.get(entity)?;
823
824        let mut values = Vec::with_capacity(self.config.dimensions * 2);
825        for c in complex_emb.iter() {
826            values.push(c.re); // Real part
827            values.push(c.im); // Imaginary part
828        }
829
830        Some(Vector::new(values))
831    }
832
833    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
834        self.relation_embeddings
835            .get(relation)
836            .map(|phases| Vector::new(phases.iter().cloned().collect()))
837    }
838
839    fn score_triple(&self, triple: &Triple) -> f32 {
840        let gamma = 12.0; // Fixed margin parameter for RotatE
841        gamma - self.distance(triple)
842    }
843
844    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
845        let h = match self.entity_embeddings.get(head) {
846            Some(emb) => emb,
847            None => return Vec::new(),
848        };
849
850        let r_phases = match self.relation_embeddings.get(relation) {
851            Some(emb) => emb,
852            None => return Vec::new(),
853        };
854
855        // Convert relation phases to complex numbers
856        let r: DVector<Complex<f32>> = DVector::from_iterator(
857            self.config.dimensions,
858            r_phases
859                .iter()
860                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
861        );
862
863        // Apply rotation
864        let rotated = h.component_mul(&r);
865
866        let mut scores: Vec<(String, f32)> = self
867            .entities
868            .iter()
869            .filter(|e| *e != head)
870            .filter_map(|entity| {
871                self.entity_embeddings.get(entity).map(|t| {
872                    let diff = &rotated - t;
873                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
874                    (entity.clone(), -distance)
875                })
876            })
877            .collect();
878
879        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
880        scores.truncate(k);
881        scores
882    }
883
884    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
885        let t = match self.entity_embeddings.get(tail) {
886            Some(emb) => emb,
887            None => return Vec::new(),
888        };
889
890        let r_phases = match self.relation_embeddings.get(relation) {
891            Some(emb) => emb,
892            None => return Vec::new(),
893        };
894
895        // Convert relation phases to complex numbers (inverse rotation)
896        let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
897            self.config.dimensions,
898            r_phases
899                .iter()
900                .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
901        );
902
903        let mut scores: Vec<(String, f32)> = self
904            .entities
905            .iter()
906            .filter(|e| *e != tail)
907            .filter_map(|entity| {
908                self.entity_embeddings.get(entity).map(|h| {
909                    let rotated = h.component_mul(&r_inv);
910                    let diff = rotated - t;
911                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
912                    (entity.clone(), -distance)
913                })
914            })
915            .collect();
916
917        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
918        scores.truncate(k);
919        scores
920    }
921
922    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
923        self.entity_embeddings
924            .iter()
925            .map(|(k, v)| {
926                let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
927                (k.clone(), Vector::new(real_values))
928            })
929            .collect()
930    }
931
932    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
933        self.relation_embeddings
934            .iter()
935            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
936            .collect()
937    }
938}
939
940/// Unified knowledge graph embedding interface
941pub struct KGEmbedding {
942    model: Box<dyn KGEmbeddingModel>,
943    config: KGEmbeddingConfig,
944}
945
946impl KGEmbedding {
947    /// Create a new knowledge graph embedding model
948    pub fn new(config: KGEmbeddingConfig) -> Self {
949        let model: Box<dyn KGEmbeddingModel> = match config.model {
950            KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
951            KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
952            KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
953            KGEmbeddingModelType::GCN => {
954                // Create GCN with default parameters
955                let gcn = GCN::new(config.clone());
956                Box::new(GCNAdapter::new(gcn))
957            }
958            KGEmbeddingModelType::GraphSAGE => {
959                // Create GraphSAGE with default parameters
960                let graphsage = GraphSAGE::new(config.clone())
961                    .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
962                Box::new(GraphSAGEAdapter::new(graphsage))
963            }
964        };
965
966        Self { model, config }
967    }
968
969    /// Train the model
970    pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
971        self.model.train(triples)
972    }
973
974    /// Get entity embedding
975    pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
976        self.model.get_entity_embedding(entity)
977    }
978
979    /// Get relation embedding
980    pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
981        self.model.get_relation_embedding(relation)
982    }
983
984    /// Score a triple
985    pub fn score_triple(&self, triple: &Triple) -> f32 {
986        self.model.score_triple(triple)
987    }
988
989    /// Link prediction: predict missing tail
990    pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
991        self.model.predict_tail(head, relation, k)
992    }
993
994    /// Link prediction: predict missing head
995    pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
996        self.model.predict_head(relation, tail, k)
997    }
998
999    /// Triple classification: determine if a triple is likely true
1000    pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1001        self.model.score_triple(triple) > threshold
1002    }
1003}
1004
1005/// Adapter to use GCN as a knowledge graph embedding model
1006pub struct GCNAdapter {
1007    gcn: GCN,
1008}
1009
1010impl GCNAdapter {
1011    pub fn new(gcn: GCN) -> Self {
1012        Self { gcn }
1013    }
1014}
1015
1016impl KGEmbeddingModel for GCNAdapter {
1017    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1018        // GCN training would be implemented here
1019        Ok(())
1020    }
1021
1022    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1023        // GCN embeddings would be computed from graph structure
1024        // For now, return a default embedding
1025        Some(Vector::new(vec![0.0; 128]))
1026    }
1027
1028    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1029        // Relations in GCN are typically handled differently
1030        Some(Vector::new(vec![0.0; 128]))
1031    }
1032
1033    fn score_triple(&self, _triple: &Triple) -> f32 {
1034        // GCN scoring would use graph structure
1035        0.5
1036    }
1037
1038    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1039        // Return mock predictions
1040        vec![]
1041    }
1042
1043    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1044        // Return mock predictions
1045        vec![]
1046    }
1047
1048    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1049        HashMap::new()
1050    }
1051
1052    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1053        HashMap::new()
1054    }
1055}
1056
1057/// Adapter to use GraphSAGE as a knowledge graph embedding model
1058pub struct GraphSAGEAdapter {
1059    graphsage: GraphSAGE,
1060}
1061
1062impl GraphSAGEAdapter {
1063    pub fn new(graphsage: GraphSAGE) -> Self {
1064        Self { graphsage }
1065    }
1066}
1067
1068impl KGEmbeddingModel for GraphSAGEAdapter {
1069    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1070        // GraphSAGE training would be implemented here
1071        Ok(())
1072    }
1073
1074    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1075        // GraphSAGE embeddings would be computed from neighbors
1076        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1077    }
1078
1079    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1080        // Relations in GraphSAGE are typically handled differently
1081        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1082    }
1083
1084    fn score_triple(&self, _triple: &Triple) -> f32 {
1085        // GraphSAGE scoring would use neighbor aggregation
1086        0.5
1087    }
1088
1089    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1090        // Return mock predictions
1091        vec![]
1092    }
1093
1094    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1095        // Return mock predictions
1096        vec![]
1097    }
1098
1099    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1100        HashMap::new()
1101    }
1102
1103    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1104        HashMap::new()
1105    }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111
1112    fn create_test_triples() -> Vec<Triple> {
1113        vec![
1114            Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1115            Triple::new(
1116                "Bob".to_string(),
1117                "knows".to_string(),
1118                "Charlie".to_string(),
1119            ),
1120            Triple::new(
1121                "Alice".to_string(),
1122                "likes".to_string(),
1123                "Pizza".to_string(),
1124            ),
1125            Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1126            Triple::new(
1127                "Charlie".to_string(),
1128                "knows".to_string(),
1129                "Alice".to_string(),
1130            ),
1131        ]
1132    }
1133
1134    #[test]
1135    fn test_transe() {
1136        let config = KGEmbeddingConfig {
1137            model: KGEmbeddingModelType::TransE,
1138            dimensions: 50,
1139            epochs: 10,
1140            ..Default::default()
1141        };
1142
1143        let mut model = KGEmbedding::new(config);
1144        let triples = create_test_triples();
1145
1146        model.train(&triples).unwrap();
1147
1148        // Test embeddings exist
1149        assert!(model.get_entity_embedding("Alice").is_some());
1150        assert!(model.get_relation_embedding("knows").is_some());
1151
1152        // Test scoring
1153        let score = model.score_triple(&triples[0]);
1154        assert!(score.is_finite());
1155
1156        // Test prediction
1157        let predictions = model.predict_tail("Alice", "knows", 2);
1158        assert!(!predictions.is_empty());
1159    }
1160
1161    #[test]
1162    fn test_complex() {
1163        let config = KGEmbeddingConfig {
1164            model: KGEmbeddingModelType::ComplEx,
1165            dimensions: 50,
1166            epochs: 10,
1167            ..Default::default()
1168        };
1169
1170        let mut model = KGEmbedding::new(config);
1171        let triples = create_test_triples();
1172
1173        model.train(&triples).unwrap();
1174
1175        // Test embeddings exist
1176        assert!(model.get_entity_embedding("Bob").is_some());
1177        let emb = model.get_entity_embedding("Bob").unwrap();
1178        assert_eq!(emb.dimensions, 100); // Real + imaginary parts
1179    }
1180
1181    #[test]
1182    fn test_rotate() {
1183        let config = KGEmbeddingConfig {
1184            model: KGEmbeddingModelType::RotatE,
1185            dimensions: 50,
1186            epochs: 10,
1187            ..Default::default()
1188        };
1189
1190        let mut model = KGEmbedding::new(config);
1191        let triples = create_test_triples();
1192
1193        model.train(&triples).unwrap();
1194
1195        // Test embeddings exist
1196        assert!(model.get_entity_embedding("Charlie").is_some());
1197        assert!(model.get_relation_embedding("likes").is_some());
1198
1199        // Test relation embedding is phase angles
1200        let rel_emb = model.get_relation_embedding("likes").unwrap();
1201        assert_eq!(rel_emb.dimensions, 50);
1202    }
1203}