oxirs_vec/
kg_embeddings.rs

1//! Knowledge Graph Embeddings for RDF data
2//!
3//! This module implements various knowledge graph embedding methods:
4//! - TransE: Translation-based embeddings
5//! - ComplEx: Complex number embeddings
6//! - RotatE: Rotation-based embeddings
7
8use crate::gnn_embeddings::{GraphSAGE, GCN};
9use crate::Vector;
10use anyhow::{anyhow, Result};
11use nalgebra::{Complex, DVector};
12use crate::random_utils::{NormalSampler as Normal, UniformSampler as Uniform};
13use scirs2_core::random::{Random, Rng};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17/// Knowledge graph embedding model type
18#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
19pub enum KGEmbeddingModelType {
20    /// Translation-based embeddings (TransE)
21    TransE,
22    /// Complex number embeddings (ComplEx)
23    ComplEx,
24    /// Rotation-based embeddings (RotatE)
25    RotatE,
26    /// Graph Convolutional Network (GCN)
27    GCN,
28    /// GraphSAGE (Graph Sample and Aggregate)
29    GraphSAGE,
30}
31
32/// Configuration for knowledge graph embeddings
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct KGEmbeddingConfig {
35    /// Model type
36    pub model: KGEmbeddingModelType,
37    /// Embedding dimensions
38    pub dimensions: usize,
39    /// Learning rate
40    pub learning_rate: f32,
41    /// Margin for loss function
42    pub margin: f32,
43    /// Negative sampling ratio
44    pub negative_samples: usize,
45    /// Batch size for training
46    pub batch_size: usize,
47    /// Number of epochs
48    pub epochs: usize,
49    /// L1 or L2 norm
50    pub norm: usize,
51    /// Random seed
52    pub random_seed: Option<u64>,
53    /// Regularization weight
54    pub regularization: f32,
55}
56
57impl Default for KGEmbeddingConfig {
58    fn default() -> Self {
59        Self {
60            model: KGEmbeddingModelType::TransE,
61            dimensions: 100,
62            learning_rate: 0.01,
63            margin: 1.0,
64            negative_samples: 10,
65            batch_size: 100,
66            epochs: 100,
67            norm: 2,
68            random_seed: Some(42),
69            regularization: 0.0,
70        }
71    }
72}
73
74/// Triple for knowledge graph
75#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76pub struct Triple {
77    pub subject: String,
78    pub predicate: String,
79    pub object: String,
80}
81
82impl Triple {
83    pub fn new(subject: String, predicate: String, object: String) -> Self {
84        Self {
85            subject,
86            predicate,
87            object,
88        }
89    }
90}
91
92/// Base trait for knowledge graph embedding models
93pub trait KGEmbeddingModel: Send + Sync {
94    /// Train the model on triples
95    fn train(&mut self, triples: &[Triple]) -> Result<()>;
96
97    /// Get entity embedding
98    fn get_entity_embedding(&self, entity: &str) -> Option<Vector>;
99
100    /// Get relation embedding
101    fn get_relation_embedding(&self, relation: &str) -> Option<Vector>;
102
103    /// Score a triple
104    fn score_triple(&self, triple: &Triple) -> f32;
105
106    /// Predict tail entities for (head, relation, ?)
107    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)>;
108
109    /// Predict head entities for (?, relation, tail)
110    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)>;
111
112    /// Get all entity embeddings
113    fn get_entity_embeddings(&self) -> HashMap<String, Vector>;
114
115    /// Get all relation embeddings
116    fn get_relation_embeddings(&self) -> HashMap<String, Vector>;
117}
118
119/// TransE: Translation-based embeddings
120/// Learns embeddings where h + r ≈ t for triple (h, r, t)
121pub struct TransE {
122    config: KGEmbeddingConfig,
123    entity_embeddings: HashMap<String, DVector<f32>>,
124    relation_embeddings: HashMap<String, DVector<f32>>,
125    entities: Vec<String>,
126    relations: Vec<String>,
127}
128
129impl TransE {
130    pub fn new(config: KGEmbeddingConfig) -> Self {
131        Self {
132            config,
133            entity_embeddings: HashMap::new(),
134            relation_embeddings: HashMap::new(),
135            entities: Vec::new(),
136            relations: Vec::new(),
137        }
138    }
139
140    /// Initialize embeddings
141    fn initialize_embeddings(&mut self, triples: &[Triple]) {
142        // Collect unique entities and relations
143        let mut entities = std::collections::HashSet::new();
144        let mut relations = std::collections::HashSet::new();
145
146        for triple in triples {
147            entities.insert(triple.subject.clone());
148            entities.insert(triple.object.clone());
149            relations.insert(triple.predicate.clone());
150        }
151
152        self.entities = entities.into_iter().collect();
153        self.relations = relations.into_iter().collect();
154
155        // Initialize embeddings with uniform distribution
156        let mut rng = if let Some(seed) = self.config.random_seed {
157            Random::seed(seed)
158        } else {
159            Random::seed(42)
160        };
161
162        let range_min = -6.0 / (self.config.dimensions as f32).sqrt();
163        let range_max = 6.0 / (self.config.dimensions as f32).sqrt();
164
165        // Initialize entity embeddings
166        for entity in &self.entities {
167            let values: Vec<f32> = (0..self.config.dimensions)
168                .map(|_| rng.gen_range(range_min..range_max))
169                .collect();
170            let mut embedding = DVector::from_vec(values);
171
172            // Normalize entities
173            let norm = embedding.norm();
174            if norm > 0.0 {
175                embedding /= norm;
176            }
177
178            self.entity_embeddings.insert(entity.clone(), embedding);
179        }
180
181        // Initialize relation embeddings
182        for relation in &self.relations {
183            let values: Vec<f32> = (0..self.config.dimensions)
184                .map(|_| rng.gen_range(range_min..range_max))
185                .collect();
186            let embedding = DVector::from_vec(values);
187
188            // Relations are not normalized in TransE
189            self.relation_embeddings.insert(relation.clone(), embedding);
190        }
191    }
192
193    /// Generate negative samples
194    fn generate_negative_samples(&self, triple: &Triple, rng: &mut impl Rng) -> Vec<Triple> {
195        let mut negatives = Vec::new();
196
197        for _ in 0..self.config.negative_samples {
198            if rng.gen_bool(0.5) {
199                // Corrupt head
200                let mut negative = triple.clone();
201                loop {
202                    let idx = rng.gen_range(0..self.entities.len());
203                    let entity = &self.entities[idx];
204                    if entity != &triple.subject {
205                        negative.subject = entity.clone();
206                        break;
207                    }
208                }
209                negatives.push(negative);
210            } else {
211                // Corrupt tail
212                let mut negative = triple.clone();
213                loop {
214                    let idx = rng.gen_range(0..self.entities.len());
215                    let entity = &self.entities[idx];
216                    if entity != &triple.object {
217                        negative.object = entity.clone();
218                        break;
219                    }
220                }
221                negatives.push(negative);
222            }
223        }
224
225        negatives
226    }
227
228    /// Calculate distance for a triple
229    fn distance(&self, triple: &Triple) -> f32 {
230        let h = self.entity_embeddings.get(&triple.subject).unwrap();
231        let r = self.relation_embeddings.get(&triple.predicate).unwrap();
232        let t = self.entity_embeddings.get(&triple.object).unwrap();
233
234        let translation = h + r - t;
235
236        match self.config.norm {
237            1 => translation.iter().map(|x| x.abs()).sum(),
238            2 => translation.norm(),
239            _ => translation.norm(),
240        }
241    }
242
243    /// Update embeddings using gradient descent
244    fn update_embeddings(&mut self, positive: &Triple, negatives: &[Triple]) {
245        let pos_dist = self.distance(positive);
246
247        for negative in negatives {
248            let neg_dist = self.distance(negative);
249            let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
250
251            if loss > 0.0 {
252                // Calculate gradients
253                let h_pos = self
254                    .entity_embeddings
255                    .get(&positive.subject)
256                    .unwrap()
257                    .clone();
258                let r = self
259                    .relation_embeddings
260                    .get(&positive.predicate)
261                    .unwrap()
262                    .clone();
263                let t_pos = self
264                    .entity_embeddings
265                    .get(&positive.object)
266                    .unwrap()
267                    .clone();
268
269                let h_neg = self
270                    .entity_embeddings
271                    .get(&negative.subject)
272                    .unwrap()
273                    .clone();
274                let t_neg = self
275                    .entity_embeddings
276                    .get(&negative.object)
277                    .unwrap()
278                    .clone();
279
280                let pos_grad = &h_pos + &r - &t_pos;
281                let neg_grad = &h_neg + &r - &t_neg;
282
283                // Normalize gradients
284                let pos_norm = pos_grad.norm();
285                let neg_norm = neg_grad.norm();
286
287                let pos_grad_norm = if pos_norm > 0.0 {
288                    &pos_grad / pos_norm
289                } else {
290                    pos_grad
291                };
292                let neg_grad_norm = if neg_norm > 0.0 {
293                    &neg_grad / neg_norm
294                } else {
295                    neg_grad
296                };
297
298                // Update embeddings
299                let lr = self.config.learning_rate;
300
301                // Update positive triple embeddings
302                if let Some(h) = self.entity_embeddings.get_mut(&positive.subject) {
303                    *h -= lr * &pos_grad_norm;
304                    // Re-normalize entity
305                    let norm = h.norm();
306                    if norm > 0.0 {
307                        *h /= norm;
308                    }
309                }
310
311                if let Some(r) = self.relation_embeddings.get_mut(&positive.predicate) {
312                    *r -= lr * (&pos_grad_norm - &neg_grad_norm);
313                }
314
315                if let Some(t) = self.entity_embeddings.get_mut(&positive.object) {
316                    *t += lr * &pos_grad_norm;
317                    // Re-normalize entity
318                    let norm = t.norm();
319                    if norm > 0.0 {
320                        *t /= norm;
321                    }
322                }
323
324                // Update negative triple embeddings
325                if positive.subject != negative.subject {
326                    if let Some(h) = self.entity_embeddings.get_mut(&negative.subject) {
327                        *h += lr * &neg_grad_norm;
328                        // Re-normalize entity
329                        let norm = h.norm();
330                        if norm > 0.0 {
331                            *h /= norm;
332                        }
333                    }
334                }
335
336                if positive.object != negative.object {
337                    if let Some(t) = self.entity_embeddings.get_mut(&negative.object) {
338                        *t -= lr * &neg_grad_norm;
339                        // Re-normalize entity
340                        let norm = t.norm();
341                        if norm > 0.0 {
342                            *t /= norm;
343                        }
344                    }
345                }
346            }
347        }
348    }
349}
350
351impl KGEmbeddingModel for TransE {
352    fn train(&mut self, triples: &[Triple]) -> Result<()> {
353        if triples.is_empty() {
354            return Err(anyhow!("No triples provided for training"));
355        }
356
357        // Initialize embeddings
358        self.initialize_embeddings(triples);
359
360        let mut rng = if let Some(seed) = self.config.random_seed {
361            Random::seed(seed)
362        } else {
363            Random::seed(42)
364        };
365
366        // Training loop
367        for epoch in 0..self.config.epochs {
368            let mut total_loss = 0.0;
369            let mut batch_count = 0;
370
371            // Shuffle triples
372            let mut shuffled_triples = triples.to_vec();
373            // Note: Using manual random selection instead of SliceRandom
374            // Manually shuffle using Fisher-Yates algorithm
375            for i in (1..shuffled_triples.len()).rev() {
376                let j = rng.gen_range(0..=i);
377                shuffled_triples.swap(i, j);
378            }
379
380            // Process batches
381            for batch in shuffled_triples.chunks(self.config.batch_size) {
382                for triple in batch {
383                    // Generate negative samples
384                    let negatives = self.generate_negative_samples(triple, &mut rng);
385
386                    // Calculate loss
387                    let pos_dist = self.distance(triple);
388                    for negative in &negatives {
389                        let neg_dist = self.distance(negative);
390                        let loss = (self.config.margin + pos_dist - neg_dist).max(0.0);
391                        total_loss += loss;
392                    }
393
394                    // Update embeddings
395                    self.update_embeddings(triple, &negatives);
396                }
397                batch_count += 1;
398            }
399
400            if epoch % 10 == 0 {
401                let avg_loss = total_loss / (batch_count as f32 * self.config.batch_size as f32);
402                tracing::info!("Epoch {}: Average loss = {:.4}", epoch, avg_loss);
403            }
404        }
405
406        Ok(())
407    }
408
409    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
410        self.entity_embeddings
411            .get(entity)
412            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
413    }
414
415    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
416        self.relation_embeddings
417            .get(relation)
418            .map(|embedding| Vector::new(embedding.iter().cloned().collect()))
419    }
420
421    fn score_triple(&self, triple: &Triple) -> f32 {
422        -self.distance(triple)
423    }
424
425    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
426        let h = match self.entity_embeddings.get(head) {
427            Some(emb) => emb,
428            None => return Vec::new(),
429        };
430
431        let r = match self.relation_embeddings.get(relation) {
432            Some(emb) => emb,
433            None => return Vec::new(),
434        };
435
436        let translation = h + r;
437
438        let mut scores: Vec<(String, f32)> = self
439            .entities
440            .iter()
441            .filter(|e| *e != head)
442            .filter_map(|entity| {
443                self.entity_embeddings.get(entity).map(|t| {
444                    let distance = match self.config.norm {
445                        1 => (&translation - t).iter().map(|x| x.abs()).sum(),
446                        2 => (&translation - t).norm(),
447                        _ => (&translation - t).norm(),
448                    };
449                    (entity.clone(), -distance)
450                })
451            })
452            .collect();
453
454        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
455        scores.truncate(k);
456        scores
457    }
458
459    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
460        let t = match self.entity_embeddings.get(tail) {
461            Some(emb) => emb,
462            None => return Vec::new(),
463        };
464
465        let r = match self.relation_embeddings.get(relation) {
466            Some(emb) => emb,
467            None => return Vec::new(),
468        };
469
470        let target = t - r;
471
472        let mut scores: Vec<(String, f32)> = self
473            .entities
474            .iter()
475            .filter(|e| *e != tail)
476            .filter_map(|entity| {
477                self.entity_embeddings.get(entity).map(|h| {
478                    let distance = match self.config.norm {
479                        1 => (h - &target).iter().map(|x| x.abs()).sum(),
480                        2 => (h - &target).norm(),
481                        _ => (h - &target).norm(),
482                    };
483                    (entity.clone(), -distance)
484                })
485            })
486            .collect();
487
488        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489        scores.truncate(k);
490        scores
491    }
492
493    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
494        self.entity_embeddings
495            .iter()
496            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
497            .collect()
498    }
499
500    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
501        self.relation_embeddings
502            .iter()
503            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
504            .collect()
505    }
506}
507
508/// ComplEx: Complex number embeddings
509/// Uses complex-valued embeddings and Hermitian dot product
510pub struct ComplEx {
511    config: KGEmbeddingConfig,
512    entity_embeddings_real: HashMap<String, DVector<f32>>,
513    entity_embeddings_imag: HashMap<String, DVector<f32>>,
514    relation_embeddings_real: HashMap<String, DVector<f32>>,
515    relation_embeddings_imag: HashMap<String, DVector<f32>>,
516    entities: Vec<String>,
517    relations: Vec<String>,
518}
519
520impl ComplEx {
521    pub fn new(config: KGEmbeddingConfig) -> Self {
522        Self {
523            config,
524            entity_embeddings_real: HashMap::new(),
525            entity_embeddings_imag: HashMap::new(),
526            relation_embeddings_real: HashMap::new(),
527            relation_embeddings_imag: HashMap::new(),
528            entities: Vec::new(),
529            relations: Vec::new(),
530        }
531    }
532
533    /// Initialize embeddings with Xavier initialization
534    fn initialize_embeddings(&mut self, triples: &[Triple]) {
535        // Collect unique entities and relations
536        let mut entities = std::collections::HashSet::new();
537        let mut relations = std::collections::HashSet::new();
538
539        for triple in triples {
540            entities.insert(triple.subject.clone());
541            entities.insert(triple.object.clone());
542            relations.insert(triple.predicate.clone());
543        }
544
545        self.entities = entities.into_iter().collect();
546        self.relations = relations.into_iter().collect();
547
548        // Initialize with Xavier initialization
549        let mut rng = if let Some(seed) = self.config.random_seed {
550            Random::seed(seed)
551        } else {
552            Random::seed(42)
553        };
554
555        let std_dev = (2.0 / self.config.dimensions as f32).sqrt();
556        let normal = Normal::new(0.0, std_dev).unwrap();
557
558        // Initialize entity embeddings
559        for entity in &self.entities {
560            let real_values: Vec<f32> = (0..self.config.dimensions)
561                .map(|_| normal.sample(&mut rng))
562                .collect();
563            let imag_values: Vec<f32> = (0..self.config.dimensions)
564                .map(|_| normal.sample(&mut rng))
565                .collect();
566
567            self.entity_embeddings_real
568                .insert(entity.clone(), DVector::from_vec(real_values));
569            self.entity_embeddings_imag
570                .insert(entity.clone(), DVector::from_vec(imag_values));
571        }
572
573        // Initialize relation embeddings
574        for relation in &self.relations {
575            let real_values: Vec<f32> = (0..self.config.dimensions)
576                .map(|_| normal.sample(&mut rng))
577                .collect();
578            let imag_values: Vec<f32> = (0..self.config.dimensions)
579                .map(|_| normal.sample(&mut rng))
580                .collect();
581
582            self.relation_embeddings_real
583                .insert(relation.clone(), DVector::from_vec(real_values));
584            self.relation_embeddings_imag
585                .insert(relation.clone(), DVector::from_vec(imag_values));
586        }
587    }
588
589    /// Hermitian dot product for scoring
590    fn hermitian_dot(&self, triple: &Triple) -> f32 {
591        let h_real = self.entity_embeddings_real.get(&triple.subject).unwrap();
592        let h_imag = self.entity_embeddings_imag.get(&triple.subject).unwrap();
593        let r_real = self
594            .relation_embeddings_real
595            .get(&triple.predicate)
596            .unwrap();
597        let r_imag = self
598            .relation_embeddings_imag
599            .get(&triple.predicate)
600            .unwrap();
601        let t_real = self.entity_embeddings_real.get(&triple.object).unwrap();
602        let t_imag = self.entity_embeddings_imag.get(&triple.object).unwrap();
603
604        // ComplEx scoring function: Re(<h, r, t̄>)
605        // = Re(∑ h_i * r_i * conj(t_i))
606        // = ∑ (h_real * r_real * t_real + h_real * r_imag * t_imag +
607        //      h_imag * r_real * t_imag - h_imag * r_imag * t_real)
608
609        let mut score = 0.0;
610        for i in 0..self.config.dimensions {
611            score += h_real[i] * r_real[i] * t_real[i]
612                + h_real[i] * r_imag[i] * t_imag[i]
613                + h_imag[i] * r_real[i] * t_imag[i]
614                - h_imag[i] * r_imag[i] * t_real[i];
615        }
616
617        score
618    }
619}
620
621impl KGEmbeddingModel for ComplEx {
622    fn train(&mut self, triples: &[Triple]) -> Result<()> {
623        if triples.is_empty() {
624            return Err(anyhow!("No triples provided for training"));
625        }
626
627        // Initialize embeddings
628        self.initialize_embeddings(triples);
629
630        // Training implementation would go here
631        // For brevity, using a simplified version
632
633        Ok(())
634    }
635
636    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
637        // Return concatenated real and imaginary parts
638        let real = self.entity_embeddings_real.get(entity)?;
639        let imag = self.entity_embeddings_imag.get(entity)?;
640
641        let mut values = Vec::with_capacity(self.config.dimensions * 2);
642        values.extend(real.iter().cloned());
643        values.extend(imag.iter().cloned());
644
645        Some(Vector::new(values))
646    }
647
648    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
649        // Return concatenated real and imaginary parts
650        let real = self.relation_embeddings_real.get(relation)?;
651        let imag = self.relation_embeddings_imag.get(relation)?;
652
653        let mut values = Vec::with_capacity(self.config.dimensions * 2);
654        values.extend(real.iter().cloned());
655        values.extend(imag.iter().cloned());
656
657        Some(Vector::new(values))
658    }
659
660    fn score_triple(&self, triple: &Triple) -> f32 {
661        self.hermitian_dot(triple)
662    }
663
664    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
665        let mut scores: Vec<(String, f32)> = self
666            .entities
667            .iter()
668            .filter(|e| *e != head)
669            .map(|tail| {
670                let triple = Triple::new(head.to_string(), relation.to_string(), tail.clone());
671                let score = self.score_triple(&triple);
672                (tail.clone(), score)
673            })
674            .collect();
675
676        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
677        scores.truncate(k);
678        scores
679    }
680
681    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
682        let mut scores: Vec<(String, f32)> = self
683            .entities
684            .iter()
685            .filter(|e| *e != tail)
686            .map(|head| {
687                let triple = Triple::new(head.clone(), relation.to_string(), tail.to_string());
688                let score = self.score_triple(&triple);
689                (head.clone(), score)
690            })
691            .collect();
692
693        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
694        scores.truncate(k);
695        scores
696    }
697
698    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
699        self.entity_embeddings_real
700            .iter()
701            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
702            .collect()
703    }
704
705    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
706        self.relation_embeddings_real
707            .iter()
708            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
709            .collect()
710    }
711}
712
713/// RotatE: Rotation-based embeddings
714/// Models relations as rotations in complex space
715pub struct RotatE {
716    config: KGEmbeddingConfig,
717    entity_embeddings: HashMap<String, DVector<Complex<f32>>>,
718    relation_embeddings: HashMap<String, DVector<f32>>, // Phase angles
719    entities: Vec<String>,
720    relations: Vec<String>,
721}
722
723impl RotatE {
724    pub fn new(config: KGEmbeddingConfig) -> Self {
725        Self {
726            config,
727            entity_embeddings: HashMap::new(),
728            relation_embeddings: HashMap::new(),
729            entities: Vec::new(),
730            relations: Vec::new(),
731        }
732    }
733
734    /// Initialize embeddings
735    fn initialize_embeddings(&mut self, triples: &[Triple]) {
736        // Collect unique entities and relations
737        let mut entities = std::collections::HashSet::new();
738        let mut relations = std::collections::HashSet::new();
739
740        for triple in triples {
741            entities.insert(triple.subject.clone());
742            entities.insert(triple.object.clone());
743            relations.insert(triple.predicate.clone());
744        }
745
746        self.entities = entities.into_iter().collect();
747        self.relations = relations.into_iter().collect();
748
749        let mut rng = if let Some(seed) = self.config.random_seed {
750            Random::seed(seed)
751        } else {
752            Random::seed(42)
753        };
754
755        // Initialize entity embeddings (complex numbers with unit modulus)
756        let phase_range = -std::f32::consts::PI..std::f32::consts::PI;
757
758        for entity in &self.entities {
759            let phases: Vec<Complex<f32>> = (0..self.config.dimensions)
760                .map(|_| {
761                    let phase = rng.gen_range(phase_range.clone());
762                    Complex::new(phase.cos(), phase.sin())
763                })
764                .collect();
765
766            self.entity_embeddings
767                .insert(entity.clone(), DVector::from_vec(phases));
768        }
769
770        // Initialize relation embeddings (phase angles)
771        for relation in &self.relations {
772            let phases: Vec<f32> = (0..self.config.dimensions)
773                .map(|_| rng.gen_range(phase_range.clone()))
774                .collect();
775
776            self.relation_embeddings
777                .insert(relation.clone(), DVector::from_vec(phases));
778        }
779    }
780
781    /// Calculate distance for RotatE
782    fn distance(&self, triple: &Triple) -> f32 {
783        let h = self.entity_embeddings.get(&triple.subject).unwrap();
784        let r_phases = self.relation_embeddings.get(&triple.predicate).unwrap();
785        let t = self.entity_embeddings.get(&triple.object).unwrap();
786
787        // Convert relation phases to complex numbers
788        let r: DVector<Complex<f32>> = DVector::from_iterator(
789            self.config.dimensions,
790            r_phases
791                .iter()
792                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
793        );
794
795        // Apply rotation: h ∘ r (element-wise complex multiplication)
796        let rotated: DVector<Complex<f32>> = h.component_mul(&r);
797
798        // Calculate distance ||h ∘ r - t||
799        let diff = rotated - t;
800        diff.iter().map(|c| c.norm()).sum::<f32>()
801    }
802}
803
804impl KGEmbeddingModel for RotatE {
805    fn train(&mut self, triples: &[Triple]) -> Result<()> {
806        if triples.is_empty() {
807            return Err(anyhow!("No triples provided for training"));
808        }
809
810        // Initialize embeddings
811        self.initialize_embeddings(triples);
812
813        // Training implementation would go here
814        // For brevity, using a simplified version
815
816        Ok(())
817    }
818
819    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
820        // Return magnitude and phase representation
821        let complex_emb = self.entity_embeddings.get(entity)?;
822
823        let mut values = Vec::with_capacity(self.config.dimensions * 2);
824        for c in complex_emb.iter() {
825            values.push(c.re); // Real part
826            values.push(c.im); // Imaginary part
827        }
828
829        Some(Vector::new(values))
830    }
831
832    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
833        self.relation_embeddings
834            .get(relation)
835            .map(|phases| Vector::new(phases.iter().cloned().collect()))
836    }
837
838    fn score_triple(&self, triple: &Triple) -> f32 {
839        let gamma = 12.0; // Fixed margin parameter for RotatE
840        gamma - self.distance(triple)
841    }
842
843    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
844        let h = match self.entity_embeddings.get(head) {
845            Some(emb) => emb,
846            None => return Vec::new(),
847        };
848
849        let r_phases = match self.relation_embeddings.get(relation) {
850            Some(emb) => emb,
851            None => return Vec::new(),
852        };
853
854        // Convert relation phases to complex numbers
855        let r: DVector<Complex<f32>> = DVector::from_iterator(
856            self.config.dimensions,
857            r_phases
858                .iter()
859                .map(|&phase| Complex::new(phase.cos(), phase.sin())),
860        );
861
862        // Apply rotation
863        let rotated = h.component_mul(&r);
864
865        let mut scores: Vec<(String, f32)> = self
866            .entities
867            .iter()
868            .filter(|e| *e != head)
869            .filter_map(|entity| {
870                self.entity_embeddings.get(entity).map(|t| {
871                    let diff = &rotated - t;
872                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
873                    (entity.clone(), -distance)
874                })
875            })
876            .collect();
877
878        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
879        scores.truncate(k);
880        scores
881    }
882
883    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
884        let t = match self.entity_embeddings.get(tail) {
885            Some(emb) => emb,
886            None => return Vec::new(),
887        };
888
889        let r_phases = match self.relation_embeddings.get(relation) {
890            Some(emb) => emb,
891            None => return Vec::new(),
892        };
893
894        // Convert relation phases to complex numbers (inverse rotation)
895        let r_inv: DVector<Complex<f32>> = DVector::from_iterator(
896            self.config.dimensions,
897            r_phases
898                .iter()
899                .map(|&phase| Complex::new(phase.cos(), -phase.sin())),
900        );
901
902        let mut scores: Vec<(String, f32)> = self
903            .entities
904            .iter()
905            .filter(|e| *e != tail)
906            .filter_map(|entity| {
907                self.entity_embeddings.get(entity).map(|h| {
908                    let rotated = h.component_mul(&r_inv);
909                    let diff = rotated - t;
910                    let distance: f32 = diff.iter().map(|c| c.norm()).sum();
911                    (entity.clone(), -distance)
912                })
913            })
914            .collect();
915
916        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
917        scores.truncate(k);
918        scores
919    }
920
921    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
922        self.entity_embeddings
923            .iter()
924            .map(|(k, v)| {
925                let real_values: Vec<f32> = v.iter().map(|c| c.re).collect();
926                (k.clone(), Vector::new(real_values))
927            })
928            .collect()
929    }
930
931    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
932        self.relation_embeddings
933            .iter()
934            .map(|(k, v)| (k.clone(), Vector::new(v.as_slice().to_vec())))
935            .collect()
936    }
937}
938
939/// Unified knowledge graph embedding interface
940pub struct KGEmbedding {
941    model: Box<dyn KGEmbeddingModel>,
942    config: KGEmbeddingConfig,
943}
944
945impl KGEmbedding {
946    /// Create a new knowledge graph embedding model
947    pub fn new(config: KGEmbeddingConfig) -> Self {
948        let model: Box<dyn KGEmbeddingModel> = match config.model {
949            KGEmbeddingModelType::TransE => Box::new(TransE::new(config.clone())),
950            KGEmbeddingModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
951            KGEmbeddingModelType::RotatE => Box::new(RotatE::new(config.clone())),
952            KGEmbeddingModelType::GCN => {
953                // Create GCN with default parameters
954                let gcn = GCN::new(config.clone());
955                Box::new(GCNAdapter::new(gcn))
956            }
957            KGEmbeddingModelType::GraphSAGE => {
958                // Create GraphSAGE with default parameters
959                let graphsage = GraphSAGE::new(config.clone())
960                    .with_aggregator(crate::gnn_embeddings::AggregatorType::Mean);
961                Box::new(GraphSAGEAdapter::new(graphsage))
962            }
963        };
964
965        Self { model, config }
966    }
967
968    /// Train the model
969    pub fn train(&mut self, triples: &[Triple]) -> Result<()> {
970        self.model.train(triples)
971    }
972
973    /// Get entity embedding
974    pub fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
975        self.model.get_entity_embedding(entity)
976    }
977
978    /// Get relation embedding
979    pub fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
980        self.model.get_relation_embedding(relation)
981    }
982
983    /// Score a triple
984    pub fn score_triple(&self, triple: &Triple) -> f32 {
985        self.model.score_triple(triple)
986    }
987
988    /// Link prediction: predict missing tail
989    pub fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
990        self.model.predict_tail(head, relation, k)
991    }
992
993    /// Link prediction: predict missing head
994    pub fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
995        self.model.predict_head(relation, tail, k)
996    }
997
998    /// Triple classification: determine if a triple is likely true
999    pub fn classify_triple(&self, triple: &Triple, threshold: f32) -> bool {
1000        self.model.score_triple(triple) > threshold
1001    }
1002}
1003
1004/// Adapter to use GCN as a knowledge graph embedding model
1005pub struct GCNAdapter {
1006    gcn: GCN,
1007}
1008
1009impl GCNAdapter {
1010    pub fn new(gcn: GCN) -> Self {
1011        Self { gcn }
1012    }
1013}
1014
1015impl KGEmbeddingModel for GCNAdapter {
1016    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1017        // GCN training would be implemented here
1018        Ok(())
1019    }
1020
1021    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1022        // GCN embeddings would be computed from graph structure
1023        // For now, return a default embedding
1024        Some(Vector::new(vec![0.0; 128]))
1025    }
1026
1027    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1028        // Relations in GCN are typically handled differently
1029        Some(Vector::new(vec![0.0; 128]))
1030    }
1031
1032    fn score_triple(&self, _triple: &Triple) -> f32 {
1033        // GCN scoring would use graph structure
1034        0.5
1035    }
1036
1037    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1038        // Return mock predictions
1039        vec![]
1040    }
1041
1042    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1043        // Return mock predictions
1044        vec![]
1045    }
1046
1047    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1048        HashMap::new()
1049    }
1050
1051    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1052        HashMap::new()
1053    }
1054}
1055
1056/// Adapter to use GraphSAGE as a knowledge graph embedding model
1057pub struct GraphSAGEAdapter {
1058    graphsage: GraphSAGE,
1059}
1060
1061impl GraphSAGEAdapter {
1062    pub fn new(graphsage: GraphSAGE) -> Self {
1063        Self { graphsage }
1064    }
1065}
1066
1067impl KGEmbeddingModel for GraphSAGEAdapter {
1068    fn train(&mut self, _triples: &[Triple]) -> Result<()> {
1069        // GraphSAGE training would be implemented here
1070        Ok(())
1071    }
1072
1073    fn get_entity_embedding(&self, _entity: &str) -> Option<Vector> {
1074        // GraphSAGE embeddings would be computed from neighbors
1075        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1076    }
1077
1078    fn get_relation_embedding(&self, _relation: &str) -> Option<Vector> {
1079        // Relations in GraphSAGE are typically handled differently
1080        Some(Vector::new(vec![0.0; self.graphsage.dimensions()]))
1081    }
1082
1083    fn score_triple(&self, _triple: &Triple) -> f32 {
1084        // GraphSAGE scoring would use neighbor aggregation
1085        0.5
1086    }
1087
1088    fn predict_tail(&self, _head: &str, _relation: &str, _k: usize) -> Vec<(String, f32)> {
1089        // Return mock predictions
1090        vec![]
1091    }
1092
1093    fn predict_head(&self, _relation: &str, _tail: &str, _k: usize) -> Vec<(String, f32)> {
1094        // Return mock predictions
1095        vec![]
1096    }
1097
1098    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
1099        HashMap::new()
1100    }
1101
1102    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
1103        HashMap::new()
1104    }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109    use super::*;
1110
1111    fn create_test_triples() -> Vec<Triple> {
1112        vec![
1113            Triple::new("Alice".to_string(), "knows".to_string(), "Bob".to_string()),
1114            Triple::new(
1115                "Bob".to_string(),
1116                "knows".to_string(),
1117                "Charlie".to_string(),
1118            ),
1119            Triple::new(
1120                "Alice".to_string(),
1121                "likes".to_string(),
1122                "Pizza".to_string(),
1123            ),
1124            Triple::new("Bob".to_string(), "likes".to_string(), "Pasta".to_string()),
1125            Triple::new(
1126                "Charlie".to_string(),
1127                "knows".to_string(),
1128                "Alice".to_string(),
1129            ),
1130        ]
1131    }
1132
1133    #[test]
1134    fn test_transe() {
1135        let config = KGEmbeddingConfig {
1136            model: KGEmbeddingModelType::TransE,
1137            dimensions: 50,
1138            epochs: 10,
1139            ..Default::default()
1140        };
1141
1142        let mut model = KGEmbedding::new(config);
1143        let triples = create_test_triples();
1144
1145        model.train(&triples).unwrap();
1146
1147        // Test embeddings exist
1148        assert!(model.get_entity_embedding("Alice").is_some());
1149        assert!(model.get_relation_embedding("knows").is_some());
1150
1151        // Test scoring
1152        let score = model.score_triple(&triples[0]);
1153        assert!(score.is_finite());
1154
1155        // Test prediction
1156        let predictions = model.predict_tail("Alice", "knows", 2);
1157        assert!(!predictions.is_empty());
1158    }
1159
1160    #[test]
1161    fn test_complex() {
1162        let config = KGEmbeddingConfig {
1163            model: KGEmbeddingModelType::ComplEx,
1164            dimensions: 50,
1165            epochs: 10,
1166            ..Default::default()
1167        };
1168
1169        let mut model = KGEmbedding::new(config);
1170        let triples = create_test_triples();
1171
1172        model.train(&triples).unwrap();
1173
1174        // Test embeddings exist
1175        assert!(model.get_entity_embedding("Bob").is_some());
1176        let emb = model.get_entity_embedding("Bob").unwrap();
1177        assert_eq!(emb.dimensions, 100); // Real + imaginary parts
1178    }
1179
1180    #[test]
1181    fn test_rotate() {
1182        let config = KGEmbeddingConfig {
1183            model: KGEmbeddingModelType::RotatE,
1184            dimensions: 50,
1185            epochs: 10,
1186            ..Default::default()
1187        };
1188
1189        let mut model = KGEmbedding::new(config);
1190        let triples = create_test_triples();
1191
1192        model.train(&triples).unwrap();
1193
1194        // Test embeddings exist
1195        assert!(model.get_entity_embedding("Charlie").is_some());
1196        assert!(model.get_relation_embedding("likes").is_some());
1197
1198        // Test relation embedding is phase angles
1199        let rel_emb = model.get_relation_embedding("likes").unwrap();
1200        assert_eq!(rel_emb.dimensions, 50);
1201    }
1202}