1use crate::{
7 EmbeddingError, EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector,
8};
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use chrono::Utc;
12use scirs2_core::ndarray_ext::{Array1, Array2};
13#[allow(unused_imports)]
14use scirs2_core::random::{Random, Rng};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
21pub enum GNNType {
22 GCN,
24 GraphSAGE,
26 GAT,
28 GraphTransformer,
30 GIN,
32 PNA,
34 HetGNN,
36 TGN,
38}
39
40impl GNNType {
41 pub fn default_layers(&self) -> usize {
42 match self {
43 GNNType::GCN => 2,
44 GNNType::GraphSAGE => 2,
45 GNNType::GAT => 2,
46 GNNType::GraphTransformer => 4,
47 GNNType::GIN => 3,
48 GNNType::PNA => 3,
49 GNNType::HetGNN => 2,
50 GNNType::TGN => 2,
51 }
52 }
53
54 pub fn requires_attention(&self) -> bool {
55 matches!(self, GNNType::GAT | GNNType::GraphTransformer)
56 }
57}
58
59#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
61pub enum AggregationType {
62 Mean,
63 Max,
64 Sum,
65 LSTM,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GNNConfig {
71 pub base_config: ModelConfig,
72 pub gnn_type: GNNType,
73 pub num_layers: usize,
74 pub hidden_dimensions: Vec<usize>,
75 pub dropout: f64,
76 pub aggregation: AggregationType,
77 pub num_heads: Option<usize>, pub sample_neighbors: Option<usize>, pub residual_connections: bool,
80 pub layer_norm: bool,
81 pub edge_features: bool,
82}
83
84impl Default for GNNConfig {
85 fn default() -> Self {
86 Self {
87 base_config: ModelConfig::default(),
88 gnn_type: GNNType::GCN,
89 num_layers: 2,
90 hidden_dimensions: vec![128, 64],
91 dropout: 0.1,
92 aggregation: AggregationType::Mean,
93 num_heads: None,
94 sample_neighbors: None,
95 residual_connections: true,
96 layer_norm: true,
97 edge_features: false,
98 }
99 }
100}
101
102pub struct GNNEmbedding {
104 id: Uuid,
105 config: GNNConfig,
106 entity_embeddings: HashMap<String, Array1<f32>>,
107 relation_embeddings: HashMap<String, Array1<f32>>,
108 entity_to_idx: HashMap<String, usize>,
109 relation_to_idx: HashMap<String, usize>,
110 idx_to_entity: HashMap<usize, String>,
111 idx_to_relation: HashMap<usize, String>,
112 adjacency_list: HashMap<usize, HashSet<(usize, usize)>>, reverse_adjacency_list: HashMap<usize, HashSet<(usize, usize)>>,
114 triples: Vec<Triple>,
115 layers: Vec<GNNLayer>,
116 is_trained: bool,
117 creation_time: chrono::DateTime<Utc>,
118 last_training_time: Option<chrono::DateTime<Utc>>,
119}
120
121struct GNNLayer {
123 weight_matrix: Array2<f32>,
124 bias: Array1<f32>,
125 attention_weights: Option<AttentionWeights>,
126 layer_norm: Option<LayerNormalization>,
127}
128
129struct AttentionWeights {
131 query_weights: Array2<f32>,
132 key_weights: Array2<f32>,
133 value_weights: Array2<f32>,
134 num_heads: usize,
135}
136
137struct LayerNormalization {
139 gamma: Array1<f32>,
140 beta: Array1<f32>,
141 epsilon: f32,
142}
143
144impl GNNEmbedding {
145 pub fn new(config: GNNConfig) -> Self {
146 Self {
147 id: Uuid::new_v4(),
148 config,
149 entity_embeddings: HashMap::new(),
150 relation_embeddings: HashMap::new(),
151 entity_to_idx: HashMap::new(),
152 relation_to_idx: HashMap::new(),
153 idx_to_entity: HashMap::new(),
154 idx_to_relation: HashMap::new(),
155 adjacency_list: HashMap::new(),
156 reverse_adjacency_list: HashMap::new(),
157 triples: Vec::new(),
158 layers: Vec::new(),
159 is_trained: false,
160 creation_time: Utc::now(),
161 last_training_time: None,
162 }
163 }
164
165 fn initialize_layers(&mut self) -> Result<()> {
167 self.layers.clear();
168 let mut rng = Random::seed(42);
169
170 let mut input_dim = self.config.base_config.dimensions;
171 let num_layers = self.config.num_layers;
172
173 for i in 0..num_layers {
174 let output_dim = if i == num_layers - 1 {
175 self.config.base_config.dimensions
177 } else if i < self.config.hidden_dimensions.len() {
178 self.config.hidden_dimensions[i]
179 } else {
180 self.config.base_config.dimensions
181 };
182
183 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
185 let weight_matrix = Array2::from_shape_fn((input_dim, output_dim), |_| {
186 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
187 });
188
189 let bias = Array1::zeros(output_dim);
190
191 let attention_weights = if self.config.gnn_type.requires_attention() {
193 let num_heads = self.config.num_heads.unwrap_or(8);
194 let head_dim = output_dim / num_heads;
195
196 let attention_dim = head_dim * num_heads; Some(AttentionWeights {
200 query_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
201 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
202 }),
203 key_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
204 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
205 }),
206 value_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
207 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
208 }),
209 num_heads,
210 })
211 } else {
212 None
213 };
214
215 let layer_norm = if self.config.layer_norm {
217 Some(LayerNormalization {
218 gamma: Array1::ones(output_dim),
219 beta: Array1::zeros(output_dim),
220 epsilon: 1e-5,
221 })
222 } else {
223 None
224 };
225
226 self.layers.push(GNNLayer {
227 weight_matrix,
228 bias,
229 attention_weights,
230 layer_norm,
231 });
232
233 input_dim = output_dim;
234 }
235
236 Ok(())
237 }
238
239 fn build_adjacency_lists(&mut self) {
241 self.adjacency_list.clear();
242 self.reverse_adjacency_list.clear();
243
244 for triple in &self.triples {
245 let subject_idx = self.entity_to_idx[&triple.subject.iri];
246 let object_idx = self.entity_to_idx[&triple.object.iri];
247 let relation_idx = self.relation_to_idx[&triple.predicate.iri];
248
249 self.adjacency_list
251 .entry(subject_idx)
252 .or_default()
253 .insert((object_idx, relation_idx));
254
255 self.reverse_adjacency_list
257 .entry(object_idx)
258 .or_default()
259 .insert((subject_idx, relation_idx));
260 }
261 }
262
263 fn aggregate_neighbors(
265 &self,
266 node_idx: usize,
267 node_features: &HashMap<usize, Array1<f32>>,
268 ) -> Array1<f32> {
269 let neighbors = self.adjacency_list.get(&node_idx);
270 let reverse_neighbors = self.reverse_adjacency_list.get(&node_idx);
271
272 let mut neighbor_features = Vec::new();
273
274 if let Some(neighbors) = neighbors {
276 for (neighbor_idx, _) in neighbors {
277 if let Some(feature) = node_features.get(neighbor_idx) {
278 neighbor_features.push(feature.clone());
279 }
280 }
281 }
282
283 if let Some(reverse_neighbors) = reverse_neighbors {
285 for (neighbor_idx, _) in reverse_neighbors {
286 if let Some(feature) = node_features.get(neighbor_idx) {
287 neighbor_features.push(feature.clone());
288 }
289 }
290 }
291
292 if neighbor_features.is_empty() {
293 return Array1::zeros(
295 node_features
296 .values()
297 .next()
298 .expect("node_features should not be empty")
299 .len(),
300 );
301 }
302
303 match self.config.aggregation {
305 AggregationType::Mean => {
306 let sum: Array1<f32> = neighbor_features
307 .iter()
308 .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x);
309 sum / neighbor_features.len() as f32
310 }
311 AggregationType::Max => neighbor_features.iter().fold(
312 Array1::from_elem(neighbor_features[0].len(), f32::NEG_INFINITY),
313 |acc, x| {
314 let mut result = acc.clone();
315 for (i, &val) in x.iter().enumerate() {
316 result[i] = result[i].max(val);
317 }
318 result
319 },
320 ),
321 AggregationType::Sum => neighbor_features
322 .iter()
323 .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x),
324 AggregationType::LSTM => {
325 self.aggregate_neighbors_lstm(&neighbor_features)
327 }
328 }
329 }
330
331 fn aggregate_neighbors_lstm(&self, neighbor_features: &[Array1<f32>]) -> Array1<f32> {
333 let mut aggregated = Array1::zeros(neighbor_features[0].len());
335 for feature in neighbor_features {
336 aggregated = aggregated * 0.8 + feature * 0.2; }
338 aggregated
339 }
340
341 fn apply_layer(
343 &self,
344 layer: &GNNLayer,
345 node_features: &HashMap<usize, Array1<f32>>,
346 ) -> HashMap<usize, Array1<f32>> {
347 let mut new_features = HashMap::new();
348
349 match self.config.gnn_type {
350 GNNType::GCN => self.apply_gcn_layer(layer, node_features, &mut new_features),
351 GNNType::GraphSAGE => {
352 self.apply_graphsage_layer(layer, node_features, &mut new_features)
353 }
354 GNNType::GAT => self.apply_gat_layer(layer, node_features, &mut new_features),
355 GNNType::GIN => self.apply_gin_layer(layer, node_features, &mut new_features),
356 _ => self.apply_gcn_layer(layer, node_features, &mut new_features), }
358
359 new_features
360 }
361
362 fn apply_gcn_layer(
364 &self,
365 layer: &GNNLayer,
366 node_features: &HashMap<usize, Array1<f32>>,
367 new_features: &mut HashMap<usize, Array1<f32>>,
368 ) {
369 for (node_idx, feature) in node_features {
370 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
371 let combined = feature + &aggregated;
372 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
373
374 let activated = transformed.mapv(|x| x.max(0.0));
376
377 let output = if let Some(ln) = &layer.layer_norm {
379 self.apply_layer_norm(&activated, ln)
380 } else {
381 activated
382 };
383
384 new_features.insert(*node_idx, output);
385 }
386 }
387
388 fn apply_graphsage_layer(
390 &self,
391 layer: &GNNLayer,
392 node_features: &HashMap<usize, Array1<f32>>,
393 new_features: &mut HashMap<usize, Array1<f32>>,
394 ) {
395 for (node_idx, feature) in node_features {
396 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
397
398 let node_transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
401
402 let neighbor_transformed = aggregated.dot(&layer.weight_matrix) + &layer.bias;
404
405 let combined = &node_transformed + &neighbor_transformed;
407
408 let activated = combined.mapv(|x| x.max(0.0));
410 let normalized = &activated / (activated.dot(&activated).sqrt() + 1e-6);
411
412 new_features.insert(*node_idx, normalized);
413 }
414 }
415
416 fn apply_gat_layer(
418 &self,
419 layer: &GNNLayer,
420 node_features: &HashMap<usize, Array1<f32>>,
421 new_features: &mut HashMap<usize, Array1<f32>>,
422 ) {
423 let attention = layer
425 .attention_weights
426 .as_ref()
427 .expect("attention_weights should be initialized for GAT layer");
428
429 for (node_idx, feature) in node_features {
430 let mut neighbor_indices = Vec::new();
432 if let Some(neighbors) = self.adjacency_list.get(node_idx) {
433 neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
434 }
435 if let Some(neighbors) = self.reverse_adjacency_list.get(node_idx) {
436 neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
437 }
438
439 if neighbor_indices.is_empty() {
440 let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
442 let activated = transformed.mapv(|x| x.max(0.0));
443 new_features.insert(*node_idx, activated);
444 continue;
445 }
446
447 if feature.len() != attention.query_weights.shape()[0] {
449 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
451 let combined = feature + &aggregated;
452 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
453 let activated = transformed.mapv(|x| x.max(0.0));
454 new_features.insert(*node_idx, activated);
455 continue;
456 }
457
458 let query = feature.dot(&attention.query_weights);
460 let mut attention_scores = Vec::new();
461 let mut neighbor_values = Vec::new();
462
463 for neighbor_idx in &neighbor_indices {
464 if let Some(neighbor_feature) = node_features.get(neighbor_idx) {
465 if neighbor_feature.len() != attention.key_weights.shape()[0] {
467 continue;
468 }
469
470 let key = neighbor_feature.dot(&attention.key_weights);
471 let value = neighbor_feature.dot(&attention.value_weights);
472
473 if query.len() == key.len() {
475 let score = query.dot(&key) / (attention.num_heads as f32).sqrt();
476 attention_scores.push(score);
477 neighbor_values.push(value);
478 }
479 }
480 }
481
482 if attention_scores.is_empty() {
483 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
485 let combined = feature + &aggregated;
486 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
487 let activated = transformed.mapv(|x| x.max(0.0));
488 new_features.insert(*node_idx, activated);
489 continue;
490 }
491
492 let max_score = attention_scores
494 .iter()
495 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
496 let exp_scores: Vec<f32> = attention_scores
497 .iter()
498 .map(|&s| (s - max_score).exp())
499 .collect();
500 let sum_exp = exp_scores.iter().sum::<f32>();
501 let attention_weights: Vec<f32> =
502 exp_scores.iter().copied().map(|e| e / sum_exp).collect();
503
504 let output_dim = layer.weight_matrix.shape()[1];
506 let mut aggregated = Array1::<f32>::zeros(output_dim);
507
508 for (i, value) in neighbor_values.iter().enumerate() {
509 let min_dim = aggregated.len().min(value.len());
511 for j in 0..min_dim {
512 aggregated[j] += value[j] * attention_weights[i];
513 }
514 }
515
516 let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
518 let combined =
519 if self.config.residual_connections && transformed.len() == aggregated.len() {
520 transformed + &aggregated
521 } else {
522 transformed
523 };
524
525 let activated = combined.mapv(|x| x.max(0.0));
526 new_features.insert(*node_idx, activated);
527 }
528 }
529
530 fn apply_gin_layer(
532 &self,
533 layer: &GNNLayer,
534 node_features: &HashMap<usize, Array1<f32>>,
535 new_features: &mut HashMap<usize, Array1<f32>>,
536 ) {
537 let epsilon = 0.0; for (node_idx, feature) in node_features {
540 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
541 let combined = (1.0 + epsilon) * feature + aggregated;
542
543 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
545 let activated = transformed.mapv(|x| x.max(0.0));
546
547 new_features.insert(*node_idx, activated);
548 }
549 }
550
551 fn apply_layer_norm(&self, input: &Array1<f32>, ln: &LayerNormalization) -> Array1<f32> {
553 let mean = input.mean().unwrap_or(0.0);
554 let variance = input.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(1.0);
555 let normalized = input.mapv(|x| (x - mean) / (variance + ln.epsilon).sqrt());
556 &normalized * &ln.gamma + &ln.beta
557 }
558
559 fn forward(
561 &self,
562 initial_features: HashMap<usize, Array1<f32>>,
563 ) -> HashMap<usize, Array1<f32>> {
564 let mut features = initial_features;
565
566 for layer in self.layers.iter() {
567 let new_features = self.apply_layer(layer, &features);
568
569 let dropout_rate = self.config.dropout;
571 let mut rng = Random::seed(42);
572
573 features = new_features
574 .into_iter()
575 .map(|(idx, feat)| {
576 let masked = feat.mapv(|x| {
577 if rng.gen_range(0.0..1.0) > dropout_rate as f32 {
578 x / (1.0 - dropout_rate as f32)
579 } else {
580 0.0
581 }
582 });
583 (idx, masked)
584 })
585 .collect();
586 }
587
588 features
589 }
590}
591
592#[async_trait]
593impl EmbeddingModel for GNNEmbedding {
594 fn config(&self) -> &ModelConfig {
595 &self.config.base_config
596 }
597
598 fn model_id(&self) -> &Uuid {
599 &self.id
600 }
601
602 fn model_type(&self) -> &'static str {
603 "GNNEmbedding"
604 }
605
606 fn add_triple(&mut self, triple: Triple) -> Result<()> {
607 let subject = triple.subject.iri.clone();
609 let object = triple.object.iri.clone();
610 let predicate = triple.predicate.iri.clone();
611
612 if !self.entity_to_idx.contains_key(&subject) {
613 let idx = self.entity_to_idx.len();
614 self.entity_to_idx.insert(subject.clone(), idx);
615 self.idx_to_entity.insert(idx, subject);
616 }
617
618 if !self.entity_to_idx.contains_key(&object) {
619 let idx = self.entity_to_idx.len();
620 self.entity_to_idx.insert(object.clone(), idx);
621 self.idx_to_entity.insert(idx, object);
622 }
623
624 if !self.relation_to_idx.contains_key(&predicate) {
625 let idx = self.relation_to_idx.len();
626 self.relation_to_idx.insert(predicate.clone(), idx);
627 self.idx_to_relation.insert(idx, predicate);
628 }
629
630 self.triples.push(triple);
631 self.is_trained = false;
632 Ok(())
633 }
634
635 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
636 let start_time = std::time::Instant::now();
637 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
638
639 self.build_adjacency_lists();
641
642 self.initialize_layers()?;
644
645 let mut rng = Random::seed(42);
647 let dimensions = self.config.base_config.dimensions;
648
649 let mut initial_features = HashMap::new();
650 for idx in self.entity_to_idx.values() {
651 let embedding =
652 Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
653 initial_features.insert(*idx, embedding);
654 }
655
656 let mut loss_history = Vec::new();
658
659 for _epoch in 0..epochs {
660 let output_features = self.forward(initial_features.clone());
662
663 let loss = output_features
665 .values()
666 .map(|f| f.mapv(|x| x * x).sum())
667 .sum::<f32>()
668 / output_features.len() as f32;
669
670 loss_history.push(loss as f64);
671
672 initial_features = output_features;
674
675 if loss < 0.001 {
677 break;
678 }
679 }
680
681 for (idx, embedding) in initial_features {
683 if let Some(entity) = self.idx_to_entity.get(&idx) {
684 self.entity_embeddings.insert(entity.clone(), embedding);
685 }
686 }
687
688 for relation in self.relation_to_idx.keys() {
690 let embedding =
691 Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
692 self.relation_embeddings.insert(relation.clone(), embedding);
693 }
694
695 self.is_trained = true;
696 self.last_training_time = Some(Utc::now());
697
698 Ok(TrainingStats {
699 epochs_completed: loss_history.len(),
700 final_loss: *loss_history.last().unwrap_or(&0.0),
701 training_time_seconds: start_time.elapsed().as_secs_f64(),
702 convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
703 loss_history,
704 })
705 }
706
707 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
708 if !self.is_trained {
709 return Err(EmbeddingError::ModelNotTrained.into());
710 }
711
712 self.entity_embeddings
713 .get(entity)
714 .map(|e| Vector::new(e.to_vec()))
715 .ok_or_else(|| {
716 EmbeddingError::EntityNotFound {
717 entity: entity.to_string(),
718 }
719 .into()
720 })
721 }
722
723 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
724 if !self.is_trained {
725 return Err(EmbeddingError::ModelNotTrained.into());
726 }
727
728 self.relation_embeddings
729 .get(relation)
730 .map(|e| Vector::new(e.to_vec()))
731 .ok_or_else(|| {
732 EmbeddingError::RelationNotFound {
733 relation: relation.to_string(),
734 }
735 .into()
736 })
737 }
738
739 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
740 if !self.is_trained {
741 return Err(EmbeddingError::ModelNotTrained.into());
742 }
743
744 let subj_emb =
745 self.entity_embeddings
746 .get(subject)
747 .ok_or_else(|| EmbeddingError::EntityNotFound {
748 entity: subject.to_string(),
749 })?;
750
751 let pred_emb = self.relation_embeddings.get(predicate).ok_or_else(|| {
752 EmbeddingError::RelationNotFound {
753 relation: predicate.to_string(),
754 }
755 })?;
756
757 let obj_emb =
758 self.entity_embeddings
759 .get(object)
760 .ok_or_else(|| EmbeddingError::EntityNotFound {
761 entity: object.to_string(),
762 })?;
763
764 let transformed = (subj_emb + pred_emb) * obj_emb;
766 Ok(transformed.sum() as f64)
767 }
768
769 fn predict_objects(
770 &self,
771 subject: &str,
772 predicate: &str,
773 k: usize,
774 ) -> Result<Vec<(String, f64)>> {
775 if !self.is_trained {
776 return Err(EmbeddingError::ModelNotTrained.into());
777 }
778
779 let mut scores = Vec::new();
780
781 for entity in self.entity_to_idx.keys() {
782 if let Ok(score) = self.score_triple(subject, predicate, entity) {
783 scores.push((entity.clone(), score));
784 }
785 }
786
787 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
788 scores.truncate(k);
789
790 Ok(scores)
791 }
792
793 fn predict_subjects(
794 &self,
795 predicate: &str,
796 object: &str,
797 k: usize,
798 ) -> Result<Vec<(String, f64)>> {
799 if !self.is_trained {
800 return Err(EmbeddingError::ModelNotTrained.into());
801 }
802
803 let mut scores = Vec::new();
804
805 for entity in self.entity_to_idx.keys() {
806 if let Ok(score) = self.score_triple(entity, predicate, object) {
807 scores.push((entity.clone(), score));
808 }
809 }
810
811 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
812 scores.truncate(k);
813
814 Ok(scores)
815 }
816
817 fn predict_relations(
818 &self,
819 subject: &str,
820 object: &str,
821 k: usize,
822 ) -> Result<Vec<(String, f64)>> {
823 if !self.is_trained {
824 return Err(EmbeddingError::ModelNotTrained.into());
825 }
826
827 let mut scores = Vec::new();
828
829 for relation in self.relation_to_idx.keys() {
830 if let Ok(score) = self.score_triple(subject, relation, object) {
831 scores.push((relation.clone(), score));
832 }
833 }
834
835 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
836 scores.truncate(k);
837
838 Ok(scores)
839 }
840
841 fn get_entities(&self) -> Vec<String> {
842 self.entity_to_idx.keys().cloned().collect()
843 }
844
845 fn get_relations(&self) -> Vec<String> {
846 self.relation_to_idx.keys().cloned().collect()
847 }
848
849 fn get_stats(&self) -> ModelStats {
850 ModelStats {
851 num_entities: self.entity_to_idx.len(),
852 num_relations: self.relation_to_idx.len(),
853 num_triples: self.triples.len(),
854 dimensions: self.config.base_config.dimensions,
855 is_trained: self.is_trained,
856 model_type: format!("GNNEmbedding-{:?}", self.config.gnn_type),
857 creation_time: self.creation_time,
858 last_training_time: self.last_training_time,
859 }
860 }
861
862 fn save(&self, _path: &str) -> Result<()> {
863 Ok(())
865 }
866
867 fn load(&mut self, _path: &str) -> Result<()> {
868 Ok(())
870 }
871
872 fn clear(&mut self) {
873 self.entity_embeddings.clear();
874 self.relation_embeddings.clear();
875 self.entity_to_idx.clear();
876 self.relation_to_idx.clear();
877 self.idx_to_entity.clear();
878 self.idx_to_relation.clear();
879 self.adjacency_list.clear();
880 self.reverse_adjacency_list.clear();
881 self.triples.clear();
882 self.layers.clear();
883 self.is_trained = false;
884 }
885
886 fn is_trained(&self) -> bool {
887 self.is_trained
888 }
889
890 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
891 Err(anyhow!(
892 "Knowledge graph embedding model does not support text encoding"
893 ))
894 }
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900 use crate::NamedNode;
901
902 #[tokio::test]
903 async fn test_gnn_embedding_basic() {
904 let config = GNNConfig {
905 gnn_type: GNNType::GCN,
906 num_layers: 2,
907 hidden_dimensions: vec![64, 32],
908 ..Default::default()
909 };
910
911 let mut model = GNNEmbedding::new(config);
912
913 let triple1 = Triple::new(
915 NamedNode::new("http://example.org/Alice").unwrap(),
916 NamedNode::new("http://example.org/knows").unwrap(),
917 NamedNode::new("http://example.org/Bob").unwrap(),
918 );
919
920 let triple2 = Triple::new(
921 NamedNode::new("http://example.org/Bob").unwrap(),
922 NamedNode::new("http://example.org/knows").unwrap(),
923 NamedNode::new("http://example.org/Charlie").unwrap(),
924 );
925
926 model.add_triple(triple1).unwrap();
927 model.add_triple(triple2).unwrap();
928
929 let _stats = model.train(Some(10)).await.unwrap();
931 assert!(model.is_trained());
932
933 let alice_emb = model
935 .get_entity_embedding("http://example.org/Alice")
936 .unwrap();
937 assert_eq!(alice_emb.dimensions, 100); let predictions = model
941 .predict_objects("http://example.org/Alice", "http://example.org/knows", 5)
942 .unwrap();
943 assert!(!predictions.is_empty());
944 }
945
946 #[tokio::test]
947 async fn test_gnn_types() {
948 for gnn_type in [GNNType::GCN, GNNType::GraphSAGE, GNNType::GAT, GNNType::GIN] {
949 let config = GNNConfig {
950 gnn_type,
951 num_heads: if gnn_type == GNNType::GAT {
952 Some(4)
953 } else {
954 None
955 },
956 ..Default::default()
957 };
958
959 let mut model = GNNEmbedding::new(config);
960
961 let triple = Triple::new(
962 NamedNode::new("http://example.org/A").unwrap(),
963 NamedNode::new("http://example.org/rel").unwrap(),
964 NamedNode::new("http://example.org/B").unwrap(),
965 );
966
967 model.add_triple(triple).unwrap();
968 let _stats = model.train(Some(5)).await.unwrap();
969 assert!(model.is_trained());
970 }
971 }
972}