oxirs_embed/models/
gnn.rs

1//! Graph Neural Network (GNN) embedding models
2//!
3//! This module provides various GNN architectures for knowledge graph embeddings
4//! including GCN, GraphSAGE, GAT, and Graph Transformers.
5
6use crate::{
7    EmbeddingError, EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector,
8};
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use chrono::Utc;
12use scirs2_core::ndarray_ext::{Array1, Array2};
13#[allow(unused_imports)]
14use scirs2_core::random::{Random, Rng};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17use uuid::Uuid;
18
19/// Type of GNN architecture
20#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
21pub enum GNNType {
22    /// Graph Convolutional Network
23    GCN,
24    /// GraphSAGE - Sampling and aggregating
25    GraphSAGE,
26    /// Graph Attention Network
27    GAT,
28    /// Graph Transformer
29    GraphTransformer,
30    /// Graph Isomorphism Network
31    GIN,
32    /// Principal Neighbourhood Aggregation
33    PNA,
34    /// Heterogeneous Graph Network
35    HetGNN,
36    /// Temporal Graph Network
37    TGN,
38}
39
40impl GNNType {
41    pub fn default_layers(&self) -> usize {
42        match self {
43            GNNType::GCN => 2,
44            GNNType::GraphSAGE => 2,
45            GNNType::GAT => 2,
46            GNNType::GraphTransformer => 4,
47            GNNType::GIN => 3,
48            GNNType::PNA => 3,
49            GNNType::HetGNN => 2,
50            GNNType::TGN => 2,
51        }
52    }
53
54    pub fn requires_attention(&self) -> bool {
55        matches!(self, GNNType::GAT | GNNType::GraphTransformer)
56    }
57}
58
59/// Aggregation method for GraphSAGE
60#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
61pub enum AggregationType {
62    Mean,
63    Max,
64    Sum,
65    LSTM,
66}
67
68/// Configuration for GNN models
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GNNConfig {
71    pub base_config: ModelConfig,
72    pub gnn_type: GNNType,
73    pub num_layers: usize,
74    pub hidden_dimensions: Vec<usize>,
75    pub dropout: f64,
76    pub aggregation: AggregationType,
77    pub num_heads: Option<usize>,        // For attention-based models
78    pub sample_neighbors: Option<usize>, // For GraphSAGE
79    pub residual_connections: bool,
80    pub layer_norm: bool,
81    pub edge_features: bool,
82}
83
84impl Default for GNNConfig {
85    fn default() -> Self {
86        Self {
87            base_config: ModelConfig::default(),
88            gnn_type: GNNType::GCN,
89            num_layers: 2,
90            hidden_dimensions: vec![128, 64],
91            dropout: 0.1,
92            aggregation: AggregationType::Mean,
93            num_heads: None,
94            sample_neighbors: None,
95            residual_connections: true,
96            layer_norm: true,
97            edge_features: false,
98        }
99    }
100}
101
102/// GNN-based embedding model
103pub struct GNNEmbedding {
104    id: Uuid,
105    config: GNNConfig,
106    entity_embeddings: HashMap<String, Array1<f32>>,
107    relation_embeddings: HashMap<String, Array1<f32>>,
108    entity_to_idx: HashMap<String, usize>,
109    relation_to_idx: HashMap<String, usize>,
110    idx_to_entity: HashMap<usize, String>,
111    idx_to_relation: HashMap<usize, String>,
112    adjacency_list: HashMap<usize, HashSet<(usize, usize)>>, // (neighbor, relation)
113    reverse_adjacency_list: HashMap<usize, HashSet<(usize, usize)>>,
114    triples: Vec<Triple>,
115    layers: Vec<GNNLayer>,
116    is_trained: bool,
117    creation_time: chrono::DateTime<Utc>,
118    last_training_time: Option<chrono::DateTime<Utc>>,
119}
120
121/// Single GNN layer
122struct GNNLayer {
123    weight_matrix: Array2<f32>,
124    bias: Array1<f32>,
125    attention_weights: Option<AttentionWeights>,
126    layer_norm: Option<LayerNormalization>,
127}
128
129/// Attention weights for GAT/GraphTransformer
130struct AttentionWeights {
131    query_weights: Array2<f32>,
132    key_weights: Array2<f32>,
133    value_weights: Array2<f32>,
134    num_heads: usize,
135}
136
137/// Layer normalization parameters
138struct LayerNormalization {
139    gamma: Array1<f32>,
140    beta: Array1<f32>,
141    epsilon: f32,
142}
143
144impl GNNEmbedding {
145    pub fn new(config: GNNConfig) -> Self {
146        Self {
147            id: Uuid::new_v4(),
148            config,
149            entity_embeddings: HashMap::new(),
150            relation_embeddings: HashMap::new(),
151            entity_to_idx: HashMap::new(),
152            relation_to_idx: HashMap::new(),
153            idx_to_entity: HashMap::new(),
154            idx_to_relation: HashMap::new(),
155            adjacency_list: HashMap::new(),
156            reverse_adjacency_list: HashMap::new(),
157            triples: Vec::new(),
158            layers: Vec::new(),
159            is_trained: false,
160            creation_time: Utc::now(),
161            last_training_time: None,
162        }
163    }
164
165    /// Initialize GNN layers
166    fn initialize_layers(&mut self) -> Result<()> {
167        self.layers.clear();
168        let mut rng = Random::seed(42);
169
170        let mut input_dim = self.config.base_config.dimensions;
171        let num_layers = self.config.num_layers;
172
173        for i in 0..num_layers {
174            let output_dim = if i == num_layers - 1 {
175                // Final layer should output back to original embedding dimension
176                self.config.base_config.dimensions
177            } else if i < self.config.hidden_dimensions.len() {
178                self.config.hidden_dimensions[i]
179            } else {
180                self.config.base_config.dimensions
181            };
182
183            // Initialize weight matrix
184            let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
185            let weight_matrix = Array2::from_shape_fn((input_dim, output_dim), |_| {
186                rng.gen_range(0.0..1.0) * scale * 2.0 - scale
187            });
188
189            let bias = Array1::zeros(output_dim);
190
191            // Initialize attention weights if needed
192            let attention_weights = if self.config.gnn_type.requires_attention() {
193                let num_heads = self.config.num_heads.unwrap_or(8);
194                let head_dim = output_dim / num_heads;
195
196                // For multi-head attention, each head processes a portion of the output
197                let attention_dim = head_dim * num_heads; // Should equal output_dim
198
199                Some(AttentionWeights {
200                    query_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
201                        rng.gen_range(0.0..1.0) * scale * 2.0 - scale
202                    }),
203                    key_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
204                        rng.gen_range(0.0..1.0) * scale * 2.0 - scale
205                    }),
206                    value_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
207                        rng.gen_range(0.0..1.0) * scale * 2.0 - scale
208                    }),
209                    num_heads,
210                })
211            } else {
212                None
213            };
214
215            // Initialize layer normalization if needed
216            let layer_norm = if self.config.layer_norm {
217                Some(LayerNormalization {
218                    gamma: Array1::ones(output_dim),
219                    beta: Array1::zeros(output_dim),
220                    epsilon: 1e-5,
221                })
222            } else {
223                None
224            };
225
226            self.layers.push(GNNLayer {
227                weight_matrix,
228                bias,
229                attention_weights,
230                layer_norm,
231            });
232
233            input_dim = output_dim;
234        }
235
236        Ok(())
237    }
238
239    /// Build adjacency lists from triples
240    fn build_adjacency_lists(&mut self) {
241        self.adjacency_list.clear();
242        self.reverse_adjacency_list.clear();
243
244        for triple in &self.triples {
245            let subject_idx = self.entity_to_idx[&triple.subject.iri];
246            let object_idx = self.entity_to_idx[&triple.object.iri];
247            let relation_idx = self.relation_to_idx[&triple.predicate.iri];
248
249            // Forward adjacency
250            self.adjacency_list
251                .entry(subject_idx)
252                .or_default()
253                .insert((object_idx, relation_idx));
254
255            // Reverse adjacency
256            self.reverse_adjacency_list
257                .entry(object_idx)
258                .or_default()
259                .insert((subject_idx, relation_idx));
260        }
261    }
262
263    /// Aggregate neighbor features
264    fn aggregate_neighbors(
265        &self,
266        node_idx: usize,
267        node_features: &HashMap<usize, Array1<f32>>,
268    ) -> Array1<f32> {
269        let neighbors = self.adjacency_list.get(&node_idx);
270        let reverse_neighbors = self.reverse_adjacency_list.get(&node_idx);
271
272        let mut neighbor_features = Vec::new();
273
274        // Collect forward neighbors
275        if let Some(neighbors) = neighbors {
276            for (neighbor_idx, _) in neighbors {
277                if let Some(feature) = node_features.get(neighbor_idx) {
278                    neighbor_features.push(feature.clone());
279                }
280            }
281        }
282
283        // Collect reverse neighbors
284        if let Some(reverse_neighbors) = reverse_neighbors {
285            for (neighbor_idx, _) in reverse_neighbors {
286                if let Some(feature) = node_features.get(neighbor_idx) {
287                    neighbor_features.push(feature.clone());
288                }
289            }
290        }
291
292        if neighbor_features.is_empty() {
293            // Return zero vector if no neighbors
294            return Array1::zeros(node_features.values().next().unwrap().len());
295        }
296
297        // Aggregate based on configuration
298        match self.config.aggregation {
299            AggregationType::Mean => {
300                let sum: Array1<f32> = neighbor_features
301                    .iter()
302                    .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x);
303                sum / neighbor_features.len() as f32
304            }
305            AggregationType::Max => neighbor_features.iter().fold(
306                Array1::from_elem(neighbor_features[0].len(), f32::NEG_INFINITY),
307                |acc, x| {
308                    let mut result = acc.clone();
309                    for (i, &val) in x.iter().enumerate() {
310                        result[i] = result[i].max(val);
311                    }
312                    result
313                },
314            ),
315            AggregationType::Sum => neighbor_features
316                .iter()
317                .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x),
318            AggregationType::LSTM => {
319                // Simplified LSTM aggregation - in practice would use actual LSTM
320                self.aggregate_neighbors_lstm(&neighbor_features)
321            }
322        }
323    }
324
325    /// LSTM aggregation (simplified)
326    fn aggregate_neighbors_lstm(&self, neighbor_features: &[Array1<f32>]) -> Array1<f32> {
327        // Simplified version - real implementation would use LSTM cells
328        let mut aggregated = Array1::zeros(neighbor_features[0].len());
329        for feature in neighbor_features {
330            aggregated = aggregated * 0.8 + feature * 0.2; // Simple weighted average
331        }
332        aggregated
333    }
334
335    /// Apply GNN layer
336    fn apply_layer(
337        &self,
338        layer: &GNNLayer,
339        node_features: &HashMap<usize, Array1<f32>>,
340    ) -> HashMap<usize, Array1<f32>> {
341        let mut new_features = HashMap::new();
342
343        match self.config.gnn_type {
344            GNNType::GCN => self.apply_gcn_layer(layer, node_features, &mut new_features),
345            GNNType::GraphSAGE => {
346                self.apply_graphsage_layer(layer, node_features, &mut new_features)
347            }
348            GNNType::GAT => self.apply_gat_layer(layer, node_features, &mut new_features),
349            GNNType::GIN => self.apply_gin_layer(layer, node_features, &mut new_features),
350            _ => self.apply_gcn_layer(layer, node_features, &mut new_features), // Default to GCN
351        }
352
353        new_features
354    }
355
356    /// Apply GCN layer
357    fn apply_gcn_layer(
358        &self,
359        layer: &GNNLayer,
360        node_features: &HashMap<usize, Array1<f32>>,
361        new_features: &mut HashMap<usize, Array1<f32>>,
362    ) {
363        for (node_idx, feature) in node_features {
364            let aggregated = self.aggregate_neighbors(*node_idx, node_features);
365            let combined = feature + &aggregated;
366            let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
367
368            // Apply activation (ReLU)
369            let activated = transformed.mapv(|x| x.max(0.0));
370
371            // Apply layer norm if configured
372            let output = if let Some(ln) = &layer.layer_norm {
373                self.apply_layer_norm(&activated, ln)
374            } else {
375                activated
376            };
377
378            new_features.insert(*node_idx, output);
379        }
380    }
381
382    /// Apply GraphSAGE layer
383    fn apply_graphsage_layer(
384        &self,
385        layer: &GNNLayer,
386        node_features: &HashMap<usize, Array1<f32>>,
387        new_features: &mut HashMap<usize, Array1<f32>>,
388    ) {
389        for (node_idx, feature) in node_features {
390            let aggregated = self.aggregate_neighbors(*node_idx, node_features);
391
392            // For GraphSAGE, we apply separate transformations and then combine
393            // Transform node feature
394            let node_transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
395
396            // Transform aggregated neighbor features (reuse same weight matrix for simplicity)
397            let neighbor_transformed = aggregated.dot(&layer.weight_matrix) + &layer.bias;
398
399            // Combine the transformed features
400            let combined = &node_transformed + &neighbor_transformed;
401
402            // Apply activation and normalization
403            let activated = combined.mapv(|x| x.max(0.0));
404            let normalized = &activated / (activated.dot(&activated).sqrt() + 1e-6);
405
406            new_features.insert(*node_idx, normalized);
407        }
408    }
409
410    /// Apply GAT layer
411    fn apply_gat_layer(
412        &self,
413        layer: &GNNLayer,
414        node_features: &HashMap<usize, Array1<f32>>,
415        new_features: &mut HashMap<usize, Array1<f32>>,
416    ) {
417        // Simplified GAT - real implementation would compute attention scores
418        let attention = layer.attention_weights.as_ref().unwrap();
419
420        for (node_idx, feature) in node_features {
421            // Get neighbors
422            let mut neighbor_indices = Vec::new();
423            if let Some(neighbors) = self.adjacency_list.get(node_idx) {
424                neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
425            }
426            if let Some(neighbors) = self.reverse_adjacency_list.get(node_idx) {
427                neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
428            }
429
430            if neighbor_indices.is_empty() {
431                // Apply linear transformation even when no neighbors
432                let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
433                let activated = transformed.mapv(|x| x.max(0.0));
434                new_features.insert(*node_idx, activated);
435                continue;
436            }
437
438            // Ensure feature dimensions match weight matrix input dimensions
439            if feature.len() != attention.query_weights.shape()[0] {
440                // Fallback to simple aggregation if dimensions don't match
441                let aggregated = self.aggregate_neighbors(*node_idx, node_features);
442                let combined = feature + &aggregated;
443                let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
444                let activated = transformed.mapv(|x| x.max(0.0));
445                new_features.insert(*node_idx, activated);
446                continue;
447            }
448
449            // Compute attention scores (simplified)
450            let query = feature.dot(&attention.query_weights);
451            let mut attention_scores = Vec::new();
452            let mut neighbor_values = Vec::new();
453
454            for neighbor_idx in &neighbor_indices {
455                if let Some(neighbor_feature) = node_features.get(neighbor_idx) {
456                    // Check dimension compatibility before computing attention
457                    if neighbor_feature.len() != attention.key_weights.shape()[0] {
458                        continue;
459                    }
460
461                    let key = neighbor_feature.dot(&attention.key_weights);
462                    let value = neighbor_feature.dot(&attention.value_weights);
463
464                    // Compute attention score with proper dimension checking
465                    if query.len() == key.len() {
466                        let score = query.dot(&key) / (attention.num_heads as f32).sqrt();
467                        attention_scores.push(score);
468                        neighbor_values.push(value);
469                    }
470                }
471            }
472
473            if attention_scores.is_empty() {
474                // Fallback to simple aggregation if no valid attention scores
475                let aggregated = self.aggregate_neighbors(*node_idx, node_features);
476                let combined = feature + &aggregated;
477                let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
478                let activated = transformed.mapv(|x| x.max(0.0));
479                new_features.insert(*node_idx, activated);
480                continue;
481            }
482
483            // Softmax
484            let max_score = attention_scores
485                .iter()
486                .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
487            let exp_scores: Vec<f32> = attention_scores
488                .iter()
489                .map(|&s| (s - max_score).exp())
490                .collect();
491            let sum_exp = exp_scores.iter().sum::<f32>();
492            let attention_weights: Vec<f32> =
493                exp_scores.iter().copied().map(|e| e / sum_exp).collect();
494
495            // Apply attention with proper output dimensions
496            let output_dim = layer.weight_matrix.shape()[1];
497            let mut aggregated = Array1::<f32>::zeros(output_dim);
498
499            for (i, value) in neighbor_values.iter().enumerate() {
500                // Ensure value dimension matches output dimension
501                let min_dim = aggregated.len().min(value.len());
502                for j in 0..min_dim {
503                    aggregated[j] += value[j] * attention_weights[i];
504                }
505            }
506
507            // Apply linear transformation
508            let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
509            let combined =
510                if self.config.residual_connections && transformed.len() == aggregated.len() {
511                    transformed + &aggregated
512                } else {
513                    transformed
514                };
515
516            let activated = combined.mapv(|x| x.max(0.0));
517            new_features.insert(*node_idx, activated);
518        }
519    }
520
521    /// Apply GIN layer
522    fn apply_gin_layer(
523        &self,
524        layer: &GNNLayer,
525        node_features: &HashMap<usize, Array1<f32>>,
526        new_features: &mut HashMap<usize, Array1<f32>>,
527    ) {
528        let epsilon = 0.0; // GIN epsilon parameter
529
530        for (node_idx, feature) in node_features {
531            let aggregated = self.aggregate_neighbors(*node_idx, node_features);
532            let combined = (1.0 + epsilon) * feature + aggregated;
533
534            // MLP transformation (simplified as single linear layer)
535            let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
536            let activated = transformed.mapv(|x| x.max(0.0));
537
538            new_features.insert(*node_idx, activated);
539        }
540    }
541
542    /// Apply layer normalization
543    fn apply_layer_norm(&self, input: &Array1<f32>, ln: &LayerNormalization) -> Array1<f32> {
544        let mean = input.mean().unwrap_or(0.0);
545        let variance = input.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(1.0);
546        let normalized = input.mapv(|x| (x - mean) / (variance + ln.epsilon).sqrt());
547        &normalized * &ln.gamma + &ln.beta
548    }
549
550    /// Forward pass through all GNN layers
551    fn forward(
552        &self,
553        initial_features: HashMap<usize, Array1<f32>>,
554    ) -> HashMap<usize, Array1<f32>> {
555        let mut features = initial_features;
556
557        for layer in self.layers.iter() {
558            let new_features = self.apply_layer(layer, &features);
559
560            // Apply dropout during training (simplified - always applied here)
561            let dropout_rate = self.config.dropout;
562            let mut rng = Random::seed(42);
563
564            features = new_features
565                .into_iter()
566                .map(|(idx, feat)| {
567                    let masked = feat.mapv(|x| {
568                        if rng.gen_range(0.0..1.0) > dropout_rate as f32 {
569                            x / (1.0 - dropout_rate as f32)
570                        } else {
571                            0.0
572                        }
573                    });
574                    (idx, masked)
575                })
576                .collect();
577        }
578
579        features
580    }
581}
582
583#[async_trait]
584impl EmbeddingModel for GNNEmbedding {
585    fn config(&self) -> &ModelConfig {
586        &self.config.base_config
587    }
588
589    fn model_id(&self) -> &Uuid {
590        &self.id
591    }
592
593    fn model_type(&self) -> &'static str {
594        "GNNEmbedding"
595    }
596
597    fn add_triple(&mut self, triple: Triple) -> Result<()> {
598        // Add entities to index
599        let subject = triple.subject.iri.clone();
600        let object = triple.object.iri.clone();
601        let predicate = triple.predicate.iri.clone();
602
603        if !self.entity_to_idx.contains_key(&subject) {
604            let idx = self.entity_to_idx.len();
605            self.entity_to_idx.insert(subject.clone(), idx);
606            self.idx_to_entity.insert(idx, subject);
607        }
608
609        if !self.entity_to_idx.contains_key(&object) {
610            let idx = self.entity_to_idx.len();
611            self.entity_to_idx.insert(object.clone(), idx);
612            self.idx_to_entity.insert(idx, object);
613        }
614
615        if !self.relation_to_idx.contains_key(&predicate) {
616            let idx = self.relation_to_idx.len();
617            self.relation_to_idx.insert(predicate.clone(), idx);
618            self.idx_to_relation.insert(idx, predicate);
619        }
620
621        self.triples.push(triple);
622        self.is_trained = false;
623        Ok(())
624    }
625
626    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
627        let start_time = std::time::Instant::now();
628        let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
629
630        // Build adjacency lists
631        self.build_adjacency_lists();
632
633        // Initialize layers
634        self.initialize_layers()?;
635
636        // Initialize random embeddings
637        let mut rng = Random::seed(42);
638        let dimensions = self.config.base_config.dimensions;
639
640        let mut initial_features = HashMap::new();
641        for idx in self.entity_to_idx.values() {
642            let embedding =
643                Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
644            initial_features.insert(*idx, embedding);
645        }
646
647        // Training loop (simplified)
648        let mut loss_history = Vec::new();
649
650        for _epoch in 0..epochs {
651            // Forward pass
652            let output_features = self.forward(initial_features.clone());
653
654            // Compute loss (simplified - just using L2 regularization)
655            let loss = output_features
656                .values()
657                .map(|f| f.mapv(|x| x * x).sum())
658                .sum::<f32>()
659                / output_features.len() as f32;
660
661            loss_history.push(loss as f64);
662
663            // Update initial features with output (simplified training)
664            initial_features = output_features;
665
666            // Early stopping
667            if loss < 0.001 {
668                break;
669            }
670        }
671
672        // Store final embeddings
673        for (idx, embedding) in initial_features {
674            if let Some(entity) = self.idx_to_entity.get(&idx) {
675                self.entity_embeddings.insert(entity.clone(), embedding);
676            }
677        }
678
679        // Generate relation embeddings (simplified - using random initialization)
680        for relation in self.relation_to_idx.keys() {
681            let embedding =
682                Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
683            self.relation_embeddings.insert(relation.clone(), embedding);
684        }
685
686        self.is_trained = true;
687        self.last_training_time = Some(Utc::now());
688
689        Ok(TrainingStats {
690            epochs_completed: loss_history.len(),
691            final_loss: *loss_history.last().unwrap_or(&0.0),
692            training_time_seconds: start_time.elapsed().as_secs_f64(),
693            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
694            loss_history,
695        })
696    }
697
698    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
699        if !self.is_trained {
700            return Err(EmbeddingError::ModelNotTrained.into());
701        }
702
703        self.entity_embeddings
704            .get(entity)
705            .map(|e| Vector::new(e.to_vec()))
706            .ok_or_else(|| {
707                EmbeddingError::EntityNotFound {
708                    entity: entity.to_string(),
709                }
710                .into()
711            })
712    }
713
714    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
715        if !self.is_trained {
716            return Err(EmbeddingError::ModelNotTrained.into());
717        }
718
719        self.relation_embeddings
720            .get(relation)
721            .map(|e| Vector::new(e.to_vec()))
722            .ok_or_else(|| {
723                EmbeddingError::RelationNotFound {
724                    relation: relation.to_string(),
725                }
726                .into()
727            })
728    }
729
730    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
731        if !self.is_trained {
732            return Err(EmbeddingError::ModelNotTrained.into());
733        }
734
735        let subj_emb =
736            self.entity_embeddings
737                .get(subject)
738                .ok_or_else(|| EmbeddingError::EntityNotFound {
739                    entity: subject.to_string(),
740                })?;
741
742        let pred_emb = self.relation_embeddings.get(predicate).ok_or_else(|| {
743            EmbeddingError::RelationNotFound {
744                relation: predicate.to_string(),
745            }
746        })?;
747
748        let obj_emb =
749            self.entity_embeddings
750                .get(object)
751                .ok_or_else(|| EmbeddingError::EntityNotFound {
752                    entity: object.to_string(),
753                })?;
754
755        // Simple scoring: dot product of transformed embeddings
756        let transformed = (subj_emb + pred_emb) * obj_emb;
757        Ok(transformed.sum() as f64)
758    }
759
760    fn predict_objects(
761        &self,
762        subject: &str,
763        predicate: &str,
764        k: usize,
765    ) -> Result<Vec<(String, f64)>> {
766        if !self.is_trained {
767            return Err(EmbeddingError::ModelNotTrained.into());
768        }
769
770        let mut scores = Vec::new();
771
772        for entity in self.entity_to_idx.keys() {
773            if let Ok(score) = self.score_triple(subject, predicate, entity) {
774                scores.push((entity.clone(), score));
775            }
776        }
777
778        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779        scores.truncate(k);
780
781        Ok(scores)
782    }
783
784    fn predict_subjects(
785        &self,
786        predicate: &str,
787        object: &str,
788        k: usize,
789    ) -> Result<Vec<(String, f64)>> {
790        if !self.is_trained {
791            return Err(EmbeddingError::ModelNotTrained.into());
792        }
793
794        let mut scores = Vec::new();
795
796        for entity in self.entity_to_idx.keys() {
797            if let Ok(score) = self.score_triple(entity, predicate, object) {
798                scores.push((entity.clone(), score));
799            }
800        }
801
802        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
803        scores.truncate(k);
804
805        Ok(scores)
806    }
807
808    fn predict_relations(
809        &self,
810        subject: &str,
811        object: &str,
812        k: usize,
813    ) -> Result<Vec<(String, f64)>> {
814        if !self.is_trained {
815            return Err(EmbeddingError::ModelNotTrained.into());
816        }
817
818        let mut scores = Vec::new();
819
820        for relation in self.relation_to_idx.keys() {
821            if let Ok(score) = self.score_triple(subject, relation, object) {
822                scores.push((relation.clone(), score));
823            }
824        }
825
826        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
827        scores.truncate(k);
828
829        Ok(scores)
830    }
831
832    fn get_entities(&self) -> Vec<String> {
833        self.entity_to_idx.keys().cloned().collect()
834    }
835
836    fn get_relations(&self) -> Vec<String> {
837        self.relation_to_idx.keys().cloned().collect()
838    }
839
840    fn get_stats(&self) -> ModelStats {
841        ModelStats {
842            num_entities: self.entity_to_idx.len(),
843            num_relations: self.relation_to_idx.len(),
844            num_triples: self.triples.len(),
845            dimensions: self.config.base_config.dimensions,
846            is_trained: self.is_trained,
847            model_type: format!("GNNEmbedding-{:?}", self.config.gnn_type),
848            creation_time: self.creation_time,
849            last_training_time: self.last_training_time,
850        }
851    }
852
853    fn save(&self, _path: &str) -> Result<()> {
854        // Implementation would save model weights and configuration
855        Ok(())
856    }
857
858    fn load(&mut self, _path: &str) -> Result<()> {
859        // Implementation would load model weights and configuration
860        Ok(())
861    }
862
863    fn clear(&mut self) {
864        self.entity_embeddings.clear();
865        self.relation_embeddings.clear();
866        self.entity_to_idx.clear();
867        self.relation_to_idx.clear();
868        self.idx_to_entity.clear();
869        self.idx_to_relation.clear();
870        self.adjacency_list.clear();
871        self.reverse_adjacency_list.clear();
872        self.triples.clear();
873        self.layers.clear();
874        self.is_trained = false;
875    }
876
877    fn is_trained(&self) -> bool {
878        self.is_trained
879    }
880
881    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
882        Err(anyhow!(
883            "Knowledge graph embedding model does not support text encoding"
884        ))
885    }
886}
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891    use crate::NamedNode;
892
893    #[tokio::test]
894    async fn test_gnn_embedding_basic() {
895        let config = GNNConfig {
896            gnn_type: GNNType::GCN,
897            num_layers: 2,
898            hidden_dimensions: vec![64, 32],
899            ..Default::default()
900        };
901
902        let mut model = GNNEmbedding::new(config);
903
904        // Add some triples
905        let triple1 = Triple::new(
906            NamedNode::new("http://example.org/Alice").unwrap(),
907            NamedNode::new("http://example.org/knows").unwrap(),
908            NamedNode::new("http://example.org/Bob").unwrap(),
909        );
910
911        let triple2 = Triple::new(
912            NamedNode::new("http://example.org/Bob").unwrap(),
913            NamedNode::new("http://example.org/knows").unwrap(),
914            NamedNode::new("http://example.org/Charlie").unwrap(),
915        );
916
917        model.add_triple(triple1).unwrap();
918        model.add_triple(triple2).unwrap();
919
920        // Train the model
921        let _stats = model.train(Some(10)).await.unwrap();
922        assert!(model.is_trained());
923
924        // Get embeddings
925        let alice_emb = model
926            .get_entity_embedding("http://example.org/Alice")
927            .unwrap();
928        assert_eq!(alice_emb.dimensions, 100); // Default dimensions
929
930        // Test predictions
931        let predictions = model
932            .predict_objects("http://example.org/Alice", "http://example.org/knows", 5)
933            .unwrap();
934        assert!(!predictions.is_empty());
935    }
936
937    #[tokio::test]
938    async fn test_gnn_types() {
939        for gnn_type in [GNNType::GCN, GNNType::GraphSAGE, GNNType::GAT, GNNType::GIN] {
940            let config = GNNConfig {
941                gnn_type,
942                num_heads: if gnn_type == GNNType::GAT {
943                    Some(4)
944                } else {
945                    None
946                },
947                ..Default::default()
948            };
949
950            let mut model = GNNEmbedding::new(config);
951
952            let triple = Triple::new(
953                NamedNode::new("http://example.org/A").unwrap(),
954                NamedNode::new("http://example.org/rel").unwrap(),
955                NamedNode::new("http://example.org/B").unwrap(),
956            );
957
958            model.add_triple(triple).unwrap();
959            let _stats = model.train(Some(5)).await.unwrap();
960            assert!(model.is_trained());
961        }
962    }
963}