oxirs_vec/
gnn_embeddings.rs

1//! Graph Neural Network (GNN) embeddings for knowledge graphs
2//!
3//! This module implements GNN-based embedding methods:
4//! - GCN: Graph Convolutional Networks
5//! - GraphSAGE: Graph Sample and Aggregate
6
7use crate::{
8    kg_embeddings::{KGEmbeddingConfig, KGEmbeddingModel, Triple},
9    Vector,
10};
11use anyhow::{anyhow, Result};
12use nalgebra::{DMatrix, DVector};
13use crate::random_utils::NormalSampler as Normal;
14use scirs2_core::random::{Random, Rng};
15use std::collections::HashMap;
16
17/// Graph Convolutional Network (GCN) embedding model
18pub struct GCN {
19    config: KGEmbeddingConfig,
20    entity_embeddings: HashMap<String, DVector<f32>>,
21    relation_embeddings: HashMap<String, DVector<f32>>,
22    entities: Vec<String>,
23    relations: Vec<String>,
24    adjacency_matrix: Option<DMatrix<f32>>,
25    weight_matrices: Vec<DMatrix<f32>>,
26    num_layers: usize,
27}
28
29impl GCN {
30    pub fn new(config: KGEmbeddingConfig) -> Self {
31        let num_layers = 2; // Default to 2 layers
32        Self {
33            config,
34            entity_embeddings: HashMap::new(),
35            relation_embeddings: HashMap::new(),
36            entities: Vec::new(),
37            relations: Vec::new(),
38            adjacency_matrix: None,
39            weight_matrices: Vec::new(),
40            num_layers,
41        }
42    }
43
44    /// Initialize GCN with specified number of layers
45    pub fn with_layers(config: KGEmbeddingConfig, num_layers: usize) -> Self {
46        Self {
47            config,
48            entity_embeddings: HashMap::new(),
49            relation_embeddings: HashMap::new(),
50            entities: Vec::new(),
51            relations: Vec::new(),
52            adjacency_matrix: None,
53            weight_matrices: Vec::new(),
54            num_layers,
55        }
56    }
57
58    /// Initialize embeddings and graph structure
59    fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
60        // Collect unique entities and relations
61        let mut entities = std::collections::HashSet::new();
62        let mut relations = std::collections::HashSet::new();
63
64        for triple in triples {
65            entities.insert(triple.subject.clone());
66            entities.insert(triple.object.clone());
67            relations.insert(triple.predicate.clone());
68        }
69
70        self.entities = entities.into_iter().collect();
71        self.relations = relations.into_iter().collect();
72
73        let _num_entities = self.entities.len();
74
75        // Initialize entity embeddings
76        let mut rng = if let Some(seed) = self.config.random_seed {
77            Random::seed(seed)
78        } else {
79            Random::seed(42)
80        };
81
82        let normal = Normal::new(0.0, 0.1)
83            .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
84
85        for entity in &self.entities {
86            let embedding: Vec<f32> = (0..self.config.dimensions)
87                .map(|_| normal.sample(&mut rng))
88                .collect();
89            self.entity_embeddings
90                .insert(entity.clone(), DVector::from_vec(embedding));
91        }
92
93        for relation in &self.relations {
94            let embedding: Vec<f32> = (0..self.config.dimensions)
95                .map(|_| normal.sample(&mut rng))
96                .collect();
97            self.relation_embeddings
98                .insert(relation.clone(), DVector::from_vec(embedding));
99        }
100
101        // Build adjacency matrix
102        self.build_adjacency_matrix(triples)?;
103
104        // Initialize weight matrices for each layer
105        self.weight_matrices.clear();
106        for _ in 0..self.num_layers {
107            let weight_matrix =
108                DMatrix::from_fn(self.config.dimensions, self.config.dimensions, |_, _| {
109                    normal.sample(&mut rng)
110                });
111            self.weight_matrices.push(weight_matrix);
112        }
113
114        Ok(())
115    }
116
117    /// Build adjacency matrix from triples
118    fn build_adjacency_matrix(&mut self, triples: &[Triple]) -> Result<()> {
119        let num_entities = self.entities.len();
120        let mut adj_matrix = DMatrix::zeros(num_entities, num_entities);
121
122        // Create entity index mapping
123        let entity_to_index: HashMap<String, usize> = self
124            .entities
125            .iter()
126            .enumerate()
127            .map(|(i, entity)| (entity.clone(), i))
128            .collect();
129
130        // Fill adjacency matrix
131        for triple in triples {
132            if let (Some(&subject_idx), Some(&object_idx)) = (
133                entity_to_index.get(&triple.subject),
134                entity_to_index.get(&triple.object),
135            ) {
136                adj_matrix[(subject_idx, object_idx)] = 1.0;
137                adj_matrix[(object_idx, subject_idx)] = 1.0; // Undirected graph
138            }
139        }
140
141        // Add self-loops
142        for i in 0..num_entities {
143            adj_matrix[(i, i)] = 1.0;
144        }
145
146        // Normalize adjacency matrix (symmetric normalization)
147        self.adjacency_matrix = Some(self.normalize_adjacency_matrix(adj_matrix));
148
149        Ok(())
150    }
151
152    /// Symmetric normalization of adjacency matrix: D^(-1/2) * A * D^(-1/2)
153    fn normalize_adjacency_matrix(&self, mut adj_matrix: DMatrix<f32>) -> DMatrix<f32> {
154        let num_nodes = adj_matrix.nrows();
155
156        // Calculate degree matrix
157        let mut degrees = Vec::with_capacity(num_nodes);
158        for i in 0..num_nodes {
159            let degree: f32 = (0..num_nodes).map(|j| adj_matrix[(i, j)]).sum();
160            degrees.push(if degree > 0.0 {
161                1.0 / degree.sqrt()
162            } else {
163                0.0
164            });
165        }
166
167        // Apply symmetric normalization
168        for i in 0..num_nodes {
169            for j in 0..num_nodes {
170                adj_matrix[(i, j)] *= degrees[i] * degrees[j];
171            }
172        }
173
174        adj_matrix
175    }
176
177    /// Forward pass through GCN layers
178    fn forward_pass(&self, features: &DMatrix<f32>) -> Result<DMatrix<f32>> {
179        let adj_matrix = self
180            .adjacency_matrix
181            .as_ref()
182            .ok_or_else(|| anyhow!("Adjacency matrix not initialized"))?;
183
184        let mut hidden = features.clone();
185
186        for layer_idx in 0..self.num_layers {
187            let weight = &self.weight_matrices[layer_idx];
188
189            // GCN layer: H^(l+1) = σ(A * H^(l) * W^(l))
190            let linear_transform = &hidden * weight;
191            hidden = adj_matrix * &linear_transform;
192
193            // Apply ReLU activation (except for last layer)
194            if layer_idx < self.num_layers - 1 {
195                hidden = hidden.map(|x| x.max(0.0));
196            }
197        }
198
199        Ok(hidden)
200    }
201
202    /// Train the GCN model
203    fn train_gcn(&mut self, _triples: &[Triple]) -> Result<()> {
204        // Create feature matrix from current embeddings
205        let num_entities = self.entities.len();
206        let mut features = DMatrix::zeros(num_entities, self.config.dimensions);
207
208        for (i, entity) in self.entities.iter().enumerate() {
209            if let Some(embedding) = self.entity_embeddings.get(entity) {
210                for (j, &value) in embedding.iter().enumerate() {
211                    features[(i, j)] = value;
212                }
213            }
214        }
215
216        // Perform forward pass
217        let updated_features = self.forward_pass(&features)?;
218
219        // Update entity embeddings with new features
220        for (i, entity) in self.entities.iter().enumerate() {
221            let new_embedding: Vec<f32> = (0..self.config.dimensions)
222                .map(|j| updated_features[(i, j)])
223                .collect();
224            self.entity_embeddings
225                .insert(entity.clone(), DVector::from_vec(new_embedding));
226        }
227
228        Ok(())
229    }
230}
231
232impl KGEmbeddingModel for GCN {
233    fn train(&mut self, triples: &[Triple]) -> Result<()> {
234        self.initialize(triples)?;
235
236        for epoch in 0..self.config.epochs {
237            self.train_gcn(triples)?;
238
239            if epoch % 10 == 0 {
240                println!("GCN training epoch {}/{}", epoch, self.config.epochs);
241            }
242        }
243
244        Ok(())
245    }
246
247    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
248        self.entity_embeddings
249            .get(entity)
250            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
251    }
252
253    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
254        self.relation_embeddings
255            .get(relation)
256            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
257    }
258
259    fn score_triple(&self, triple: &Triple) -> f32 {
260        // For GCN, we use cosine similarity between subject and object embeddings
261        // after considering the relation
262        if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
263            self.get_entity_embedding(&triple.subject),
264            self.get_relation_embedding(&triple.predicate),
265            self.get_entity_embedding(&triple.object),
266        ) {
267            // Simple approach: h + r should be close to t
268            let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
269            predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
270        } else {
271            0.0
272        }
273    }
274
275    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
276        if let (Some(head_emb), Some(rel_emb)) = (
277            self.get_entity_embedding(head),
278            self.get_relation_embedding(relation),
279        ) {
280            let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
281
282            let mut scores = Vec::new();
283            for entity in &self.entities {
284                if entity != head {
285                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
286                        let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
287                        scores.push((entity.clone(), score));
288                    }
289                }
290            }
291
292            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
293            scores.into_iter().take(k).collect()
294        } else {
295            Vec::new()
296        }
297    }
298
299    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
300        if let (Some(rel_emb), Some(tail_emb)) = (
301            self.get_relation_embedding(relation),
302            self.get_entity_embedding(tail),
303        ) {
304            let mut scores = Vec::new();
305            for entity in &self.entities {
306                if entity != tail {
307                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
308                        let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
309                        let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
310                        scores.push((entity.clone(), score));
311                    }
312                }
313            }
314
315            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
316            scores.into_iter().take(k).collect()
317        } else {
318            Vec::new()
319        }
320    }
321
322    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
323        // This is a bit tricky because we store DVector but need to return HashMap<String, Vector>
324        // For now, we'll return an empty HashMap - this should be refactored
325        HashMap::new()
326    }
327
328    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
329        // Same issue as above
330        HashMap::new()
331    }
332}
333
334/// GraphSAGE (Graph Sample and Aggregate) embedding model
335pub struct GraphSAGE {
336    config: KGEmbeddingConfig,
337    entity_embeddings: HashMap<String, DVector<f32>>,
338    relation_embeddings: HashMap<String, DVector<f32>>,
339    entities: Vec<String>,
340    relations: Vec<String>,
341    graph: HashMap<String, Vec<String>>, // Adjacency list
342    aggregator_type: AggregatorType,
343    num_layers: usize,
344    sample_size: usize,
345    sampling_strategy: SamplingStrategy,
346}
347
348#[derive(Debug, Clone, Copy)]
349pub enum AggregatorType {
350    Mean,
351    LSTM,
352    Pool,
353    Attention,
354}
355
356#[derive(Debug, Clone, Copy)]
357pub enum SamplingStrategy {
358    Uniform,  // Uniform random sampling
359    Degree,   // Degree-based sampling (prefer high-degree neighbors)
360    PageRank, // PageRank-based sampling (prefer important neighbors)
361    Recent,   // Sample recently added neighbors (for temporal graphs)
362}
363
364impl GraphSAGE {
365    pub fn new(config: KGEmbeddingConfig) -> Self {
366        Self {
367            config,
368            entity_embeddings: HashMap::new(),
369            relation_embeddings: HashMap::new(),
370            entities: Vec::new(),
371            relations: Vec::new(),
372            graph: HashMap::new(),
373            aggregator_type: AggregatorType::Mean,
374            num_layers: 2,
375            sample_size: 10, // Number of neighbors to sample
376            sampling_strategy: SamplingStrategy::Uniform,
377        }
378    }
379
380    pub fn with_aggregator(mut self, aggregator: AggregatorType) -> Self {
381        self.aggregator_type = aggregator;
382        self
383    }
384
385    pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
386        self.sampling_strategy = strategy;
387        self
388    }
389
390    pub fn with_sample_size(mut self, size: usize) -> Self {
391        self.sample_size = size;
392        self
393    }
394
395    /// Get embedding dimensions
396    pub fn dimensions(&self) -> usize {
397        self.config.dimensions
398    }
399
400    /// Initialize GraphSAGE model
401    fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
402        // Collect unique entities and relations
403        let mut entities = std::collections::HashSet::new();
404        let mut relations = std::collections::HashSet::new();
405
406        for triple in triples {
407            entities.insert(triple.subject.clone());
408            entities.insert(triple.object.clone());
409            relations.insert(triple.predicate.clone());
410        }
411
412        self.entities = entities.into_iter().collect();
413        self.relations = relations.into_iter().collect();
414
415        // Build graph adjacency list
416        self.build_graph(triples);
417
418        // Initialize embeddings
419        let mut rng = if let Some(seed) = self.config.random_seed {
420            Random::seed(seed)
421        } else {
422            Random::seed(42)
423        };
424
425        let normal = Normal::new(0.0, 0.1)
426            .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
427
428        for entity in &self.entities {
429            let embedding: Vec<f32> = (0..self.config.dimensions)
430                .map(|_| normal.sample(&mut rng))
431                .collect();
432            self.entity_embeddings
433                .insert(entity.clone(), DVector::from_vec(embedding));
434        }
435
436        for relation in &self.relations {
437            let embedding: Vec<f32> = (0..self.config.dimensions)
438                .map(|_| normal.sample(&mut rng))
439                .collect();
440            self.relation_embeddings
441                .insert(relation.clone(), DVector::from_vec(embedding));
442        }
443
444        Ok(())
445    }
446
447    /// Build graph adjacency list
448    fn build_graph(&mut self, triples: &[Triple]) {
449        for triple in triples {
450            self.graph
451                .entry(triple.subject.clone())
452                .or_default()
453                .push(triple.object.clone());
454
455            self.graph
456                .entry(triple.object.clone())
457                .or_default()
458                .push(triple.subject.clone());
459        }
460    }
461
462    /// Sample neighbors for a node using different strategies
463    fn sample_neighbors(&self, node: &str, rng: &mut impl Rng) -> Vec<String> {
464        if let Some(neighbors) = self.graph.get(node) {
465            if neighbors.len() <= self.sample_size {
466                neighbors.clone()
467            } else {
468                match self.sampling_strategy {
469                    SamplingStrategy::Uniform => {
470                        // Note: Using manual random selection instead of SliceRandom
471                        // Manually sample neighbors using reservoir sampling
472                        let mut sampled = Vec::new();
473                        let sample_size = std::cmp::min(self.sample_size, neighbors.len());
474                        for (i, neighbor) in neighbors.iter().enumerate() {
475                            if sampled.len() < sample_size {
476                                sampled.push(neighbor.clone());
477                            } else {
478                                let j = rng.gen_range(0..=i);
479                                if j < sample_size {
480                                    sampled[j] = neighbor.clone();
481                                }
482                            }
483                        }
484                        sampled
485                    }
486                    SamplingStrategy::Degree => self.degree_based_sampling(neighbors, rng),
487                    SamplingStrategy::PageRank => {
488                        // Simplified PageRank-based sampling (use degree as approximation)
489                        self.degree_based_sampling(neighbors, rng)
490                    }
491                    SamplingStrategy::Recent => {
492                        // For recent sampling, take the last added neighbors
493                        neighbors
494                            .iter()
495                            .rev()
496                            .take(self.sample_size)
497                            .cloned()
498                            .collect()
499                    }
500                }
501            }
502        } else {
503            Vec::new()
504        }
505    }
506
507    /// Degree-based sampling: prefer neighbors with higher degree
508    fn degree_based_sampling(&self, neighbors: &[String], rng: &mut impl Rng) -> Vec<String> {
509        let mut neighbor_degrees: Vec<(String, usize)> = neighbors
510            .iter()
511            .map(|neighbor| {
512                let degree = self.graph.get(neighbor).map(|n| n.len()).unwrap_or(0);
513                (neighbor.clone(), degree)
514            })
515            .collect();
516
517        // Sort by degree (descending) and add some randomization
518        neighbor_degrees.sort_by(|a, b| {
519            let degree_cmp = b.1.cmp(&a.1);
520            if degree_cmp == std::cmp::Ordering::Equal {
521                // Add randomization for ties
522                if rng.gen_bool(0.5) {
523                    std::cmp::Ordering::Greater
524                } else {
525                    std::cmp::Ordering::Less
526                }
527            } else {
528                degree_cmp
529            }
530        });
531
532        neighbor_degrees
533            .into_iter()
534            .take(self.sample_size)
535            .map(|(neighbor, _)| neighbor)
536            .collect()
537    }
538
539    /// Aggregate neighbor embeddings
540    fn aggregate_neighbors(&self, neighbors: &[String]) -> Result<DVector<f32>> {
541        if neighbors.is_empty() {
542            return Ok(DVector::zeros(self.config.dimensions));
543        }
544
545        match self.aggregator_type {
546            AggregatorType::Mean => {
547                let mut sum = DVector::zeros(self.config.dimensions);
548                let mut count = 0;
549
550                for neighbor in neighbors {
551                    if let Some(embedding) = self.entity_embeddings.get(neighbor) {
552                        sum += embedding;
553                        count += 1;
554                    }
555                }
556
557                if count > 0 {
558                    Ok(sum / count as f32)
559                } else {
560                    Ok(DVector::zeros(self.config.dimensions))
561                }
562            }
563            AggregatorType::Pool => {
564                // Max pooling aggregator
565                let mut max_embedding =
566                    DVector::from_element(self.config.dimensions, f32::NEG_INFINITY);
567
568                for neighbor in neighbors {
569                    if let Some(embedding) = self.entity_embeddings.get(neighbor) {
570                        for i in 0..self.config.dimensions {
571                            max_embedding[i] = max_embedding[i].max(embedding[i]);
572                        }
573                    }
574                }
575
576                // Replace negative infinity with zeros
577                for i in 0..self.config.dimensions {
578                    if max_embedding[i] == f32::NEG_INFINITY {
579                        max_embedding[i] = 0.0;
580                    }
581                }
582
583                Ok(max_embedding)
584            }
585            AggregatorType::LSTM => {
586                // LSTM-based aggregator
587                self.lstm_aggregate(neighbors)
588            }
589            AggregatorType::Attention => {
590                // Attention-based aggregator
591                self.attention_aggregate(neighbors)
592            }
593        }
594    }
595
596    /// LSTM-based aggregator (simplified implementation)
597    fn lstm_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
598        if neighbors.is_empty() {
599            return Ok(DVector::zeros(self.config.dimensions));
600        }
601
602        // Simplified LSTM: process neighbors sequentially with forget/input gates
603        let mut cell_state = DVector::zeros(self.config.dimensions);
604        let mut hidden_state = DVector::zeros(self.config.dimensions);
605
606        for neighbor in neighbors {
607            if let Some(embedding) = self.entity_embeddings.get(neighbor) {
608                // Simplified LSTM gates (using tanh and sigmoid approximations)
609                let forget_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp())); // sigmoid
610                let input_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
611                let candidate = embedding.map(|x| x.tanh()); // tanh
612
613                // Update cell state
614                cell_state =
615                    cell_state.component_mul(&forget_gate) + input_gate.component_mul(&candidate);
616
617                // Update hidden state
618                let output_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
619                hidden_state = output_gate.component_mul(&cell_state.map(|x| x.tanh()));
620            }
621        }
622
623        Ok(hidden_state)
624    }
625
626    /// Attention-based aggregator
627    fn attention_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
628        if neighbors.is_empty() {
629            return Ok(DVector::zeros(self.config.dimensions));
630        }
631
632        let neighbor_embeddings: Vec<&DVector<f32>> = neighbors
633            .iter()
634            .filter_map(|neighbor| self.entity_embeddings.get(neighbor))
635            .collect();
636
637        if neighbor_embeddings.is_empty() {
638            return Ok(DVector::zeros(self.config.dimensions));
639        }
640
641        // Simple attention mechanism using dot-product attention
642        let mut attention_scores = Vec::new();
643        let mut weighted_sum = DVector::zeros(self.config.dimensions);
644
645        // Calculate attention scores (simplified: using magnitude as query)
646        let query = DVector::from_element(self.config.dimensions, 1.0); // Simple uniform query
647
648        for embedding in &neighbor_embeddings {
649            let score = query.dot(embedding).exp(); // Softmax will normalize
650            attention_scores.push(score);
651        }
652
653        // Normalize attention scores (softmax)
654        let total_score: f32 = attention_scores.iter().sum();
655        if total_score > 0.0 {
656            for score in &mut attention_scores {
657                *score /= total_score;
658            }
659        }
660
661        // Calculate weighted sum
662        for (embedding, &score) in neighbor_embeddings.iter().zip(attention_scores.iter()) {
663            weighted_sum += *embedding * score;
664        }
665
666        Ok(weighted_sum)
667    }
668
669    /// Forward pass for a single node
670    fn forward_node(&self, node: &str, rng: &mut impl Rng) -> Result<DVector<f32>> {
671        let neighbors = self.sample_neighbors(node, rng);
672        let neighbor_aggregate = self.aggregate_neighbors(&neighbors)?;
673
674        if let Some(node_embedding) = self.entity_embeddings.get(node) {
675            // Concatenate node embedding with aggregated neighbor embeddings
676            // For simplicity, we'll just add them (should be concatenation + linear transformation)
677            Ok(node_embedding + neighbor_aggregate)
678        } else {
679            Ok(neighbor_aggregate)
680        }
681    }
682}
683
684impl KGEmbeddingModel for GraphSAGE {
685    fn train(&mut self, triples: &[Triple]) -> Result<()> {
686        self.initialize(triples)?;
687
688        let mut rng = if let Some(seed) = self.config.random_seed {
689            Random::seed(seed)
690        } else {
691            Random::seed(42)
692        };
693
694        for epoch in 0..self.config.epochs {
695            let mut new_embeddings = HashMap::new();
696
697            // Update embeddings for all entities
698            for entity in &self.entities {
699                let new_embedding = self.forward_node(entity, &mut rng)?;
700                new_embeddings.insert(entity.clone(), new_embedding);
701            }
702
703            // Update embeddings
704            self.entity_embeddings = new_embeddings;
705
706            if epoch % 10 == 0 {
707                println!("GraphSAGE training epoch {}/{}", epoch, self.config.epochs);
708            }
709        }
710
711        Ok(())
712    }
713
714    fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
715        self.entity_embeddings
716            .get(entity)
717            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
718    }
719
720    fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
721        self.relation_embeddings
722            .get(relation)
723            .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
724    }
725
726    fn score_triple(&self, triple: &Triple) -> f32 {
727        if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
728            self.get_entity_embedding(&triple.subject),
729            self.get_relation_embedding(&triple.predicate),
730            self.get_entity_embedding(&triple.object),
731        ) {
732            let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
733            predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
734        } else {
735            0.0
736        }
737    }
738
739    fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
740        if let (Some(head_emb), Some(rel_emb)) = (
741            self.get_entity_embedding(head),
742            self.get_relation_embedding(relation),
743        ) {
744            let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
745
746            let mut scores = Vec::new();
747            for entity in &self.entities {
748                if entity != head {
749                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
750                        let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
751                        scores.push((entity.clone(), score));
752                    }
753                }
754            }
755
756            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
757            scores.into_iter().take(k).collect()
758        } else {
759            Vec::new()
760        }
761    }
762
763    fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
764        if let (Some(rel_emb), Some(tail_emb)) = (
765            self.get_relation_embedding(relation),
766            self.get_entity_embedding(tail),
767        ) {
768            let mut scores = Vec::new();
769            for entity in &self.entities {
770                if entity != tail {
771                    if let Some(entity_emb) = self.get_entity_embedding(entity) {
772                        let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
773                        let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
774                        scores.push((entity.clone(), score));
775                    }
776                }
777            }
778
779            scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
780            scores.into_iter().take(k).collect()
781        } else {
782            Vec::new()
783        }
784    }
785
786    fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
787        HashMap::new()
788    }
789
790    fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
791        HashMap::new()
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798
799    #[test]
800    fn test_gcn_creation() {
801        let config = KGEmbeddingConfig {
802            model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
803            dimensions: 64,
804            learning_rate: 0.01,
805            margin: 1.0,
806            negative_samples: 5,
807            batch_size: 32,
808            epochs: 10,
809            norm: 2,
810            random_seed: Some(42),
811            regularization: 0.01,
812        };
813
814        let gcn = GCN::new(config);
815        assert_eq!(gcn.num_layers, 2);
816    }
817
818    #[test]
819    fn test_graphsage_creation() {
820        let config = KGEmbeddingConfig {
821            model: crate::kg_embeddings::KGEmbeddingModelType::GraphSAGE,
822            dimensions: 64,
823            learning_rate: 0.01,
824            margin: 1.0,
825            negative_samples: 5,
826            batch_size: 32,
827            epochs: 10,
828            norm: 2,
829            random_seed: Some(42),
830            regularization: 0.01,
831        };
832
833        let graphsage = GraphSAGE::new(config);
834        assert_eq!(graphsage.sample_size, 10);
835    }
836
837    #[test]
838    fn test_gnn_training() {
839        let config = KGEmbeddingConfig {
840            model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
841            dimensions: 32,
842            learning_rate: 0.01,
843            margin: 1.0,
844            negative_samples: 5,
845            batch_size: 16,
846            epochs: 5,
847            norm: 2,
848            random_seed: Some(42),
849            regularization: 0.01,
850        };
851
852        let mut gcn = GCN::new(config);
853
854        let triples = vec![
855            Triple::new(
856                "entity1".to_string(),
857                "relation1".to_string(),
858                "entity2".to_string(),
859            ),
860            Triple::new(
861                "entity2".to_string(),
862                "relation2".to_string(),
863                "entity3".to_string(),
864            ),
865            Triple::new(
866                "entity1".to_string(),
867                "relation3".to_string(),
868                "entity3".to_string(),
869            ),
870        ];
871
872        // Should not panic
873        gcn.train(&triples).unwrap();
874
875        // Should have embeddings for all entities
876        assert!(gcn.get_entity_embedding("entity1").is_some());
877        assert!(gcn.get_entity_embedding("entity2").is_some());
878        assert!(gcn.get_entity_embedding("entity3").is_some());
879    }
880}