oxirs_vec/
gnn_embeddings.rs

1//! Graph Neural Network (GNN) embeddings for knowledge graphs
2//!
3//! This module implements GNN-based embedding methods:
4//! - GCN: Graph Convolutional Networks
5//! - GraphSAGE: Graph Sample and Aggregate
6
7use crate::random_utils::NormalSampler as Normal;
8use crate::{
9    kg_embeddings::{KGEmbeddingConfig, KGEmbeddingModel, Triple},
10    Vector,
11};
12use anyhow::{anyhow, Result};
13use nalgebra::{DMatrix, DVector};
14use scirs2_core::random::{Random, Rng};
15use std::collections::HashMap;
16
17/// Graph Convolutional Network (GCN) embedding model
18pub struct GCN {
19    config: KGEmbeddingConfig,
20    entity_embeddings: HashMap<String, DVector<f32>>,
21    relation_embeddings: HashMap<String, DVector<f32>>,
22    entities: Vec<String>,
23    relations: Vec<String>,
24    adjacency_matrix: Option<DMatrix<f32>>,
25    weight_matrices: Vec<DMatrix<f32>>,
26    num_layers: usize,
27}
28
29impl GCN {
30    pub fn new(config: KGEmbeddingConfig) -> Self {
31        let num_layers = 2; // Default to 2 layers
32        Self {
33            config,
34            entity_embeddings: HashMap::new(),
35            relation_embeddings: HashMap::new(),
36            entities: Vec::new(),
37            relations: Vec::new(),
38            adjacency_matrix: None,
39            weight_matrices: Vec::new(),
40            num_layers,
41        }
42    }
43
44    /// Initialize GCN with specified number of layers
45    pub fn with_layers(config: KGEmbeddingConfig, num_layers: usize) -> Self {
46        Self {
47            config,
48            entity_embeddings: HashMap::new(),
49            relation_embeddings: HashMap::new(),
50            entities: Vec::new(),
51            relations: Vec::new(),
52            adjacency_matrix: None,
53            weight_matrices: Vec::new(),
54            num_layers,
55        }
56    }
57
58    /// Initialize embeddings and graph structure
59    fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
60        // Collect unique entities and relations
61        let mut entities = std::collections::HashSet::new();
62        let mut relations = std::collections::HashSet::new();
63
64        for triple in triples {
65            entities.insert(triple.subject.clone());
66            entities.insert(triple.object.clone());
67            relations.insert(triple.predicate.clone());
68        }
69
70        self.entities = entities.into_iter().collect();
71        self.relations = relations.into_iter().collect();
72
73        let _num_entities = self.entities.len();
74
75        // Initialize entity embeddings
76        let mut rng = if let Some(seed) = self.config.random_seed {
77            Random::seed(seed)
78        } else {
79            Random::seed(42)
80        };
81
82        let normal = Normal::new(0.0, 0.1)
83            .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
84
85        for entity in &self.entities {
86            let embedding: Vec<f32> = (0..self.config.dimensions)
87                .map(|_| normal.sample(&mut rng))
88                .collect();
89            self.entity_embeddings
90                .insert(entity.clone(), DVector::from_vec(embedding));
91        }
92
93        for relation in &self.relations {
94            let embedding: Vec<f32> = (0..self.config.dimensions)
95                .map(|_| normal.sample(&mut rng))
96                .collect();
97            self.relation_embeddings
98                .insert(relation.clone(), DVector::from_vec(embedding));
99        }
100
101        // Build adjacency matrix
102        self.build_adjacency_matrix(triples)?;
103
104        // Initialize weight matrices for each layer
105        self.weight_matrices.clear();
106        for _ in 0..self.num_layers {
107            let weight_matrix =
108                DMatrix::from_fn(self.config.dimensions, self.config.dimensions, |_, _| {
109                    normal.sample(&mut rng)
110                });
111            self.weight_matrices.push(weight_matrix);
112        }
113
114        Ok(())
115    }
116
117    /// Build adjacency matrix from triples
118    fn build_adjacency_matrix(&mut self, triples: &[Triple]) -> Result<()> {
119        let num_entities = self.entities.len();
120        let mut adj_matrix = DMatrix::zeros(num_entities, num_entities);
121
122        // Create entity index mapping
123        let entity_to_index: HashMap<String, usize> = self
124            .entities
125            .iter()
126            .enumerate()
127            .map(|(i, entity)| (entity.clone(), i))
128            .collect();
129
130        // Fill adjacency matrix
131        for triple in triples {
132            if let (Some(&subject_idx), Some(&object_idx)) = (
133                entity_to_index.get(&triple.subject),
134                entity_to_index.get(&triple.object),
135            ) {
136                adj_matrix[(subject_idx, object_idx)] = 1.0;
137                adj_matrix[(object_idx, subject_idx)] = 1.0; // Undirected graph
138            }
139        }
140
141        // Add self-loops
142        for i in 0..num_entities {
143            adj_matrix[(i, i)] = 1.0;
144        }
145
146        // Normalize adjacency matrix (symmetric normalization)
147        self.adjacency_matrix = Some(self.normalize_adjacency_matrix(adj_matrix));
148
149        Ok(())
150    }
151
152    /// Symmetric normalization of adjacency matrix: D^(-1/2) * A * D^(-1/2)
153    fn normalize_adjacency_matrix(&self, mut adj_matrix: DMatrix<f32>) -> DMatrix<f32> {
154        let num_nodes = adj_matrix.nrows();
155
156        // Calculate degree matrix
157        let mut degrees = Vec::with_capacity(num_nodes);
158        for i in 0..num_nodes {
159            let degree: f32 = (0..num_nodes).map(|j| adj_matrix[(i, j)]).sum();
160            degrees.push(if degree > 0.0 {
161                1.0 / degree.sqrt()
162            } else {
163                0.0
164            });
165        }
166
167        // Apply symmetric normalization
168        for i in 0..num_nodes {
169            for j in 0..num_nodes {
170                adj_matrix[(i, j)] *= degrees[i] * degrees[j];
171            }
172        }
173
174        adj_matrix
175    }
176
177    /// Forward pass through GCN layers
178    fn forward_pass(&self, features: &DMatrix<f32>) -> Result<DMatrix<f32>> {
179        let adj_matrix = self
180            .adjacency_matrix
181            .as_ref()
182            .ok_or_else(|| anyhow!("Adjacency matrix not initialized"))?;
183
184        let mut hidden = features.clone();
185
186        for layer_idx in 0..self.num_layers {
187            let weight = &self.weight_matrices[layer_idx];
188
189            // GCN layer: H^(l+1) = σ(A * H^(l) * W^(l))
190            let linear_transform = &hidden * weight;
191            hidden = adj_matrix * &linear_transform;
192
193            // Apply ReLU activation (except for last layer)
194            if layer_idx < self.num_layers - 1 {
195                hidden = hidden.map(|x| x.max(0.0));
196            }
197        }
198
199        Ok(hidden)
200    }
201
202    /// Train the GCN model
203    fn train_gcn(&mut self, _triples: &[Triple]) -> Result<()> {
204        // Create feature matrix from current embeddings
205        let num_entities = self.entities.len();
206        let mut features = DMatrix::zeros(num_entities, self.config.dimensions);
207
208        for (i, entity) in self.entities.iter().enumerate() {
209            if let Some(embedding) = self.entity_embeddings.get(entity) {
210                for (j, &value) in embedding.iter().enumerate() {
211                    features[(i, j)] = value;
212                }
213            }
214        }
215
216        // Perform forward pass
217        let updated_features = self.forward_pass(&features)?;
218
219        // Update entity embeddings with new features
220        for (i, entity) in self.entities.iter().enumerate() {
221            let new_embedding: Vec<f32> = (0..self.config.dimensions)
222                .map(|j| updated_features[(i, j)])
223                .collect();
224            self.entity_embeddings
225                .insert(entity.clone(), DVector::from_vec(new_embedding));
226        }
227
228        Ok(())
229    }
230}
231
232impl KGEmbeddingModel for GCN {
233    fn train(&mut self, triples: &[Triple]) -> Result<()> {
234        self.initialize(triples)?;
235
236        for epoch in 0..self.config.epochs {
237            self.train_gcn(triples)?;
238
239            if epoch % 10 == 0 {
240                println!("GCN training epoch {}/{}", epoch, self.config.epochs);
241            }
242        }
243
244        Ok(())
245    }
246
247    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
248        self.entity_embeddings
249            .get(entity)
250            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
251    }
252
253    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
254        self.relation_embeddings
255            .get(relation)
256            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
257    }
258
259    fn score_triple(&self, triple: &Triple) -> f32 {
260        // For GCN, we use cosine similarity between subject and object embeddings
261        // after considering the relation
262        if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
263            self.get_entity_embedding(&triple.subject),
264            self.get_relation_embedding(&triple.predicate),
265            self.get_entity_embedding(&triple.object),
266        ) {
267            // Simple approach: h + r should be close to t
268            let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
269            predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
270        } else {
271            0.0
272        }
273    }
274
275    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
276        if let (Some(head_emb), Some(rel_emb)) = (
277            self.get_entity_embedding(head),
278            self.get_relation_embedding(relation),
279        ) {
280            let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
281
282            let mut scores = Vec::new();
283            for entity in &self.entities {
284                if entity != head {
285                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
286                        let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
287                        scores.push((entity.clone(), score));
288                    }
289                }
290            }
291
292            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
293            scores.into_iter().take(k).collect()
294        } else {
295            Vec::new()
296        }
297    }
298
299    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
300        if let (Some(rel_emb), Some(tail_emb)) = (
301            self.get_relation_embedding(relation),
302            self.get_entity_embedding(tail),
303        ) {
304            let mut scores = Vec::new();
305            for entity in &self.entities {
306                if entity != tail {
307                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
308                        let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
309                        let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
310                        scores.push((entity.clone(), score));
311                    }
312                }
313            }
314
315            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
316            scores.into_iter().take(k).collect()
317        } else {
318            Vec::new()
319        }
320    }
321
322    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
323        // This is a bit tricky because we store DVector but need to return HashMap<String, Vector>
324        // For now, we'll return an empty HashMap - this should be refactored
325        HashMap::new()
326    }
327
328    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
329        // Same issue as above
330        HashMap::new()
331    }
332}
333
334/// GraphSAGE (Graph Sample and Aggregate) embedding model
335pub struct GraphSAGE {
336    config: KGEmbeddingConfig,
337    entity_embeddings: HashMap<String, DVector<f32>>,
338    relation_embeddings: HashMap<String, DVector<f32>>,
339    entities: Vec<String>,
340    relations: Vec<String>,
341    graph: HashMap<String, Vec<String>>, // Adjacency list
342    aggregator_type: AggregatorType,
343    num_layers: usize,
344    sample_size: usize,
345    sampling_strategy: SamplingStrategy,
346}
347
348#[derive(Debug, Clone, Copy)]
349pub enum AggregatorType {
350    Mean,
351    LSTM,
352    Pool,
353    Attention,
354}
355
356#[derive(Debug, Clone, Copy)]
357pub enum SamplingStrategy {
358    Uniform,  // Uniform random sampling
359    Degree,   // Degree-based sampling (prefer high-degree neighbors)
360    PageRank, // PageRank-based sampling (prefer important neighbors)
361    Recent,   // Sample recently added neighbors (for temporal graphs)
362}
363
364impl GraphSAGE {
365    pub fn new(config: KGEmbeddingConfig) -> Self {
366        Self {
367            config,
368            entity_embeddings: HashMap::new(),
369            relation_embeddings: HashMap::new(),
370            entities: Vec::new(),
371            relations: Vec::new(),
372            graph: HashMap::new(),
373            aggregator_type: AggregatorType::Mean,
374            num_layers: 2,
375            sample_size: 10, // Number of neighbors to sample
376            sampling_strategy: SamplingStrategy::Uniform,
377        }
378    }
379
380    pub fn with_aggregator(mut self, aggregator: AggregatorType) -> Self {
381        self.aggregator_type = aggregator;
382        self
383    }
384
385    pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
386        self.sampling_strategy = strategy;
387        self
388    }
389
390    pub fn with_sample_size(mut self, size: usize) -> Self {
391        self.sample_size = size;
392        self
393    }
394
395    /// Get embedding dimensions
396    pub fn dimensions(&self) -> usize {
397        self.config.dimensions
398    }
399
400    /// Initialize GraphSAGE model
401    fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
402        // Collect unique entities and relations
403        let mut entities = std::collections::HashSet::new();
404        let mut relations = std::collections::HashSet::new();
405
406        for triple in triples {
407            entities.insert(triple.subject.clone());
408            entities.insert(triple.object.clone());
409            relations.insert(triple.predicate.clone());
410        }
411
412        self.entities = entities.into_iter().collect();
413        self.relations = relations.into_iter().collect();
414
415        // Build graph adjacency list
416        self.build_graph(triples);
417
418        // Initialize embeddings
419        let mut rng = if let Some(seed) = self.config.random_seed {
420            Random::seed(seed)
421        } else {
422            Random::seed(42)
423        };
424
425        let normal = Normal::new(0.0, 0.1)
426            .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
427
428        for entity in &self.entities {
429            let embedding: Vec<f32> = (0..self.config.dimensions)
430                .map(|_| normal.sample(&mut rng))
431                .collect();
432            self.entity_embeddings
433                .insert(entity.clone(), DVector::from_vec(embedding));
434        }
435
436        for relation in &self.relations {
437            let embedding: Vec<f32> = (0..self.config.dimensions)
438                .map(|_| normal.sample(&mut rng))
439                .collect();
440            self.relation_embeddings
441                .insert(relation.clone(), DVector::from_vec(embedding));
442        }
443
444        Ok(())
445    }
446
447    /// Build graph adjacency list
448    fn build_graph(&mut self, triples: &[Triple]) {
449        for triple in triples {
450            self.graph
451                .entry(triple.subject.clone())
452                .or_default()
453                .push(triple.object.clone());
454
455            self.graph
456                .entry(triple.object.clone())
457                .or_default()
458                .push(triple.subject.clone());
459        }
460    }
461
462    /// Sample neighbors for a node using different strategies
463    #[allow(deprecated)]
464    fn sample_neighbors(&self, node: &str, rng: &mut impl Rng) -> Vec<String> {
465        if let Some(neighbors) = self.graph.get(node) {
466            if neighbors.len() <= self.sample_size {
467                neighbors.clone()
468            } else {
469                match self.sampling_strategy {
470                    SamplingStrategy::Uniform => {
471                        // Note: Using manual random selection instead of SliceRandom
472                        // Manually sample neighbors using reservoir sampling
473                        let mut sampled = Vec::new();
474                        let sample_size = std::cmp::min(self.sample_size, neighbors.len());
475                        for (i, neighbor) in neighbors.iter().enumerate() {
476                            if sampled.len() < sample_size {
477                                sampled.push(neighbor.clone());
478                            } else {
479                                let j = rng.gen_range(0..=i);
480                                if j < sample_size {
481                                    sampled[j] = neighbor.clone();
482                                }
483                            }
484                        }
485                        sampled
486                    }
487                    SamplingStrategy::Degree => self.degree_based_sampling(neighbors, rng),
488                    SamplingStrategy::PageRank => {
489                        // Simplified PageRank-based sampling (use degree as approximation)
490                        self.degree_based_sampling(neighbors, rng)
491                    }
492                    SamplingStrategy::Recent => {
493                        // For recent sampling, take the last added neighbors
494                        neighbors
495                            .iter()
496                            .rev()
497                            .take(self.sample_size)
498                            .cloned()
499                            .collect()
500                    }
501                }
502            }
503        } else {
504            Vec::new()
505        }
506    }
507
508    /// Degree-based sampling: prefer neighbors with higher degree
509    #[allow(deprecated)]
510    fn degree_based_sampling(&self, neighbors: &[String], rng: &mut impl Rng) -> Vec<String> {
511        let mut neighbor_degrees: Vec<(String, usize)> = neighbors
512            .iter()
513            .map(|neighbor| {
514                let degree = self.graph.get(neighbor).map(|n| n.len()).unwrap_or(0);
515                (neighbor.clone(), degree)
516            })
517            .collect();
518
519        // Sort by degree (descending) and add some randomization
520        neighbor_degrees.sort_by(|a, b| {
521            let degree_cmp = b.1.cmp(&a.1);
522            if degree_cmp == std::cmp::Ordering::Equal {
523                // Add randomization for ties
524                if rng.gen_bool(0.5) {
525                    std::cmp::Ordering::Greater
526                } else {
527                    std::cmp::Ordering::Less
528                }
529            } else {
530                degree_cmp
531            }
532        });
533
534        neighbor_degrees
535            .into_iter()
536            .take(self.sample_size)
537            .map(|(neighbor, _)| neighbor)
538            .collect()
539    }
540
541    /// Aggregate neighbor embeddings
542    fn aggregate_neighbors(&self, neighbors: &[String]) -> Result<DVector<f32>> {
543        if neighbors.is_empty() {
544            return Ok(DVector::zeros(self.config.dimensions));
545        }
546
547        match self.aggregator_type {
548            AggregatorType::Mean => {
549                let mut sum = DVector::zeros(self.config.dimensions);
550                let mut count = 0;
551
552                for neighbor in neighbors {
553                    if let Some(embedding) = self.entity_embeddings.get(neighbor) {
554                        sum += embedding;
555                        count += 1;
556                    }
557                }
558
559                if count > 0 {
560                    Ok(sum / count as f32)
561                } else {
562                    Ok(DVector::zeros(self.config.dimensions))
563                }
564            }
565            AggregatorType::Pool => {
566                // Max pooling aggregator
567                let mut max_embedding =
568                    DVector::from_element(self.config.dimensions, f32::NEG_INFINITY);
569
570                for neighbor in neighbors {
571                    if let Some(embedding) = self.entity_embeddings.get(neighbor) {
572                        for i in 0..self.config.dimensions {
573                            max_embedding[i] = max_embedding[i].max(embedding[i]);
574                        }
575                    }
576                }
577
578                // Replace negative infinity with zeros
579                for i in 0..self.config.dimensions {
580                    if max_embedding[i] == f32::NEG_INFINITY {
581                        max_embedding[i] = 0.0;
582                    }
583                }
584
585                Ok(max_embedding)
586            }
587            AggregatorType::LSTM => {
588                // LSTM-based aggregator
589                self.lstm_aggregate(neighbors)
590            }
591            AggregatorType::Attention => {
592                // Attention-based aggregator
593                self.attention_aggregate(neighbors)
594            }
595        }
596    }
597
598    /// LSTM-based aggregator (simplified implementation)
599    fn lstm_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
600        if neighbors.is_empty() {
601            return Ok(DVector::zeros(self.config.dimensions));
602        }
603
604        // Simplified LSTM: process neighbors sequentially with forget/input gates
605        let mut cell_state = DVector::zeros(self.config.dimensions);
606        let mut hidden_state = DVector::zeros(self.config.dimensions);
607
608        for neighbor in neighbors {
609            if let Some(embedding) = self.entity_embeddings.get(neighbor) {
610                // Simplified LSTM gates (using tanh and sigmoid approximations)
611                let forget_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp())); // sigmoid
612                let input_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
613                let candidate = embedding.map(|x| x.tanh()); // tanh
614
615                // Update cell state
616                cell_state =
617                    cell_state.component_mul(&forget_gate) + input_gate.component_mul(&candidate);
618
619                // Update hidden state
620                let output_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
621                hidden_state = output_gate.component_mul(&cell_state.map(|x| x.tanh()));
622            }
623        }
624
625        Ok(hidden_state)
626    }
627
628    /// Attention-based aggregator
629    fn attention_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
630        if neighbors.is_empty() {
631            return Ok(DVector::zeros(self.config.dimensions));
632        }
633
634        let neighbor_embeddings: Vec<&DVector<f32>> = neighbors
635            .iter()
636            .filter_map(|neighbor| self.entity_embeddings.get(neighbor))
637            .collect();
638
639        if neighbor_embeddings.is_empty() {
640            return Ok(DVector::zeros(self.config.dimensions));
641        }
642
643        // Simple attention mechanism using dot-product attention
644        let mut attention_scores = Vec::new();
645        let mut weighted_sum = DVector::zeros(self.config.dimensions);
646
647        // Calculate attention scores (simplified: using magnitude as query)
648        let query = DVector::from_element(self.config.dimensions, 1.0); // Simple uniform query
649
650        for embedding in &neighbor_embeddings {
651            let score = query.dot(embedding).exp(); // Softmax will normalize
652            attention_scores.push(score);
653        }
654
655        // Normalize attention scores (softmax)
656        let total_score: f32 = attention_scores.iter().sum();
657        if total_score > 0.0 {
658            for score in &mut attention_scores {
659                *score /= total_score;
660            }
661        }
662
663        // Calculate weighted sum
664        for (embedding, &score) in neighbor_embeddings.iter().zip(attention_scores.iter()) {
665            weighted_sum += *embedding * score;
666        }
667
668        Ok(weighted_sum)
669    }
670
671    /// Forward pass for a single node
672    fn forward_node(&self, node: &str, rng: &mut impl Rng) -> Result<DVector<f32>> {
673        let neighbors = self.sample_neighbors(node, rng);
674        let neighbor_aggregate = self.aggregate_neighbors(&neighbors)?;
675
676        if let Some(node_embedding) = self.entity_embeddings.get(node) {
677            // Concatenate node embedding with aggregated neighbor embeddings
678            // For simplicity, we'll just add them (should be concatenation + linear transformation)
679            Ok(node_embedding + neighbor_aggregate)
680        } else {
681            Ok(neighbor_aggregate)
682        }
683    }
684}
685
686impl KGEmbeddingModel for GraphSAGE {
687    fn train(&mut self, triples: &[Triple]) -> Result<()> {
688        self.initialize(triples)?;
689
690        let mut rng = if let Some(seed) = self.config.random_seed {
691            Random::seed(seed)
692        } else {
693            Random::seed(42)
694        };
695
696        for epoch in 0..self.config.epochs {
697            let mut new_embeddings = HashMap::new();
698
699            // Update embeddings for all entities
700            for entity in &self.entities {
701                let new_embedding = self.forward_node(entity, &mut rng)?;
702                new_embeddings.insert(entity.clone(), new_embedding);
703            }
704
705            // Update embeddings
706            self.entity_embeddings = new_embeddings;
707
708            if epoch % 10 == 0 {
709                println!("GraphSAGE training epoch {}/{}", epoch, self.config.epochs);
710            }
711        }
712
713        Ok(())
714    }
715
716    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
717        self.entity_embeddings
718            .get(entity)
719            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
720    }
721
722    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
723        self.relation_embeddings
724            .get(relation)
725            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
726    }
727
728    fn score_triple(&self, triple: &Triple) -> f32 {
729        if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
730            self.get_entity_embedding(&triple.subject),
731            self.get_relation_embedding(&triple.predicate),
732            self.get_entity_embedding(&triple.object),
733        ) {
734            let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
735            predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
736        } else {
737            0.0
738        }
739    }
740
741    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
742        if let (Some(head_emb), Some(rel_emb)) = (
743            self.get_entity_embedding(head),
744            self.get_relation_embedding(relation),
745        ) {
746            let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
747
748            let mut scores = Vec::new();
749            for entity in &self.entities {
750                if entity != head {
751                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
752                        let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
753                        scores.push((entity.clone(), score));
754                    }
755                }
756            }
757
758            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
759            scores.into_iter().take(k).collect()
760        } else {
761            Vec::new()
762        }
763    }
764
765    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
766        if let (Some(rel_emb), Some(tail_emb)) = (
767            self.get_relation_embedding(relation),
768            self.get_entity_embedding(tail),
769        ) {
770            let mut scores = Vec::new();
771            for entity in &self.entities {
772                if entity != tail {
773                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
774                        let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
775                        let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
776                        scores.push((entity.clone(), score));
777                    }
778                }
779            }
780
781            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
782            scores.into_iter().take(k).collect()
783        } else {
784            Vec::new()
785        }
786    }
787
788    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
789        HashMap::new()
790    }
791
792    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
793        HashMap::new()
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    #[test]
802    fn test_gcn_creation() {
803        let config = KGEmbeddingConfig {
804            model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
805            dimensions: 64,
806            learning_rate: 0.01,
807            margin: 1.0,
808            negative_samples: 5,
809            batch_size: 32,
810            epochs: 10,
811            norm: 2,
812            random_seed: Some(42),
813            regularization: 0.01,
814        };
815
816        let gcn = GCN::new(config);
817        assert_eq!(gcn.num_layers, 2);
818    }
819
820    #[test]
821    fn test_graphsage_creation() {
822        let config = KGEmbeddingConfig {
823            model: crate::kg_embeddings::KGEmbeddingModelType::GraphSAGE,
824            dimensions: 64,
825            learning_rate: 0.01,
826            margin: 1.0,
827            negative_samples: 5,
828            batch_size: 32,
829            epochs: 10,
830            norm: 2,
831            random_seed: Some(42),
832            regularization: 0.01,
833        };
834
835        let graphsage = GraphSAGE::new(config);
836        assert_eq!(graphsage.sample_size, 10);
837    }
838
839    #[test]
840    fn test_gnn_training() {
841        let config = KGEmbeddingConfig {
842            model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
843            dimensions: 32,
844            learning_rate: 0.01,
845            margin: 1.0,
846            negative_samples: 5,
847            batch_size: 16,
848            epochs: 5,
849            norm: 2,
850            random_seed: Some(42),
851            regularization: 0.01,
852        };
853
854        let mut gcn = GCN::new(config);
855
856        let triples = vec![
857            Triple::new(
858                "entity1".to_string(),
859                "relation1".to_string(),
860                "entity2".to_string(),
861            ),
862            Triple::new(
863                "entity2".to_string(),
864                "relation2".to_string(),
865                "entity3".to_string(),
866            ),
867            Triple::new(
868                "entity1".to_string(),
869                "relation3".to_string(),
870                "entity3".to_string(),
871            ),
872        ];
873
874        // Should not panic
875        gcn.train(&triples).unwrap();
876
877        // Should have embeddings for all entities
878        assert!(gcn.get_entity_embedding("entity1").is_some());
879        assert!(gcn.get_entity_embedding("entity2").is_some());
880        assert!(gcn.get_entity_embedding("entity3").is_some());
881    }
882}