1use crate::{
8 kg_embeddings::{KGEmbeddingConfig, KGEmbeddingModel, Triple},
9 Vector,
10};
11use anyhow::{anyhow, Result};
12use nalgebra::{DMatrix, DVector};
13use crate::random_utils::NormalSampler as Normal;
14use scirs2_core::random::{Random, Rng};
15use std::collections::HashMap;
16
17pub struct GCN {
19 config: KGEmbeddingConfig,
20 entity_embeddings: HashMap<String, DVector<f32>>,
21 relation_embeddings: HashMap<String, DVector<f32>>,
22 entities: Vec<String>,
23 relations: Vec<String>,
24 adjacency_matrix: Option<DMatrix<f32>>,
25 weight_matrices: Vec<DMatrix<f32>>,
26 num_layers: usize,
27}
28
29impl GCN {
30 pub fn new(config: KGEmbeddingConfig) -> Self {
31 let num_layers = 2; Self {
33 config,
34 entity_embeddings: HashMap::new(),
35 relation_embeddings: HashMap::new(),
36 entities: Vec::new(),
37 relations: Vec::new(),
38 adjacency_matrix: None,
39 weight_matrices: Vec::new(),
40 num_layers,
41 }
42 }
43
44 pub fn with_layers(config: KGEmbeddingConfig, num_layers: usize) -> Self {
46 Self {
47 config,
48 entity_embeddings: HashMap::new(),
49 relation_embeddings: HashMap::new(),
50 entities: Vec::new(),
51 relations: Vec::new(),
52 adjacency_matrix: None,
53 weight_matrices: Vec::new(),
54 num_layers,
55 }
56 }
57
58 fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
60 let mut entities = std::collections::HashSet::new();
62 let mut relations = std::collections::HashSet::new();
63
64 for triple in triples {
65 entities.insert(triple.subject.clone());
66 entities.insert(triple.object.clone());
67 relations.insert(triple.predicate.clone());
68 }
69
70 self.entities = entities.into_iter().collect();
71 self.relations = relations.into_iter().collect();
72
73 let _num_entities = self.entities.len();
74
75 let mut rng = if let Some(seed) = self.config.random_seed {
77 Random::seed(seed)
78 } else {
79 Random::seed(42)
80 };
81
82 let normal = Normal::new(0.0, 0.1)
83 .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
84
85 for entity in &self.entities {
86 let embedding: Vec<f32> = (0..self.config.dimensions)
87 .map(|_| normal.sample(&mut rng))
88 .collect();
89 self.entity_embeddings
90 .insert(entity.clone(), DVector::from_vec(embedding));
91 }
92
93 for relation in &self.relations {
94 let embedding: Vec<f32> = (0..self.config.dimensions)
95 .map(|_| normal.sample(&mut rng))
96 .collect();
97 self.relation_embeddings
98 .insert(relation.clone(), DVector::from_vec(embedding));
99 }
100
101 self.build_adjacency_matrix(triples)?;
103
104 self.weight_matrices.clear();
106 for _ in 0..self.num_layers {
107 let weight_matrix =
108 DMatrix::from_fn(self.config.dimensions, self.config.dimensions, |_, _| {
109 normal.sample(&mut rng)
110 });
111 self.weight_matrices.push(weight_matrix);
112 }
113
114 Ok(())
115 }
116
117 fn build_adjacency_matrix(&mut self, triples: &[Triple]) -> Result<()> {
119 let num_entities = self.entities.len();
120 let mut adj_matrix = DMatrix::zeros(num_entities, num_entities);
121
122 let entity_to_index: HashMap<String, usize> = self
124 .entities
125 .iter()
126 .enumerate()
127 .map(|(i, entity)| (entity.clone(), i))
128 .collect();
129
130 for triple in triples {
132 if let (Some(&subject_idx), Some(&object_idx)) = (
133 entity_to_index.get(&triple.subject),
134 entity_to_index.get(&triple.object),
135 ) {
136 adj_matrix[(subject_idx, object_idx)] = 1.0;
137 adj_matrix[(object_idx, subject_idx)] = 1.0; }
139 }
140
141 for i in 0..num_entities {
143 adj_matrix[(i, i)] = 1.0;
144 }
145
146 self.adjacency_matrix = Some(self.normalize_adjacency_matrix(adj_matrix));
148
149 Ok(())
150 }
151
152 fn normalize_adjacency_matrix(&self, mut adj_matrix: DMatrix<f32>) -> DMatrix<f32> {
154 let num_nodes = adj_matrix.nrows();
155
156 let mut degrees = Vec::with_capacity(num_nodes);
158 for i in 0..num_nodes {
159 let degree: f32 = (0..num_nodes).map(|j| adj_matrix[(i, j)]).sum();
160 degrees.push(if degree > 0.0 {
161 1.0 / degree.sqrt()
162 } else {
163 0.0
164 });
165 }
166
167 for i in 0..num_nodes {
169 for j in 0..num_nodes {
170 adj_matrix[(i, j)] *= degrees[i] * degrees[j];
171 }
172 }
173
174 adj_matrix
175 }
176
177 fn forward_pass(&self, features: &DMatrix<f32>) -> Result<DMatrix<f32>> {
179 let adj_matrix = self
180 .adjacency_matrix
181 .as_ref()
182 .ok_or_else(|| anyhow!("Adjacency matrix not initialized"))?;
183
184 let mut hidden = features.clone();
185
186 for layer_idx in 0..self.num_layers {
187 let weight = &self.weight_matrices[layer_idx];
188
189 let linear_transform = &hidden * weight;
191 hidden = adj_matrix * &linear_transform;
192
193 if layer_idx < self.num_layers - 1 {
195 hidden = hidden.map(|x| x.max(0.0));
196 }
197 }
198
199 Ok(hidden)
200 }
201
202 fn train_gcn(&mut self, _triples: &[Triple]) -> Result<()> {
204 let num_entities = self.entities.len();
206 let mut features = DMatrix::zeros(num_entities, self.config.dimensions);
207
208 for (i, entity) in self.entities.iter().enumerate() {
209 if let Some(embedding) = self.entity_embeddings.get(entity) {
210 for (j, &value) in embedding.iter().enumerate() {
211 features[(i, j)] = value;
212 }
213 }
214 }
215
216 let updated_features = self.forward_pass(&features)?;
218
219 for (i, entity) in self.entities.iter().enumerate() {
221 let new_embedding: Vec<f32> = (0..self.config.dimensions)
222 .map(|j| updated_features[(i, j)])
223 .collect();
224 self.entity_embeddings
225 .insert(entity.clone(), DVector::from_vec(new_embedding));
226 }
227
228 Ok(())
229 }
230}
231
232impl KGEmbeddingModel for GCN {
233 fn train(&mut self, triples: &[Triple]) -> Result<()> {
234 self.initialize(triples)?;
235
236 for epoch in 0..self.config.epochs {
237 self.train_gcn(triples)?;
238
239 if epoch % 10 == 0 {
240 println!("GCN training epoch {}/{}", epoch, self.config.epochs);
241 }
242 }
243
244 Ok(())
245 }
246
247 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
248 self.entity_embeddings
249 .get(entity)
250 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
251 }
252
253 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
254 self.relation_embeddings
255 .get(relation)
256 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
257 }
258
259 fn score_triple(&self, triple: &Triple) -> f32 {
260 if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
263 self.get_entity_embedding(&triple.subject),
264 self.get_relation_embedding(&triple.predicate),
265 self.get_entity_embedding(&triple.object),
266 ) {
267 let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
269 predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
270 } else {
271 0.0
272 }
273 }
274
275 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
276 if let (Some(head_emb), Some(rel_emb)) = (
277 self.get_entity_embedding(head),
278 self.get_relation_embedding(relation),
279 ) {
280 let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
281
282 let mut scores = Vec::new();
283 for entity in &self.entities {
284 if entity != head {
285 if let Some(entity_emb) = self.get_entity_embedding(entity) {
286 let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
287 scores.push((entity.clone(), score));
288 }
289 }
290 }
291
292 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
293 scores.into_iter().take(k).collect()
294 } else {
295 Vec::new()
296 }
297 }
298
299 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
300 if let (Some(rel_emb), Some(tail_emb)) = (
301 self.get_relation_embedding(relation),
302 self.get_entity_embedding(tail),
303 ) {
304 let mut scores = Vec::new();
305 for entity in &self.entities {
306 if entity != tail {
307 if let Some(entity_emb) = self.get_entity_embedding(entity) {
308 let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
309 let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
310 scores.push((entity.clone(), score));
311 }
312 }
313 }
314
315 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
316 scores.into_iter().take(k).collect()
317 } else {
318 Vec::new()
319 }
320 }
321
322 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
323 HashMap::new()
326 }
327
328 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
329 HashMap::new()
331 }
332}
333
334pub struct GraphSAGE {
336 config: KGEmbeddingConfig,
337 entity_embeddings: HashMap<String, DVector<f32>>,
338 relation_embeddings: HashMap<String, DVector<f32>>,
339 entities: Vec<String>,
340 relations: Vec<String>,
341 graph: HashMap<String, Vec<String>>, aggregator_type: AggregatorType,
343 num_layers: usize,
344 sample_size: usize,
345 sampling_strategy: SamplingStrategy,
346}
347
348#[derive(Debug, Clone, Copy)]
349pub enum AggregatorType {
350 Mean,
351 LSTM,
352 Pool,
353 Attention,
354}
355
356#[derive(Debug, Clone, Copy)]
357pub enum SamplingStrategy {
358 Uniform, Degree, PageRank, Recent, }
363
364impl GraphSAGE {
365 pub fn new(config: KGEmbeddingConfig) -> Self {
366 Self {
367 config,
368 entity_embeddings: HashMap::new(),
369 relation_embeddings: HashMap::new(),
370 entities: Vec::new(),
371 relations: Vec::new(),
372 graph: HashMap::new(),
373 aggregator_type: AggregatorType::Mean,
374 num_layers: 2,
375 sample_size: 10, sampling_strategy: SamplingStrategy::Uniform,
377 }
378 }
379
380 pub fn with_aggregator(mut self, aggregator: AggregatorType) -> Self {
381 self.aggregator_type = aggregator;
382 self
383 }
384
385 pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
386 self.sampling_strategy = strategy;
387 self
388 }
389
390 pub fn with_sample_size(mut self, size: usize) -> Self {
391 self.sample_size = size;
392 self
393 }
394
395 pub fn dimensions(&self) -> usize {
397 self.config.dimensions
398 }
399
400 fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
402 let mut entities = std::collections::HashSet::new();
404 let mut relations = std::collections::HashSet::new();
405
406 for triple in triples {
407 entities.insert(triple.subject.clone());
408 entities.insert(triple.object.clone());
409 relations.insert(triple.predicate.clone());
410 }
411
412 self.entities = entities.into_iter().collect();
413 self.relations = relations.into_iter().collect();
414
415 self.build_graph(triples);
417
418 let mut rng = if let Some(seed) = self.config.random_seed {
420 Random::seed(seed)
421 } else {
422 Random::seed(42)
423 };
424
425 let normal = Normal::new(0.0, 0.1)
426 .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
427
428 for entity in &self.entities {
429 let embedding: Vec<f32> = (0..self.config.dimensions)
430 .map(|_| normal.sample(&mut rng))
431 .collect();
432 self.entity_embeddings
433 .insert(entity.clone(), DVector::from_vec(embedding));
434 }
435
436 for relation in &self.relations {
437 let embedding: Vec<f32> = (0..self.config.dimensions)
438 .map(|_| normal.sample(&mut rng))
439 .collect();
440 self.relation_embeddings
441 .insert(relation.clone(), DVector::from_vec(embedding));
442 }
443
444 Ok(())
445 }
446
447 fn build_graph(&mut self, triples: &[Triple]) {
449 for triple in triples {
450 self.graph
451 .entry(triple.subject.clone())
452 .or_default()
453 .push(triple.object.clone());
454
455 self.graph
456 .entry(triple.object.clone())
457 .or_default()
458 .push(triple.subject.clone());
459 }
460 }
461
462 fn sample_neighbors(&self, node: &str, rng: &mut impl Rng) -> Vec<String> {
464 if let Some(neighbors) = self.graph.get(node) {
465 if neighbors.len() <= self.sample_size {
466 neighbors.clone()
467 } else {
468 match self.sampling_strategy {
469 SamplingStrategy::Uniform => {
470 let mut sampled = Vec::new();
473 let sample_size = std::cmp::min(self.sample_size, neighbors.len());
474 for (i, neighbor) in neighbors.iter().enumerate() {
475 if sampled.len() < sample_size {
476 sampled.push(neighbor.clone());
477 } else {
478 let j = rng.gen_range(0..=i);
479 if j < sample_size {
480 sampled[j] = neighbor.clone();
481 }
482 }
483 }
484 sampled
485 }
486 SamplingStrategy::Degree => self.degree_based_sampling(neighbors, rng),
487 SamplingStrategy::PageRank => {
488 self.degree_based_sampling(neighbors, rng)
490 }
491 SamplingStrategy::Recent => {
492 neighbors
494 .iter()
495 .rev()
496 .take(self.sample_size)
497 .cloned()
498 .collect()
499 }
500 }
501 }
502 } else {
503 Vec::new()
504 }
505 }
506
507 fn degree_based_sampling(&self, neighbors: &[String], rng: &mut impl Rng) -> Vec<String> {
509 let mut neighbor_degrees: Vec<(String, usize)> = neighbors
510 .iter()
511 .map(|neighbor| {
512 let degree = self.graph.get(neighbor).map(|n| n.len()).unwrap_or(0);
513 (neighbor.clone(), degree)
514 })
515 .collect();
516
517 neighbor_degrees.sort_by(|a, b| {
519 let degree_cmp = b.1.cmp(&a.1);
520 if degree_cmp == std::cmp::Ordering::Equal {
521 if rng.gen_bool(0.5) {
523 std::cmp::Ordering::Greater
524 } else {
525 std::cmp::Ordering::Less
526 }
527 } else {
528 degree_cmp
529 }
530 });
531
532 neighbor_degrees
533 .into_iter()
534 .take(self.sample_size)
535 .map(|(neighbor, _)| neighbor)
536 .collect()
537 }
538
539 fn aggregate_neighbors(&self, neighbors: &[String]) -> Result<DVector<f32>> {
541 if neighbors.is_empty() {
542 return Ok(DVector::zeros(self.config.dimensions));
543 }
544
545 match self.aggregator_type {
546 AggregatorType::Mean => {
547 let mut sum = DVector::zeros(self.config.dimensions);
548 let mut count = 0;
549
550 for neighbor in neighbors {
551 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
552 sum += embedding;
553 count += 1;
554 }
555 }
556
557 if count > 0 {
558 Ok(sum / count as f32)
559 } else {
560 Ok(DVector::zeros(self.config.dimensions))
561 }
562 }
563 AggregatorType::Pool => {
564 let mut max_embedding =
566 DVector::from_element(self.config.dimensions, f32::NEG_INFINITY);
567
568 for neighbor in neighbors {
569 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
570 for i in 0..self.config.dimensions {
571 max_embedding[i] = max_embedding[i].max(embedding[i]);
572 }
573 }
574 }
575
576 for i in 0..self.config.dimensions {
578 if max_embedding[i] == f32::NEG_INFINITY {
579 max_embedding[i] = 0.0;
580 }
581 }
582
583 Ok(max_embedding)
584 }
585 AggregatorType::LSTM => {
586 self.lstm_aggregate(neighbors)
588 }
589 AggregatorType::Attention => {
590 self.attention_aggregate(neighbors)
592 }
593 }
594 }
595
596 fn lstm_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
598 if neighbors.is_empty() {
599 return Ok(DVector::zeros(self.config.dimensions));
600 }
601
602 let mut cell_state = DVector::zeros(self.config.dimensions);
604 let mut hidden_state = DVector::zeros(self.config.dimensions);
605
606 for neighbor in neighbors {
607 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
608 let forget_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp())); let input_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
611 let candidate = embedding.map(|x| x.tanh()); cell_state =
615 cell_state.component_mul(&forget_gate) + input_gate.component_mul(&candidate);
616
617 let output_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
619 hidden_state = output_gate.component_mul(&cell_state.map(|x| x.tanh()));
620 }
621 }
622
623 Ok(hidden_state)
624 }
625
626 fn attention_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
628 if neighbors.is_empty() {
629 return Ok(DVector::zeros(self.config.dimensions));
630 }
631
632 let neighbor_embeddings: Vec<&DVector<f32>> = neighbors
633 .iter()
634 .filter_map(|neighbor| self.entity_embeddings.get(neighbor))
635 .collect();
636
637 if neighbor_embeddings.is_empty() {
638 return Ok(DVector::zeros(self.config.dimensions));
639 }
640
641 let mut attention_scores = Vec::new();
643 let mut weighted_sum = DVector::zeros(self.config.dimensions);
644
645 let query = DVector::from_element(self.config.dimensions, 1.0); for embedding in &neighbor_embeddings {
649 let score = query.dot(embedding).exp(); attention_scores.push(score);
651 }
652
653 let total_score: f32 = attention_scores.iter().sum();
655 if total_score > 0.0 {
656 for score in &mut attention_scores {
657 *score /= total_score;
658 }
659 }
660
661 for (embedding, &score) in neighbor_embeddings.iter().zip(attention_scores.iter()) {
663 weighted_sum += *embedding * score;
664 }
665
666 Ok(weighted_sum)
667 }
668
669 fn forward_node(&self, node: &str, rng: &mut impl Rng) -> Result<DVector<f32>> {
671 let neighbors = self.sample_neighbors(node, rng);
672 let neighbor_aggregate = self.aggregate_neighbors(&neighbors)?;
673
674 if let Some(node_embedding) = self.entity_embeddings.get(node) {
675 Ok(node_embedding + neighbor_aggregate)
678 } else {
679 Ok(neighbor_aggregate)
680 }
681 }
682}
683
684impl KGEmbeddingModel for GraphSAGE {
685 fn train(&mut self, triples: &[Triple]) -> Result<()> {
686 self.initialize(triples)?;
687
688 let mut rng = if let Some(seed) = self.config.random_seed {
689 Random::seed(seed)
690 } else {
691 Random::seed(42)
692 };
693
694 for epoch in 0..self.config.epochs {
695 let mut new_embeddings = HashMap::new();
696
697 for entity in &self.entities {
699 let new_embedding = self.forward_node(entity, &mut rng)?;
700 new_embeddings.insert(entity.clone(), new_embedding);
701 }
702
703 self.entity_embeddings = new_embeddings;
705
706 if epoch % 10 == 0 {
707 println!("GraphSAGE training epoch {}/{}", epoch, self.config.epochs);
708 }
709 }
710
711 Ok(())
712 }
713
714 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
715 self.entity_embeddings
716 .get(entity)
717 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
718 }
719
720 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
721 self.relation_embeddings
722 .get(relation)
723 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
724 }
725
726 fn score_triple(&self, triple: &Triple) -> f32 {
727 if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
728 self.get_entity_embedding(&triple.subject),
729 self.get_relation_embedding(&triple.predicate),
730 self.get_entity_embedding(&triple.object),
731 ) {
732 let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
733 predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
734 } else {
735 0.0
736 }
737 }
738
739 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
740 if let (Some(head_emb), Some(rel_emb)) = (
741 self.get_entity_embedding(head),
742 self.get_relation_embedding(relation),
743 ) {
744 let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
745
746 let mut scores = Vec::new();
747 for entity in &self.entities {
748 if entity != head {
749 if let Some(entity_emb) = self.get_entity_embedding(entity) {
750 let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
751 scores.push((entity.clone(), score));
752 }
753 }
754 }
755
756 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
757 scores.into_iter().take(k).collect()
758 } else {
759 Vec::new()
760 }
761 }
762
763 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
764 if let (Some(rel_emb), Some(tail_emb)) = (
765 self.get_relation_embedding(relation),
766 self.get_entity_embedding(tail),
767 ) {
768 let mut scores = Vec::new();
769 for entity in &self.entities {
770 if entity != tail {
771 if let Some(entity_emb) = self.get_entity_embedding(entity) {
772 let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
773 let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
774 scores.push((entity.clone(), score));
775 }
776 }
777 }
778
779 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
780 scores.into_iter().take(k).collect()
781 } else {
782 Vec::new()
783 }
784 }
785
786 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
787 HashMap::new()
788 }
789
790 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
791 HashMap::new()
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798
799 #[test]
800 fn test_gcn_creation() {
801 let config = KGEmbeddingConfig {
802 model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
803 dimensions: 64,
804 learning_rate: 0.01,
805 margin: 1.0,
806 negative_samples: 5,
807 batch_size: 32,
808 epochs: 10,
809 norm: 2,
810 random_seed: Some(42),
811 regularization: 0.01,
812 };
813
814 let gcn = GCN::new(config);
815 assert_eq!(gcn.num_layers, 2);
816 }
817
818 #[test]
819 fn test_graphsage_creation() {
820 let config = KGEmbeddingConfig {
821 model: crate::kg_embeddings::KGEmbeddingModelType::GraphSAGE,
822 dimensions: 64,
823 learning_rate: 0.01,
824 margin: 1.0,
825 negative_samples: 5,
826 batch_size: 32,
827 epochs: 10,
828 norm: 2,
829 random_seed: Some(42),
830 regularization: 0.01,
831 };
832
833 let graphsage = GraphSAGE::new(config);
834 assert_eq!(graphsage.sample_size, 10);
835 }
836
837 #[test]
838 fn test_gnn_training() {
839 let config = KGEmbeddingConfig {
840 model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
841 dimensions: 32,
842 learning_rate: 0.01,
843 margin: 1.0,
844 negative_samples: 5,
845 batch_size: 16,
846 epochs: 5,
847 norm: 2,
848 random_seed: Some(42),
849 regularization: 0.01,
850 };
851
852 let mut gcn = GCN::new(config);
853
854 let triples = vec![
855 Triple::new(
856 "entity1".to_string(),
857 "relation1".to_string(),
858 "entity2".to_string(),
859 ),
860 Triple::new(
861 "entity2".to_string(),
862 "relation2".to_string(),
863 "entity3".to_string(),
864 ),
865 Triple::new(
866 "entity1".to_string(),
867 "relation3".to_string(),
868 "entity3".to_string(),
869 ),
870 ];
871
872 gcn.train(&triples).unwrap();
874
875 assert!(gcn.get_entity_embedding("entity1").is_some());
877 assert!(gcn.get_entity_embedding("entity2").is_some());
878 assert!(gcn.get_entity_embedding("entity3").is_some());
879 }
880}