1use crate::{
7 EmbeddingError, EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector,
8};
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use chrono::Utc;
12use scirs2_core::ndarray_ext::{Array1, Array2};
13#[allow(unused_imports)]
14use scirs2_core::random::{Random, Rng};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, HashSet};
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
21pub enum GNNType {
22 GCN,
24 GraphSAGE,
26 GAT,
28 GraphTransformer,
30 GIN,
32 PNA,
34 HetGNN,
36 TGN,
38}
39
40impl GNNType {
41 pub fn default_layers(&self) -> usize {
42 match self {
43 GNNType::GCN => 2,
44 GNNType::GraphSAGE => 2,
45 GNNType::GAT => 2,
46 GNNType::GraphTransformer => 4,
47 GNNType::GIN => 3,
48 GNNType::PNA => 3,
49 GNNType::HetGNN => 2,
50 GNNType::TGN => 2,
51 }
52 }
53
54 pub fn requires_attention(&self) -> bool {
55 matches!(self, GNNType::GAT | GNNType::GraphTransformer)
56 }
57}
58
59#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
61pub enum AggregationType {
62 Mean,
63 Max,
64 Sum,
65 LSTM,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GNNConfig {
71 pub base_config: ModelConfig,
72 pub gnn_type: GNNType,
73 pub num_layers: usize,
74 pub hidden_dimensions: Vec<usize>,
75 pub dropout: f64,
76 pub aggregation: AggregationType,
77 pub num_heads: Option<usize>, pub sample_neighbors: Option<usize>, pub residual_connections: bool,
80 pub layer_norm: bool,
81 pub edge_features: bool,
82}
83
84impl Default for GNNConfig {
85 fn default() -> Self {
86 Self {
87 base_config: ModelConfig::default(),
88 gnn_type: GNNType::GCN,
89 num_layers: 2,
90 hidden_dimensions: vec![128, 64],
91 dropout: 0.1,
92 aggregation: AggregationType::Mean,
93 num_heads: None,
94 sample_neighbors: None,
95 residual_connections: true,
96 layer_norm: true,
97 edge_features: false,
98 }
99 }
100}
101
102pub struct GNNEmbedding {
104 id: Uuid,
105 config: GNNConfig,
106 entity_embeddings: HashMap<String, Array1<f32>>,
107 relation_embeddings: HashMap<String, Array1<f32>>,
108 entity_to_idx: HashMap<String, usize>,
109 relation_to_idx: HashMap<String, usize>,
110 idx_to_entity: HashMap<usize, String>,
111 idx_to_relation: HashMap<usize, String>,
112 adjacency_list: HashMap<usize, HashSet<(usize, usize)>>, reverse_adjacency_list: HashMap<usize, HashSet<(usize, usize)>>,
114 triples: Vec<Triple>,
115 layers: Vec<GNNLayer>,
116 is_trained: bool,
117 creation_time: chrono::DateTime<Utc>,
118 last_training_time: Option<chrono::DateTime<Utc>>,
119}
120
121struct GNNLayer {
123 weight_matrix: Array2<f32>,
124 bias: Array1<f32>,
125 attention_weights: Option<AttentionWeights>,
126 layer_norm: Option<LayerNormalization>,
127}
128
129struct AttentionWeights {
131 query_weights: Array2<f32>,
132 key_weights: Array2<f32>,
133 value_weights: Array2<f32>,
134 num_heads: usize,
135}
136
137struct LayerNormalization {
139 gamma: Array1<f32>,
140 beta: Array1<f32>,
141 epsilon: f32,
142}
143
144impl GNNEmbedding {
145 pub fn new(config: GNNConfig) -> Self {
146 Self {
147 id: Uuid::new_v4(),
148 config,
149 entity_embeddings: HashMap::new(),
150 relation_embeddings: HashMap::new(),
151 entity_to_idx: HashMap::new(),
152 relation_to_idx: HashMap::new(),
153 idx_to_entity: HashMap::new(),
154 idx_to_relation: HashMap::new(),
155 adjacency_list: HashMap::new(),
156 reverse_adjacency_list: HashMap::new(),
157 triples: Vec::new(),
158 layers: Vec::new(),
159 is_trained: false,
160 creation_time: Utc::now(),
161 last_training_time: None,
162 }
163 }
164
165 fn initialize_layers(&mut self) -> Result<()> {
167 self.layers.clear();
168 let mut rng = Random::seed(42);
169
170 let mut input_dim = self.config.base_config.dimensions;
171 let num_layers = self.config.num_layers;
172
173 for i in 0..num_layers {
174 let output_dim = if i == num_layers - 1 {
175 self.config.base_config.dimensions
177 } else if i < self.config.hidden_dimensions.len() {
178 self.config.hidden_dimensions[i]
179 } else {
180 self.config.base_config.dimensions
181 };
182
183 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
185 let weight_matrix = Array2::from_shape_fn((input_dim, output_dim), |_| {
186 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
187 });
188
189 let bias = Array1::zeros(output_dim);
190
191 let attention_weights = if self.config.gnn_type.requires_attention() {
193 let num_heads = self.config.num_heads.unwrap_or(8);
194 let head_dim = output_dim / num_heads;
195
196 let attention_dim = head_dim * num_heads; Some(AttentionWeights {
200 query_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
201 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
202 }),
203 key_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
204 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
205 }),
206 value_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
207 rng.gen_range(0.0..1.0) * scale * 2.0 - scale
208 }),
209 num_heads,
210 })
211 } else {
212 None
213 };
214
215 let layer_norm = if self.config.layer_norm {
217 Some(LayerNormalization {
218 gamma: Array1::ones(output_dim),
219 beta: Array1::zeros(output_dim),
220 epsilon: 1e-5,
221 })
222 } else {
223 None
224 };
225
226 self.layers.push(GNNLayer {
227 weight_matrix,
228 bias,
229 attention_weights,
230 layer_norm,
231 });
232
233 input_dim = output_dim;
234 }
235
236 Ok(())
237 }
238
239 fn build_adjacency_lists(&mut self) {
241 self.adjacency_list.clear();
242 self.reverse_adjacency_list.clear();
243
244 for triple in &self.triples {
245 let subject_idx = self.entity_to_idx[&triple.subject.iri];
246 let object_idx = self.entity_to_idx[&triple.object.iri];
247 let relation_idx = self.relation_to_idx[&triple.predicate.iri];
248
249 self.adjacency_list
251 .entry(subject_idx)
252 .or_default()
253 .insert((object_idx, relation_idx));
254
255 self.reverse_adjacency_list
257 .entry(object_idx)
258 .or_default()
259 .insert((subject_idx, relation_idx));
260 }
261 }
262
263 fn aggregate_neighbors(
265 &self,
266 node_idx: usize,
267 node_features: &HashMap<usize, Array1<f32>>,
268 ) -> Array1<f32> {
269 let neighbors = self.adjacency_list.get(&node_idx);
270 let reverse_neighbors = self.reverse_adjacency_list.get(&node_idx);
271
272 let mut neighbor_features = Vec::new();
273
274 if let Some(neighbors) = neighbors {
276 for (neighbor_idx, _) in neighbors {
277 if let Some(feature) = node_features.get(neighbor_idx) {
278 neighbor_features.push(feature.clone());
279 }
280 }
281 }
282
283 if let Some(reverse_neighbors) = reverse_neighbors {
285 for (neighbor_idx, _) in reverse_neighbors {
286 if let Some(feature) = node_features.get(neighbor_idx) {
287 neighbor_features.push(feature.clone());
288 }
289 }
290 }
291
292 if neighbor_features.is_empty() {
293 return Array1::zeros(node_features.values().next().unwrap().len());
295 }
296
297 match self.config.aggregation {
299 AggregationType::Mean => {
300 let sum: Array1<f32> = neighbor_features
301 .iter()
302 .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x);
303 sum / neighbor_features.len() as f32
304 }
305 AggregationType::Max => neighbor_features.iter().fold(
306 Array1::from_elem(neighbor_features[0].len(), f32::NEG_INFINITY),
307 |acc, x| {
308 let mut result = acc.clone();
309 for (i, &val) in x.iter().enumerate() {
310 result[i] = result[i].max(val);
311 }
312 result
313 },
314 ),
315 AggregationType::Sum => neighbor_features
316 .iter()
317 .fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x),
318 AggregationType::LSTM => {
319 self.aggregate_neighbors_lstm(&neighbor_features)
321 }
322 }
323 }
324
325 fn aggregate_neighbors_lstm(&self, neighbor_features: &[Array1<f32>]) -> Array1<f32> {
327 let mut aggregated = Array1::zeros(neighbor_features[0].len());
329 for feature in neighbor_features {
330 aggregated = aggregated * 0.8 + feature * 0.2; }
332 aggregated
333 }
334
335 fn apply_layer(
337 &self,
338 layer: &GNNLayer,
339 node_features: &HashMap<usize, Array1<f32>>,
340 ) -> HashMap<usize, Array1<f32>> {
341 let mut new_features = HashMap::new();
342
343 match self.config.gnn_type {
344 GNNType::GCN => self.apply_gcn_layer(layer, node_features, &mut new_features),
345 GNNType::GraphSAGE => {
346 self.apply_graphsage_layer(layer, node_features, &mut new_features)
347 }
348 GNNType::GAT => self.apply_gat_layer(layer, node_features, &mut new_features),
349 GNNType::GIN => self.apply_gin_layer(layer, node_features, &mut new_features),
350 _ => self.apply_gcn_layer(layer, node_features, &mut new_features), }
352
353 new_features
354 }
355
356 fn apply_gcn_layer(
358 &self,
359 layer: &GNNLayer,
360 node_features: &HashMap<usize, Array1<f32>>,
361 new_features: &mut HashMap<usize, Array1<f32>>,
362 ) {
363 for (node_idx, feature) in node_features {
364 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
365 let combined = feature + &aggregated;
366 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
367
368 let activated = transformed.mapv(|x| x.max(0.0));
370
371 let output = if let Some(ln) = &layer.layer_norm {
373 self.apply_layer_norm(&activated, ln)
374 } else {
375 activated
376 };
377
378 new_features.insert(*node_idx, output);
379 }
380 }
381
382 fn apply_graphsage_layer(
384 &self,
385 layer: &GNNLayer,
386 node_features: &HashMap<usize, Array1<f32>>,
387 new_features: &mut HashMap<usize, Array1<f32>>,
388 ) {
389 for (node_idx, feature) in node_features {
390 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
391
392 let node_transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
395
396 let neighbor_transformed = aggregated.dot(&layer.weight_matrix) + &layer.bias;
398
399 let combined = &node_transformed + &neighbor_transformed;
401
402 let activated = combined.mapv(|x| x.max(0.0));
404 let normalized = &activated / (activated.dot(&activated).sqrt() + 1e-6);
405
406 new_features.insert(*node_idx, normalized);
407 }
408 }
409
410 fn apply_gat_layer(
412 &self,
413 layer: &GNNLayer,
414 node_features: &HashMap<usize, Array1<f32>>,
415 new_features: &mut HashMap<usize, Array1<f32>>,
416 ) {
417 let attention = layer.attention_weights.as_ref().unwrap();
419
420 for (node_idx, feature) in node_features {
421 let mut neighbor_indices = Vec::new();
423 if let Some(neighbors) = self.adjacency_list.get(node_idx) {
424 neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
425 }
426 if let Some(neighbors) = self.reverse_adjacency_list.get(node_idx) {
427 neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
428 }
429
430 if neighbor_indices.is_empty() {
431 let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
433 let activated = transformed.mapv(|x| x.max(0.0));
434 new_features.insert(*node_idx, activated);
435 continue;
436 }
437
438 if feature.len() != attention.query_weights.shape()[0] {
440 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
442 let combined = feature + &aggregated;
443 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
444 let activated = transformed.mapv(|x| x.max(0.0));
445 new_features.insert(*node_idx, activated);
446 continue;
447 }
448
449 let query = feature.dot(&attention.query_weights);
451 let mut attention_scores = Vec::new();
452 let mut neighbor_values = Vec::new();
453
454 for neighbor_idx in &neighbor_indices {
455 if let Some(neighbor_feature) = node_features.get(neighbor_idx) {
456 if neighbor_feature.len() != attention.key_weights.shape()[0] {
458 continue;
459 }
460
461 let key = neighbor_feature.dot(&attention.key_weights);
462 let value = neighbor_feature.dot(&attention.value_weights);
463
464 if query.len() == key.len() {
466 let score = query.dot(&key) / (attention.num_heads as f32).sqrt();
467 attention_scores.push(score);
468 neighbor_values.push(value);
469 }
470 }
471 }
472
473 if attention_scores.is_empty() {
474 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
476 let combined = feature + &aggregated;
477 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
478 let activated = transformed.mapv(|x| x.max(0.0));
479 new_features.insert(*node_idx, activated);
480 continue;
481 }
482
483 let max_score = attention_scores
485 .iter()
486 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
487 let exp_scores: Vec<f32> = attention_scores
488 .iter()
489 .map(|&s| (s - max_score).exp())
490 .collect();
491 let sum_exp = exp_scores.iter().sum::<f32>();
492 let attention_weights: Vec<f32> =
493 exp_scores.iter().copied().map(|e| e / sum_exp).collect();
494
495 let output_dim = layer.weight_matrix.shape()[1];
497 let mut aggregated = Array1::<f32>::zeros(output_dim);
498
499 for (i, value) in neighbor_values.iter().enumerate() {
500 let min_dim = aggregated.len().min(value.len());
502 for j in 0..min_dim {
503 aggregated[j] += value[j] * attention_weights[i];
504 }
505 }
506
507 let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
509 let combined =
510 if self.config.residual_connections && transformed.len() == aggregated.len() {
511 transformed + &aggregated
512 } else {
513 transformed
514 };
515
516 let activated = combined.mapv(|x| x.max(0.0));
517 new_features.insert(*node_idx, activated);
518 }
519 }
520
521 fn apply_gin_layer(
523 &self,
524 layer: &GNNLayer,
525 node_features: &HashMap<usize, Array1<f32>>,
526 new_features: &mut HashMap<usize, Array1<f32>>,
527 ) {
528 let epsilon = 0.0; for (node_idx, feature) in node_features {
531 let aggregated = self.aggregate_neighbors(*node_idx, node_features);
532 let combined = (1.0 + epsilon) * feature + aggregated;
533
534 let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
536 let activated = transformed.mapv(|x| x.max(0.0));
537
538 new_features.insert(*node_idx, activated);
539 }
540 }
541
542 fn apply_layer_norm(&self, input: &Array1<f32>, ln: &LayerNormalization) -> Array1<f32> {
544 let mean = input.mean().unwrap_or(0.0);
545 let variance = input.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(1.0);
546 let normalized = input.mapv(|x| (x - mean) / (variance + ln.epsilon).sqrt());
547 &normalized * &ln.gamma + &ln.beta
548 }
549
550 fn forward(
552 &self,
553 initial_features: HashMap<usize, Array1<f32>>,
554 ) -> HashMap<usize, Array1<f32>> {
555 let mut features = initial_features;
556
557 for layer in self.layers.iter() {
558 let new_features = self.apply_layer(layer, &features);
559
560 let dropout_rate = self.config.dropout;
562 let mut rng = Random::seed(42);
563
564 features = new_features
565 .into_iter()
566 .map(|(idx, feat)| {
567 let masked = feat.mapv(|x| {
568 if rng.gen_range(0.0..1.0) > dropout_rate as f32 {
569 x / (1.0 - dropout_rate as f32)
570 } else {
571 0.0
572 }
573 });
574 (idx, masked)
575 })
576 .collect();
577 }
578
579 features
580 }
581}
582
583#[async_trait]
584impl EmbeddingModel for GNNEmbedding {
585 fn config(&self) -> &ModelConfig {
586 &self.config.base_config
587 }
588
589 fn model_id(&self) -> &Uuid {
590 &self.id
591 }
592
593 fn model_type(&self) -> &'static str {
594 "GNNEmbedding"
595 }
596
597 fn add_triple(&mut self, triple: Triple) -> Result<()> {
598 let subject = triple.subject.iri.clone();
600 let object = triple.object.iri.clone();
601 let predicate = triple.predicate.iri.clone();
602
603 if !self.entity_to_idx.contains_key(&subject) {
604 let idx = self.entity_to_idx.len();
605 self.entity_to_idx.insert(subject.clone(), idx);
606 self.idx_to_entity.insert(idx, subject);
607 }
608
609 if !self.entity_to_idx.contains_key(&object) {
610 let idx = self.entity_to_idx.len();
611 self.entity_to_idx.insert(object.clone(), idx);
612 self.idx_to_entity.insert(idx, object);
613 }
614
615 if !self.relation_to_idx.contains_key(&predicate) {
616 let idx = self.relation_to_idx.len();
617 self.relation_to_idx.insert(predicate.clone(), idx);
618 self.idx_to_relation.insert(idx, predicate);
619 }
620
621 self.triples.push(triple);
622 self.is_trained = false;
623 Ok(())
624 }
625
626 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
627 let start_time = std::time::Instant::now();
628 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
629
630 self.build_adjacency_lists();
632
633 self.initialize_layers()?;
635
636 let mut rng = Random::seed(42);
638 let dimensions = self.config.base_config.dimensions;
639
640 let mut initial_features = HashMap::new();
641 for idx in self.entity_to_idx.values() {
642 let embedding =
643 Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
644 initial_features.insert(*idx, embedding);
645 }
646
647 let mut loss_history = Vec::new();
649
650 for _epoch in 0..epochs {
651 let output_features = self.forward(initial_features.clone());
653
654 let loss = output_features
656 .values()
657 .map(|f| f.mapv(|x| x * x).sum())
658 .sum::<f32>()
659 / output_features.len() as f32;
660
661 loss_history.push(loss as f64);
662
663 initial_features = output_features;
665
666 if loss < 0.001 {
668 break;
669 }
670 }
671
672 for (idx, embedding) in initial_features {
674 if let Some(entity) = self.idx_to_entity.get(&idx) {
675 self.entity_embeddings.insert(entity.clone(), embedding);
676 }
677 }
678
679 for relation in self.relation_to_idx.keys() {
681 let embedding =
682 Array1::from_shape_fn(dimensions, |_| rng.gen_range(0.0..1.0) * 0.1 - 0.05);
683 self.relation_embeddings.insert(relation.clone(), embedding);
684 }
685
686 self.is_trained = true;
687 self.last_training_time = Some(Utc::now());
688
689 Ok(TrainingStats {
690 epochs_completed: loss_history.len(),
691 final_loss: *loss_history.last().unwrap_or(&0.0),
692 training_time_seconds: start_time.elapsed().as_secs_f64(),
693 convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
694 loss_history,
695 })
696 }
697
698 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
699 if !self.is_trained {
700 return Err(EmbeddingError::ModelNotTrained.into());
701 }
702
703 self.entity_embeddings
704 .get(entity)
705 .map(|e| Vector::new(e.to_vec()))
706 .ok_or_else(|| {
707 EmbeddingError::EntityNotFound {
708 entity: entity.to_string(),
709 }
710 .into()
711 })
712 }
713
714 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
715 if !self.is_trained {
716 return Err(EmbeddingError::ModelNotTrained.into());
717 }
718
719 self.relation_embeddings
720 .get(relation)
721 .map(|e| Vector::new(e.to_vec()))
722 .ok_or_else(|| {
723 EmbeddingError::RelationNotFound {
724 relation: relation.to_string(),
725 }
726 .into()
727 })
728 }
729
730 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
731 if !self.is_trained {
732 return Err(EmbeddingError::ModelNotTrained.into());
733 }
734
735 let subj_emb =
736 self.entity_embeddings
737 .get(subject)
738 .ok_or_else(|| EmbeddingError::EntityNotFound {
739 entity: subject.to_string(),
740 })?;
741
742 let pred_emb = self.relation_embeddings.get(predicate).ok_or_else(|| {
743 EmbeddingError::RelationNotFound {
744 relation: predicate.to_string(),
745 }
746 })?;
747
748 let obj_emb =
749 self.entity_embeddings
750 .get(object)
751 .ok_or_else(|| EmbeddingError::EntityNotFound {
752 entity: object.to_string(),
753 })?;
754
755 let transformed = (subj_emb + pred_emb) * obj_emb;
757 Ok(transformed.sum() as f64)
758 }
759
760 fn predict_objects(
761 &self,
762 subject: &str,
763 predicate: &str,
764 k: usize,
765 ) -> Result<Vec<(String, f64)>> {
766 if !self.is_trained {
767 return Err(EmbeddingError::ModelNotTrained.into());
768 }
769
770 let mut scores = Vec::new();
771
772 for entity in self.entity_to_idx.keys() {
773 if let Ok(score) = self.score_triple(subject, predicate, entity) {
774 scores.push((entity.clone(), score));
775 }
776 }
777
778 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
779 scores.truncate(k);
780
781 Ok(scores)
782 }
783
784 fn predict_subjects(
785 &self,
786 predicate: &str,
787 object: &str,
788 k: usize,
789 ) -> Result<Vec<(String, f64)>> {
790 if !self.is_trained {
791 return Err(EmbeddingError::ModelNotTrained.into());
792 }
793
794 let mut scores = Vec::new();
795
796 for entity in self.entity_to_idx.keys() {
797 if let Ok(score) = self.score_triple(entity, predicate, object) {
798 scores.push((entity.clone(), score));
799 }
800 }
801
802 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
803 scores.truncate(k);
804
805 Ok(scores)
806 }
807
808 fn predict_relations(
809 &self,
810 subject: &str,
811 object: &str,
812 k: usize,
813 ) -> Result<Vec<(String, f64)>> {
814 if !self.is_trained {
815 return Err(EmbeddingError::ModelNotTrained.into());
816 }
817
818 let mut scores = Vec::new();
819
820 for relation in self.relation_to_idx.keys() {
821 if let Ok(score) = self.score_triple(subject, relation, object) {
822 scores.push((relation.clone(), score));
823 }
824 }
825
826 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
827 scores.truncate(k);
828
829 Ok(scores)
830 }
831
832 fn get_entities(&self) -> Vec<String> {
833 self.entity_to_idx.keys().cloned().collect()
834 }
835
836 fn get_relations(&self) -> Vec<String> {
837 self.relation_to_idx.keys().cloned().collect()
838 }
839
840 fn get_stats(&self) -> ModelStats {
841 ModelStats {
842 num_entities: self.entity_to_idx.len(),
843 num_relations: self.relation_to_idx.len(),
844 num_triples: self.triples.len(),
845 dimensions: self.config.base_config.dimensions,
846 is_trained: self.is_trained,
847 model_type: format!("GNNEmbedding-{:?}", self.config.gnn_type),
848 creation_time: self.creation_time,
849 last_training_time: self.last_training_time,
850 }
851 }
852
853 fn save(&self, _path: &str) -> Result<()> {
854 Ok(())
856 }
857
858 fn load(&mut self, _path: &str) -> Result<()> {
859 Ok(())
861 }
862
863 fn clear(&mut self) {
864 self.entity_embeddings.clear();
865 self.relation_embeddings.clear();
866 self.entity_to_idx.clear();
867 self.relation_to_idx.clear();
868 self.idx_to_entity.clear();
869 self.idx_to_relation.clear();
870 self.adjacency_list.clear();
871 self.reverse_adjacency_list.clear();
872 self.triples.clear();
873 self.layers.clear();
874 self.is_trained = false;
875 }
876
877 fn is_trained(&self) -> bool {
878 self.is_trained
879 }
880
881 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
882 Err(anyhow!(
883 "Knowledge graph embedding model does not support text encoding"
884 ))
885 }
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891 use crate::NamedNode;
892
893 #[tokio::test]
894 async fn test_gnn_embedding_basic() {
895 let config = GNNConfig {
896 gnn_type: GNNType::GCN,
897 num_layers: 2,
898 hidden_dimensions: vec![64, 32],
899 ..Default::default()
900 };
901
902 let mut model = GNNEmbedding::new(config);
903
904 let triple1 = Triple::new(
906 NamedNode::new("http://example.org/Alice").unwrap(),
907 NamedNode::new("http://example.org/knows").unwrap(),
908 NamedNode::new("http://example.org/Bob").unwrap(),
909 );
910
911 let triple2 = Triple::new(
912 NamedNode::new("http://example.org/Bob").unwrap(),
913 NamedNode::new("http://example.org/knows").unwrap(),
914 NamedNode::new("http://example.org/Charlie").unwrap(),
915 );
916
917 model.add_triple(triple1).unwrap();
918 model.add_triple(triple2).unwrap();
919
920 let _stats = model.train(Some(10)).await.unwrap();
922 assert!(model.is_trained());
923
924 let alice_emb = model
926 .get_entity_embedding("http://example.org/Alice")
927 .unwrap();
928 assert_eq!(alice_emb.dimensions, 100); let predictions = model
932 .predict_objects("http://example.org/Alice", "http://example.org/knows", 5)
933 .unwrap();
934 assert!(!predictions.is_empty());
935 }
936
937 #[tokio::test]
938 async fn test_gnn_types() {
939 for gnn_type in [GNNType::GCN, GNNType::GraphSAGE, GNNType::GAT, GNNType::GIN] {
940 let config = GNNConfig {
941 gnn_type,
942 num_heads: if gnn_type == GNNType::GAT {
943 Some(4)
944 } else {
945 None
946 },
947 ..Default::default()
948 };
949
950 let mut model = GNNEmbedding::new(config);
951
952 let triple = Triple::new(
953 NamedNode::new("http://example.org/A").unwrap(),
954 NamedNode::new("http://example.org/rel").unwrap(),
955 NamedNode::new("http://example.org/B").unwrap(),
956 );
957
958 model.add_triple(triple).unwrap();
959 let _stats = model.train(Some(5)).await.unwrap();
960 assert!(model.is_trained());
961 }
962 }
963}