oxirs_vec/
kg_embeddings.rs

1//! Knowledge Graph Embeddings for RDF data
2//!
3//! This module implements various knowledge graph embedding methods:
4//! - TransE: Translation-based embeddings
5//! - ComplEx: Complex number embeddings
6//! - RotatE: Rotation-based embeddings
7
8use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::random_utils::NormalSampler as Normal;
10use crate::Vector;
11use anyhow::{anyhow, Result};
12use nalgebra::{Complex, DVector};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17/// Knowledge graph embedding model type
18#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20    /// Translation-based embeddings (TransE)
21    TransE,
22    /// Complex number embeddings (ComplEx)
23    ComplEx,
24    /// Rotation-based embeddings (RotatE)
25    RotatE,
26    /// Graph Convolutional Network (GCN)
27    GCN,
28    /// GraphSAGE (Graph Sample and Aggregate)
29    GraphSAGE,
30}
31
32/// Configuration for knowledge graph embeddings
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35    /// Model type
36    pub model: KGEmbeddingModelType,
37    /// Embedding dimensions
38    pub dimensions: usize,
39    /// Learning rate
40    pub learning_rate: f32,
41    /// Margin for loss function
42    pub margin: f32,
43    /// Negative sampling ratio
44    pub negative_samples: usize,
45    /// Batch size for training
46    pub batch_size: usize,
47    /// Number of epochs
48    pub epochs: usize,
49    /// L1 or L2 norm
50    pub norm: usize,
51    /// Random seed
52    pub random_seed: Option<u64>,
53    /// Regularization weight
54    pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58    fn default() -> Self {
59        Self {
60            model: KGEmbeddingModelType::TransE,
61            dimensions: 100,
62            learning_rate: 0.01,
63            margin: 1.0,
64            negative_samples: 10,
65            batch_size: 100,
66            epochs: 100,
67            norm: 2,
68            random_seed: Some(42),
69            regularization: 0.0,
70        }
71    }
72}
73
74/// Triple for knowledge graph
75#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77    pub subject: String,
78    pub predicate: String,
79    pub object: String,
80}
81
82impl Triple {
83    pub fn new(subject: String, predicate: String, object: String) -> Self {
84        Self {
85            subject,
86            predicate,
87            object,
88        }
89    }
90}
91
92/// Base trait for knowledge graph embedding models
93pub trait KGEmbeddingModel: Send + Sync {
94    /// Train the model on triples
95    fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97    /// Get entity embedding
98    fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100    /// Get relation embedding
101    fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103    /// Score a triple
104    fn score_triple(&self, triple: &Triple) -> f32;
105
106    /// Predict tail entities for (head, relation, ?)
107    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109    /// Predict head entities for (?, relation, tail)
110    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112    /// Get all entity embeddings
113    fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115    /// Get all relation embeddings
116    fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119/// TransE: Translation-based embeddings
120/// Learns embeddings where h + r ≈ t for triple (h, r, t)
121pub struct TransE {
122    config: KGEmbeddingConfig,
123    entity_embeddings: HashMap<String, DVector<f32>>,
124    relation_embeddings: HashMap<String, DVector<f32>>,
125    entities: Vec<String>,
126    relations: Vec<String>,
127}
128
129impl TransE {
130    pub fn new(config: KGEmbeddingConfig) -> Self {
131        Self {
132            config,
133            entity_embeddings: HashMap::new(),
134            relation_embeddings: HashMap::new(),
135            entities: Vec::new(),
136            relations: Vec::new(),
137        }
138    }
139
140    /// Initialize embeddings
141    fn initialize_embeddings(&mut self, triples: &[Triple]) {
142        // Collect unique entities and relations
143        let mut entities = std::collections::HashSet::new();
144        let mut relations = std::collections::HashSet::new();
145
146        for triple in triples {
147            entities.insert(triple.subject.clone());
148            entities.insert(triple.object.clone());
149            relations.insert(triple.predicate.clone());
150        }
151
152        self.entities = entities.into_iter().collect();
153        self.relations = relations.into_iter().collect();
154
155        // Initialize embeddings with uniform distribution
156        let mut rng = if let Some(seed) = self.config.random_seed {
157            Random::seed(seed)
158        } else {
159            Random::seed(42)
160        };
161
162        let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163        let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165        // Initialize entity embeddings
166        for entity in &self.entities {
167            let values: Vec<f32> = (0..self.config.dimensions)
168                .map(|_| rng.random_range(range_min..range_max))
169                .collect();
170            let mut embedding = DVector::from_vec(values);
171
172            // Normalize entities
173            let norm = embedding.norm();
174            if norm > 0.0 {
175                embedding /= norm;
176            }
177
178            self.entity_embeddings.insert(entity.clone(), embedding);
179        }
180
181        // Initialize relation embeddings
182        for relation in &self.relations {
183            let values: Vec<f32> = (0..self.config.dimensions)
184                .map(|_| rng.random_range(range_min..range_max))
185                .collect();
186            let embedding = DVector::from_vec(values);
187
188            // Relations are not normalized in TransE
189            self.relation_embeddings.insert(relation.clone(), embedding);
190        }
191    }
192
193    /// Generate negative samples
194    #[allow(deprecated)]
195    fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
196        let mut negatives = Vec::new();
197
198        for _ in 0..self.config.negative_samples {
199            if rng.gen_bool(0.5) {
200                // Corrupt head
201                let mut negative = triple.clone();
202                loop {
203                    let idx = rng.gen_range(0..self.entities.len());
204                    let entity = &self.entities[idx];
205                    if entity != &triple.subject {
206                        negative.subject = entity.clone();
207                        break;
208                    }
209                }
210                negatives.push(negative);
211            } else {
212                // Corrupt tail
213                let mut negative = triple.clone();
214                loop {
215                    let idx = rng.gen_range(0..self.entities.len());
216                    let entity = &self.entities[idx];
217                    if entity != &triple.object {
218                        negative.object = entity.clone();
219                        break;
220                    }
221                }
222                negatives.push(negative);
223            }
224        }
225
226        negatives
227    }
228
229    /// Calculate distance for a triple
230    fn distance(&self, triple: &Triple) -> f32 {
231        let h = self
232            .entity_embeddings
233            .get(&triple.subject)
234            .expect("subject entity should have embedding");
235        let r = self
236            .relation_embeddings
237            .get(&triple.predicate)
238            .expect("predicate relation should have embedding");
239        let t = self
240            .entity_embeddings
241            .get(&triple.object)
242            .expect("object entity should have embedding");
243
244        let translation = h + r - t;
245
246        match self.config.norm {
247            1 => translation.iter().map(|x| x.abs()).sum(),
248            2 => translation.norm(),
249            _ => translation.norm(),
250        }
251    }
252
253    /// Update embeddings using gradient descent
254    fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
255        let pos_dist = self.distance(positive);
256
257        for negative in negatives {
258            let neg_dist = self.distance(negative);
259            let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
260
261            if loss > 0.0 {
262                // Calculate gradients
263                let h_pos = self
264                    .entity_embeddings
265                    .get(&positive.subject)
266                    .expect("positive subject entity should have embedding")
267                    .clone();
268                let r = self
269                    .relation_embeddings
270                    .get(&positive.predicate)
271                    .expect("positive predicate relation should have embedding")
272                    .clone();
273                let t_pos = self
274                    .entity_embeddings
275                    .get(&positive.object)
276                    .expect("positive object entity should have embedding")
277                    .clone();
278
279                let h_neg = self
280                    .entity_embeddings
281                    .get(&negative.subject)
282                    .expect("negative subject entity should have embedding")
283                    .clone();
284                let t_neg = self
285                    .entity_embeddings
286                    .get(&negative.object)
287                    .expect("negative object entity should have embedding")
288                    .clone();
289
290                let pos_grad = &h_pos + &r - &t_pos;
291                let neg_grad = &h_neg + &r - &t_neg;
292
293                // Normalize gradients
294                let pos_norm = pos_grad.norm();
295                let neg_norm = neg_grad.norm();
296
297                let pos_grad_norm = if pos_norm > 0.0 {
298                    &pos_grad / pos_norm
299                } else {
300                    pos_grad
301                };
302                let neg_grad_norm = if neg_norm > 0.0 {
303                    &neg_grad / neg_norm
304                } else {
305                    neg_grad
306                };
307
308                // Update embeddings
309                let lr = self.config.learning_rate;
310
311                // Update positive triple embeddings
312                if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
313                    *h -= lr * &pos_grad_norm;
314                    // Re-normalize entity
315                    let norm = h.norm();
316                    if norm > 0.0 {
317                        *h /= norm;
318                    }
319                }
320
321                if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
322                    *r -= lr * (&pos_grad_norm - &neg_grad_norm);
323                }
324
325                if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
326                    *t += lr * &pos_grad_norm;
327                    // Re-normalize entity
328                    let norm = t.norm();
329                    if norm > 0.0 {
330                        *t /= norm;
331                    }
332                }
333
334                // Update negative triple embeddings
335                if positive.subject != negative.subject {
336                    if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
337                        *h += lr * &neg_grad_norm;
338                        // Re-normalize entity
339                        let norm = h.norm();
340                        if norm > 0.0 {
341                            *h /= norm;
342                        }
343                    }
344                }
345
346                if positive.object != negative.object {
347                    if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
348                        *t -= lr * &neg_grad_norm;
349                        // Re-normalize entity
350                        let norm = t.norm();
351                        if norm > 0.0 {
352                            *t /= norm;
353                        }
354                    }
355                }
356            }
357        }
358    }
359}
360
361impl KGEmbeddingModel for TransE {
362    fn train(&mut self, triples: &[Triple]) -> Result<()> {
363        if triples.is_empty() {
364            return Err(anyhow!("No triples provided for training"));
365        }
366
367        // Initialize embeddings
368        self.initialize_embeddings(triples);
369
370        let mut rng = if let Some(seed) = self.config.random_seed {
371            Random::seed(seed)
372        } else {
373            Random::seed(42)
374        };
375
376        // Training loop
377        for epoch in 0..self.config.epochs {
378            let mut total_loss = 0.0;
379            let mut batch_count = 0;
380
381            // Shuffle triples
382            let mut shuffled_triples = triples.to_vec();
383            // Note: Using manual random selection instead of SliceRandom
384            // Manually shuffle using Fisher-Yates algorithm
385            for i in (1..shuffled_triples.len()).rev() {
386                let j = rng.random_range(0..i + 1);
387                shuffled_triples.swap(i, j);
388            }
389
390            // Process batches
391            for batch in shuffled_triples.chunks(self.config.batch_size) {
392                for triple in batch {
393                    // Generate negative samples
394                    let negatives = self.generate_negative_samples(triple, &mut rng);
395
396                    // Calculate loss
397                    let pos_dist = self.distance(triple);
398                    for negative in &negatives {
399                        let neg_dist = self.distance(negative);
400                        let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
401                        total_loss += loss;
402                    }
403
404                    // Update embeddings
405                    self.update_embeddings(triple, &negatives);
406                }
407                batch_count += 1;
408            }
409
410            if epoch % 10 == 0 {
411                let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
412                tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
413            }
414        }
415
416        Ok(())
417    }
418
419    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
420        self.entity_embeddings
421            .get(entity)
422            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
423    }
424
425    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
426        self.relation_embeddings
427            .get(relation)
428            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
429    }
430
431    fn score_triple(&self, triple: &Triple) -> f32 {
432        -self.distance(triple)
433    }
434
435    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
436        let h = match self.entity_embeddings.get(head) {
437            Some(emb) => emb,
438            None => return Vec::new(),
439        };
440
441        let r = match self.relation_embeddings.get(relation) {
442            Some(emb) => emb,
443            None => return Vec::new(),
444        };
445
446        let translation = h + r;
447
448        let mut scores: Vec<(String, f32)> = self
449            .entities
450            .iter()
451            .filter(|e| *e != head)
452            .filter_map(|entity| {
453                self.entity_embeddings.get(entity).map(|t| {
454                    let distance = match self.config.norm {
455                        1 => (&translation - t).iter().map(|x| x.abs()).sum(),
456                        2 => (&translation - t).norm(),
457                        _ => (&translation - t).norm(),
458                    };
459                    (entity.clone(), -distance)
460                })
461            })
462            .collect();
463
464        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
465        scores.truncate(k);
466        scores
467    }
468
469    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
470        let t = match self.entity_embeddings.get(tail) {
471            Some(emb) => emb,
472            None => return Vec::new(),
473        };
474
475        let r = match self.relation_embeddings.get(relation) {
476            Some(emb) => emb,
477            None => return Vec::new(),
478        };
479
480        let target = t - r;
481
482        let mut scores: Vec<(String, f32)> = self
483            .entities
484            .iter()
485            .filter(|e| *e != tail)
486            .filter_map(|entity| {
487                self.entity_embeddings.get(entity).map(|h| {
488                    let distance = match self.config.norm {
489                        1 => (h - &target).iter().map(|x| x.abs()).sum(),
490                        2 => (h - &target).norm(),
491                        _ => (h - &target).norm(),
492                    };
493                    (entity.clone(), -distance)
494                })
495            })
496            .collect();
497
498        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
499        scores.truncate(k);
500        scores
501    }
502
503    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
504        self.entity_embeddings
505            .iter()
506            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
507            .collect()
508    }
509
510    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
511        self.relation_embeddings
512            .iter()
513            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
514            .collect()
515    }
516}
517
518/// ComplEx: Complex number embeddings
519/// Uses complex-valued embeddings and Hermitian dot product
520pub struct ComplEx {
521    config: KGEmbeddingConfig,
522    entity_embeddings_real: HashMap<String, DVector<f32>>,
523    entity_embeddings_imag: HashMap<String, DVector<f32>>,
524    relation_embeddings_real: HashMap<String, DVector<f32>>,
525    relation_embeddings_imag: HashMap<String, DVector<f32>>,
526    entities: Vec<String>,
527    relations: Vec<String>,
528}
529
530impl ComplEx {
531    pub fn new(config: KGEmbeddingConfig) -> Self {
532        Self {
533            config,
534            entity_embeddings_real: HashMap::new(),
535            entity_embeddings_imag: HashMap::new(),
536            relation_embeddings_real: HashMap::new(),
537            relation_embeddings_imag: HashMap::new(),
538            entities: Vec::new(),
539            relations: Vec::new(),
540        }
541    }
542
543    /// Initialize embeddings with Xavier initialization
544    fn initialize_embeddings(&mut self, triples: &[Triple]) {
545        // Collect unique entities and relations
546        let mut entities = std::collections::HashSet::new();
547        let mut relations = std::collections::HashSet::new();
548
549        for triple in triples {
550            entities.insert(triple.subject.clone());
551            entities.insert(triple.object.clone());
552            relations.insert(triple.predicate.clone());
553        }
554
555        self.entities = entities.into_iter().collect();
556        self.relations = relations.into_iter().collect();
557
558        // Initialize with Xavier initialization
559        let mut rng = if let Some(seed) = self.config.random_seed {
560            Random::seed(seed)
561        } else {
562            Random::seed(42)
563        };
564
565        let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
566        let normal =
567            Normal::new(0.0, std_dev).expect("normal distribution parameters should be valid");
568
569        // Initialize entity embeddings
570        for entity in &self.entities {
571            let real_values: Vec<f32> = (0..self.config.dimensions)
572                .map(|_| normal.sample(&mut rng))
573                .collect();
574            let imag_values: Vec<f32> = (0..self.config.dimensions)
575                .map(|_| normal.sample(&mut rng))
576                .collect();
577
578            self.entity_embeddings_real
579                .insert(entity.clone(), DVector::from_vec(real_values));
580            self.entity_embeddings_imag
581                .insert(entity.clone(), DVector::from_vec(imag_values));
582        }
583
584        // Initialize relation embeddings
585        for relation in &self.relations {
586            let real_values: Vec<f32> = (0..self.config.dimensions)
587                .map(|_| normal.sample(&mut rng))
588                .collect();
589            let imag_values: Vec<f32> = (0..self.config.dimensions)
590                .map(|_| normal.sample(&mut rng))
591                .collect();
592
593            self.relation_embeddings_real
594                .insert(relation.clone(), DVector::from_vec(real_values));
595            self.relation_embeddings_imag
596                .insert(relation.clone(), DVector::from_vec(imag_values));
597        }
598    }
599
600    /// Hermitian dot product for scoring
601    fn hermitian_dot(&self, triple: &Triple) -> f32 {
602        let h_real = self
603            .entity_embeddings_real
604            .get(&triple.subject)
605            .expect("subject entity should have real embedding");
606        let h_imag = self
607            .entity_embeddings_imag
608            .get(&triple.subject)
609            .expect("subject entity should have imag embedding");
610        let r_real = self
611            .relation_embeddings_real
612            .get(&triple.predicate)
613            .expect("predicate relation should have real embedding");
614        let r_imag = self
615            .relation_embeddings_imag
616            .get(&triple.predicate)
617            .expect("predicate relation should have imag embedding");
618        let t_real = self
619            .entity_embeddings_real
620            .get(&triple.object)
621            .expect("object entity should have real embedding");
622        let t_imag = self
623            .entity_embeddings_imag
624            .get(&triple.object)
625            .expect("object entity should have imag embedding");
626
627        // ComplEx scoring function: Re(<h, r, t̄>)
628        // = Re(∑ h_i * r_i * conj(t_i))
629        // = ∑ (h_real * r_real * t_real + h_real * r_imag * t_imag +
630        //      h_imag * r_real * t_imag - h_imag * r_imag * t_real)
631
632        let mut score = 0.0;
633        for i in 0..self.config.dimensions {
634            score += h_real[i] * r_real[i] * t_real[i]
635                + h_real[i] * r_imag[i] * t_imag[i]
636                + h_imag[i] * r_real[i] * t_imag[i]
637                - h_imag[i] * r_imag[i] * t_real[i];
638        }
639
640        score
641    }
642}
643
644impl KGEmbeddingModel for ComplEx {
645    fn train(&mut self, triples: &[Triple]) -> Result<()> {
646        if triples.is_empty() {
647            return Err(anyhow!("No triples provided for training"));
648        }
649
650        // Initialize embeddings
651        self.initialize_embeddings(triples);
652
653        // Training implementation would go here
654        // For brevity, using a simplified version
655
656        Ok(())
657    }
658
659    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
660        // Return concatenated real and imaginary parts
661        let real = self.entity_embeddings_real.get(entity)?;
662        let imag = self.entity_embeddings_imag.get(entity)?;
663
664        let mut values = Vec::with_capacity(self.config.dimensions * 2);
665        values.extend(real.iter().cloned());
666        values.extend(imag.iter().cloned());
667
668        Some(Vector::new(values))
669    }
670
671    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
672        // Return concatenated real and imaginary parts
673        let real = self.relation_embeddings_real.get(relation)?;
674        let imag = self.relation_embeddings_imag.get(relation)?;
675
676        let mut values = Vec::with_capacity(self.config.dimensions * 2);
677        values.extend(real.iter().cloned());
678        values.extend(imag.iter().cloned());
679
680        Some(Vector::new(values))
681    }
682
683    fn score_triple(&self, triple: &Triple) -> f32 {
684        self.hermitian_dot(triple)
685    }
686
687    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
688        let mut scores: Vec<(String, f32)> = self
689            .entities
690            .iter()
691            .filter(|e| *e != head)
692            .map(|tail| {
693                let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
694                let score = self.score_triple(&triple);
695                (tail.clone(), score)
696            })
697            .collect();
698
699        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
700        scores.truncate(k);
701        scores
702    }
703
704    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
705        let mut scores: Vec<(String, f32)> = self
706            .entities
707            .iter()
708            .filter(|e| *e != tail)
709            .map(|head| {
710                let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
711                let score = self.score_triple(&triple);
712                (head.clone(), score)
713            })
714            .collect();
715
716        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
717        scores.truncate(k);
718        scores
719    }
720
721    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
722        self.entity_embeddings_real
723            .iter()
724            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
725            .collect()
726    }
727
728    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
729        self.relation_embeddings_real
730            .iter()
731            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
732            .collect()
733    }
734}
735
736/// RotatE: Rotation-based embeddings
737/// Models relations as rotations in complex space
738pub struct RotatE {
739    config: KGEmbeddingConfig,
740    entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
741    relation_embeddings: HashMap<String, DVector<f32>>, // Phase angles
742    entities: Vec<String>,
743    relations: Vec<String>,
744}
745
746impl RotatE {
747    pub fn new(config: KGEmbeddingConfig) -> Self {
748        Self {
749            config,
750            entity_embeddings: HashMap::new(),
751            relation_embeddings: HashMap::new(),
752            entities: Vec::new(),
753            relations: Vec::new(),
754        }
755    }
756
757    /// Initialize embeddings
758    fn initialize_embeddings(&mut self, triples: &[Triple]) {
759        // Collect unique entities and relations
760        let mut entities = std::collections::HashSet::new();
761        let mut relations = std::collections::HashSet::new();
762
763        for triple in triples {
764            entities.insert(triple.subject.clone());
765            entities.insert(triple.object.clone());
766            relations.insert(triple.predicate.clone());
767        }
768
769        self.entities = entities.into_iter().collect();
770        self.relations = relations.into_iter().collect();
771
772        let mut rng = if let Some(seed) = self.config.random_seed {
773            Random::seed(seed)
774        } else {
775            Random::seed(42)
776        };
777
778        // Initialize entity embeddings (complex numbers with unit modulus)
779        let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
780
781        for entity in &self.entities {
782            let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
783                .map(|_| {
784                    let phase = rng.gen_range(phase_range.clone());
785                    Complex::new(phase.cos(), phase.sin())
786                })
787                .collect();
788
789            self.entity_embeddings
790                .insert(entity.clone(), DVector::from_vec(phases));
791        }
792
793        // Initialize relation embeddings (phase angles)
794        for relation in &self.relations {
795            let phases: Vec<f32> = (0..self.config.dimensions)
796                .map(|_| rng.gen_range(phase_range.clone()))
797                .collect();
798
799            self.relation_embeddings
800                .insert(relation.clone(), DVector::from_vec(phases));
801        }
802    }
803
804    /// Calculate distance for RotatE
805    fn distance(&self, triple: &Triple) -> f32 {
806        let h = self
807            .entity_embeddings
808            .get(&triple.subject)
809            .expect("subject entity should have embedding");
810        let r_phases = self
811            .relation_embeddings
812            .get(&triple.predicate)
813            .expect("predicate relation should have embedding");
814        let t = self
815            .entity_embeddings
816            .get(&triple.object)
817            .expect("object entity should have embedding");
818
819        // Convert relation phases to complex numbers
820        let r: DVector<Complex<f32>> = DVector::from_iterator(
821            self.config.dimensions,
822            r_phases
823                .iter()
824                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
825        );
826
827        // Apply rotation: h ∘ r (element-wise complex multiplication)
828        let rotated: DVector<Complex<f32>> = h.component_mul(&r);
829
830        // Calculate distance ||h ∘ r - t||
831        let diff = rotated - t;
832        diff.iter().map(|c| c.norm()).sum::<f32>()
833    }
834}
835
836impl KGEmbeddingModel for RotatE {
837    fn train(&mut self, triples: &[Triple]) -> Result<()> {
838        if triples.is_empty() {
839            return Err(anyhow!("No triples provided for training"));
840        }
841
842        // Initialize embeddings
843        self.initialize_embeddings(triples);
844
845        // Training implementation would go here
846        // For brevity, using a simplified version
847
848        Ok(())
849    }
850
851    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
852        // Return magnitude and phase representation
853        let complex_emb = self.entity_embeddings.get(entity)?;
854
855        let mut values = Vec::with_capacity(self.config.dimensions * 2);
856        for c in complex_emb.iter() {
857            values.push(c.re); // Real part
858            values.push(c.im); // Imaginary part
859        }
860
861        Some(Vector::new(values))
862    }
863
864    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
865        self.relation_embeddings
866            .get(relation)
867            .map(|phases| Vector::new(phases.iter().cloned().collect()))
868    }
869
870    fn score_triple(&self, triple: &Triple) -> f32 {
871        let gamma = 12.0; // Fixed margin parameter for RotatE
872        gamma - self.distance(triple)
873    }
874
875    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
876        let h = match self.entity_embeddings.get(head) {
877            Some(emb) => emb,
878            None => return Vec::new(),
879        };
880
881        let r_phases = match self.relation_embeddings.get(relation) {
882            Some(emb) => emb,
883            None => return Vec::new(),
884        };
885
886        // Convert relation phases to complex numbers
887        let r: DVector<Complex<f32>> = DVector::from_iterator(
888            self.config.dimensions,
889            r_phases
890                .iter()
891                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
892        );
893
894        // Apply rotation
895        let rotated = h.component_mul(&r);
896
897        let mut scores: Vec<(String, f32)> = self
898            .entities
899            .iter()
900            .filter(|e| *e != head)
901            .filter_map(|entity| {
902                self.entity_embeddings.get(entity).map(|t| {
903                    let diff = &rotated - t;
904                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
905                    (entity.clone(), -distance)
906                })
907            })
908            .collect();
909
910        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
911        scores.truncate(k);
912        scores
913    }
914
915    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
916        let t = match self.entity_embeddings.get(tail) {
917            Some(emb) => emb,
918            None => return Vec::new(),
919        };
920
921        let r_phases = match self.relation_embeddings.get(relation) {
922            Some(emb) => emb,
923            None => return Vec::new(),
924        };
925
926        // Convert relation phases to complex numbers (inverse rotation)
927        let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
928            self.config.dimensions,
929            r_phases
930                .iter()
931                .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
932        );
933
934        let mut scores: Vec<(String, f32)> = self
935            .entities
936            .iter()
937            .filter(|e| *e != tail)
938            .filter_map(|entity| {
939                self.entity_embeddings.get(entity).map(|h| {
940                    let rotated = h.component_mul(&r_inv);
941                    let diff = rotated - t;
942                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
943                    (entity.clone(), -distance)
944                })
945            })
946            .collect();
947
948        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
949        scores.truncate(k);
950        scores
951    }
952
953    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
954        self.entity_embeddings
955            .iter()
956            .map(|(k, v)| {
957                let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
958                (k.clone(), Vector::new(real_values))
959            })
960            .collect()
961    }
962
963    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
964        self.relation_embeddings
965            .iter()
966            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
967            .collect()
968    }
969}
970
971/// Unified knowledge graph embedding interface
972pub struct KGEmbedding {
973    model: Box<dyn KGEmbeddingModel>,
974    config: KGEmbeddingConfig,
975}
976
977impl KGEmbedding {
978    /// Create a new knowledge graph embedding model
979    pub fn new(config: KGEmbeddingConfig) -> Self {
980        let model: Box<dyn KGEmbeddingModel> = match config.model {
981            KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
982            KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
983            KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
984            KGEmbeddingModelType::GCN => {
985                // Create GCN with default parameters
986                let gcn = GCN::new(config.clone());
987                Box::new(GCNAdapter::new(gcn))
988            }
989            KGEmbeddingModelType::GraphSAGE => {
990                // Create GraphSAGE with default parameters
991                let graphsage = GraphSAGE::new(config.clone())
992                    .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
993                Box::new(GraphSAGEAdapter::new(graphsage))
994            }
995        };
996
997        Self { model, config }
998    }
999
1000    /// Train the model
1001    pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
1002        self.model.train(triples)
1003    }
1004
1005    /// Get entity embedding
1006    pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
1007        self.model.get_entity_embedding(entity)
1008    }
1009
1010    /// Get relation embedding
1011    pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
1012        self.model.get_relation_embedding(relation)
1013    }
1014
1015    /// Score a triple
1016    pub fn score_triple(&self, triple: &Triple) -> f32 {
1017        self.model.score_triple(triple)
1018    }
1019
1020    /// Link prediction: predict missing tail
1021    pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
1022        self.model.predict_tail(head, relation, k)
1023    }
1024
1025    /// Link prediction: predict missing head
1026    pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
1027        self.model.predict_head(relation, tail, k)
1028    }
1029
1030    /// Triple classification: determine if a triple is likely true
1031    pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1032        self.model.score_triple(triple) > threshold
1033    }
1034}
1035
1036/// Adapter to use GCN as a knowledge graph embedding model
1037pub struct GCNAdapter {
1038    gcn: GCN,
1039}
1040
1041impl GCNAdapter {
1042    pub fn new(gcn: GCN) -> Self {
1043        Self { gcn }
1044    }
1045}
1046
1047impl KGEmbeddingModel for GCNAdapter {
1048    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1049        // GCN training would be implemented here
1050        Ok(())
1051    }
1052
1053    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1054        // GCN embeddings would be computed from graph structure
1055        // For now, return a default embedding
1056        Some(Vector::new(vec![0.0; 128]))
1057    }
1058
1059    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1060        // Relations in GCN are typically handled differently
1061        Some(Vector::new(vec![0.0; 128]))
1062    }
1063
1064    fn score_triple(&self, _triple: &Triple) -> f32 {
1065        // GCN scoring would use graph structure
1066        0.5
1067    }
1068
1069    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1070        // Return mock predictions
1071        vec![]
1072    }
1073
1074    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1075        // Return mock predictions
1076        vec![]
1077    }
1078
1079    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1080        HashMap::new()
1081    }
1082
1083    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1084        HashMap::new()
1085    }
1086}
1087
1088/// Adapter to use GraphSAGE as a knowledge graph embedding model
1089pub struct GraphSAGEAdapter {
1090    graphsage: GraphSAGE,
1091}
1092
1093impl GraphSAGEAdapter {
1094    pub fn new(graphsage: GraphSAGE) -> Self {
1095        Self { graphsage }
1096    }
1097}
1098
1099impl KGEmbeddingModel for GraphSAGEAdapter {
1100    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1101        // GraphSAGE training would be implemented here
1102        Ok(())
1103    }
1104
1105    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1106        // GraphSAGE embeddings would be computed from neighbors
1107        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1108    }
1109
1110    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1111        // Relations in GraphSAGE are typically handled differently
1112        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1113    }
1114
1115    fn score_triple(&self, _triple: &Triple) -> f32 {
1116        // GraphSAGE scoring would use neighbor aggregation
1117        0.5
1118    }
1119
1120    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1121        // Return mock predictions
1122        vec![]
1123    }
1124
1125    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1126        // Return mock predictions
1127        vec![]
1128    }
1129
1130    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1131        HashMap::new()
1132    }
1133
1134    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1135        HashMap::new()
1136    }
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141    use super::*;
1142
1143    fn create_test_triples() -> Vec<Triple> {
1144        vec![
1145            Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1146            Triple::new(
1147                "Bob".to_string(),
1148                "knows".to_string(),
1149                "Charlie".to_string(),
1150            ),
1151            Triple::new(
1152                "Alice".to_string(),
1153                "likes".to_string(),
1154                "Pizza".to_string(),
1155            ),
1156            Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1157            Triple::new(
1158                "Charlie".to_string(),
1159                "knows".to_string(),
1160                "Alice".to_string(),
1161            ),
1162        ]
1163    }
1164
1165    #[test]
1166    fn test_transe() {
1167        let config = KGEmbeddingConfig {
1168            model: KGEmbeddingModelType::TransE,
1169            dimensions: 50,
1170            epochs: 10,
1171            ..Default::default()
1172        };
1173
1174        let mut model = KGEmbedding::new(config);
1175        let triples = create_test_triples();
1176
1177        model.train(&triples).unwrap();
1178
1179        // Test embeddings exist
1180        assert!(model.get_entity_embedding("Alice").is_some());
1181        assert!(model.get_relation_embedding("knows").is_some());
1182
1183        // Test scoring
1184        let score = model.score_triple(&triples[0]);
1185        assert!(score.is_finite());
1186
1187        // Test prediction
1188        let predictions = model.predict_tail("Alice", "knows", 2);
1189        assert!(!predictions.is_empty());
1190    }
1191
1192    #[test]
1193    fn test_complex() {
1194        let config = KGEmbeddingConfig {
1195            model: KGEmbeddingModelType::ComplEx,
1196            dimensions: 50,
1197            epochs: 10,
1198            ..Default::default()
1199        };
1200
1201        let mut model = KGEmbedding::new(config);
1202        let triples = create_test_triples();
1203
1204        model.train(&triples).unwrap();
1205
1206        // Test embeddings exist
1207        assert!(model.get_entity_embedding("Bob").is_some());
1208        let emb = model.get_entity_embedding("Bob").unwrap();
1209        assert_eq!(emb.dimensions, 100); // Real + imaginary parts
1210    }
1211
1212    #[test]
1213    fn test_rotate() {
1214        let config = KGEmbeddingConfig {
1215            model: KGEmbeddingModelType::RotatE,
1216            dimensions: 50,
1217            epochs: 10,
1218            ..Default::default()
1219        };
1220
1221        let mut model = KGEmbedding::new(config);
1222        let triples = create_test_triples();
1223
1224        model.train(&triples).unwrap();
1225
1226        // Test embeddings exist
1227        assert!(model.get_entity_embedding("Charlie").is_some());
1228        assert!(model.get_relation_embedding("likes").is_some());
1229
1230        // Test relation embedding is phase angles
1231        let rel_emb = model.get_relation_embedding("likes").unwrap();
1232        assert_eq!(rel_emb.dimensions, 50);
1233    }
1234}