1use crate::random_utils::NormalSampler as Normal;
8use crate::{
9 kg_embeddings::{KGEmbeddingConfig, KGEmbeddingModel, Triple},
10 Vector,
11};
12use anyhow::{anyhow, Result};
13use nalgebra::{DMatrix, DVector};
14use scirs2_core::random::{Random, Rng};
15use std::collections::HashMap;
16
17pub struct GCN {
19 config: KGEmbeddingConfig,
20 entity_embeddings: HashMap<String, DVector<f32>>,
21 relation_embeddings: HashMap<String, DVector<f32>>,
22 entities: Vec<String>,
23 relations: Vec<String>,
24 adjacency_matrix: Option<DMatrix<f32>>,
25 weight_matrices: Vec<DMatrix<f32>>,
26 num_layers: usize,
27}
28
29impl GCN {
30 pub fn new(config: KGEmbeddingConfig) -> Self {
31 let num_layers = 2; Self {
33 config,
34 entity_embeddings: HashMap::new(),
35 relation_embeddings: HashMap::new(),
36 entities: Vec::new(),
37 relations: Vec::new(),
38 adjacency_matrix: None,
39 weight_matrices: Vec::new(),
40 num_layers,
41 }
42 }
43
44 pub fn with_layers(config: KGEmbeddingConfig, num_layers: usize) -> Self {
46 Self {
47 config,
48 entity_embeddings: HashMap::new(),
49 relation_embeddings: HashMap::new(),
50 entities: Vec::new(),
51 relations: Vec::new(),
52 adjacency_matrix: None,
53 weight_matrices: Vec::new(),
54 num_layers,
55 }
56 }
57
58 fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
60 let mut entities = std::collections::HashSet::new();
62 let mut relations = std::collections::HashSet::new();
63
64 for triple in triples {
65 entities.insert(triple.subject.clone());
66 entities.insert(triple.object.clone());
67 relations.insert(triple.predicate.clone());
68 }
69
70 self.entities = entities.into_iter().collect();
71 self.relations = relations.into_iter().collect();
72
73 let _num_entities = self.entities.len();
74
75 let mut rng = if let Some(seed) = self.config.random_seed {
77 Random::seed(seed)
78 } else {
79 Random::seed(42)
80 };
81
82 let normal = Normal::new(0.0, 0.1)
83 .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
84
85 for entity in &self.entities {
86 let embedding: Vec<f32> = (0..self.config.dimensions)
87 .map(|_| normal.sample(&mut rng))
88 .collect();
89 self.entity_embeddings
90 .insert(entity.clone(), DVector::from_vec(embedding));
91 }
92
93 for relation in &self.relations {
94 let embedding: Vec<f32> = (0..self.config.dimensions)
95 .map(|_| normal.sample(&mut rng))
96 .collect();
97 self.relation_embeddings
98 .insert(relation.clone(), DVector::from_vec(embedding));
99 }
100
101 self.build_adjacency_matrix(triples)?;
103
104 self.weight_matrices.clear();
106 for _ in 0..self.num_layers {
107 let weight_matrix =
108 DMatrix::from_fn(self.config.dimensions, self.config.dimensions, |_, _| {
109 normal.sample(&mut rng)
110 });
111 self.weight_matrices.push(weight_matrix);
112 }
113
114 Ok(())
115 }
116
117 fn build_adjacency_matrix(&mut self, triples: &[Triple]) -> Result<()> {
119 let num_entities = self.entities.len();
120 let mut adj_matrix = DMatrix::zeros(num_entities, num_entities);
121
122 let entity_to_index: HashMap<String, usize> = self
124 .entities
125 .iter()
126 .enumerate()
127 .map(|(i, entity)| (entity.clone(), i))
128 .collect();
129
130 for triple in triples {
132 if let (Some(&subject_idx), Some(&object_idx)) = (
133 entity_to_index.get(&triple.subject),
134 entity_to_index.get(&triple.object),
135 ) {
136 adj_matrix[(subject_idx, object_idx)] = 1.0;
137 adj_matrix[(object_idx, subject_idx)] = 1.0; }
139 }
140
141 for i in 0..num_entities {
143 adj_matrix[(i, i)] = 1.0;
144 }
145
146 self.adjacency_matrix = Some(self.normalize_adjacency_matrix(adj_matrix));
148
149 Ok(())
150 }
151
152 fn normalize_adjacency_matrix(&self, mut adj_matrix: DMatrix<f32>) -> DMatrix<f32> {
154 let num_nodes = adj_matrix.nrows();
155
156 let mut degrees = Vec::with_capacity(num_nodes);
158 for i in 0..num_nodes {
159 let degree: f32 = (0..num_nodes).map(|j| adj_matrix[(i, j)]).sum();
160 degrees.push(if degree > 0.0 {
161 1.0 / degree.sqrt()
162 } else {
163 0.0
164 });
165 }
166
167 for i in 0..num_nodes {
169 for j in 0..num_nodes {
170 adj_matrix[(i, j)] *= degrees[i] * degrees[j];
171 }
172 }
173
174 adj_matrix
175 }
176
177 fn forward_pass(&self, features: &DMatrix<f32>) -> Result<DMatrix<f32>> {
179 let adj_matrix = self
180 .adjacency_matrix
181 .as_ref()
182 .ok_or_else(|| anyhow!("Adjacency matrix not initialized"))?;
183
184 let mut hidden = features.clone();
185
186 for layer_idx in 0..self.num_layers {
187 let weight = &self.weight_matrices[layer_idx];
188
189 let linear_transform = &hidden * weight;
191 hidden = adj_matrix * &linear_transform;
192
193 if layer_idx < self.num_layers - 1 {
195 hidden = hidden.map(|x| x.max(0.0));
196 }
197 }
198
199 Ok(hidden)
200 }
201
202 fn train_gcn(&mut self, _triples: &[Triple]) -> Result<()> {
204 let num_entities = self.entities.len();
206 let mut features = DMatrix::zeros(num_entities, self.config.dimensions);
207
208 for (i, entity) in self.entities.iter().enumerate() {
209 if let Some(embedding) = self.entity_embeddings.get(entity) {
210 for (j, &value) in embedding.iter().enumerate() {
211 features[(i, j)] = value;
212 }
213 }
214 }
215
216 let updated_features = self.forward_pass(&features)?;
218
219 for (i, entity) in self.entities.iter().enumerate() {
221 let new_embedding: Vec<f32> = (0..self.config.dimensions)
222 .map(|j| updated_features[(i, j)])
223 .collect();
224 self.entity_embeddings
225 .insert(entity.clone(), DVector::from_vec(new_embedding));
226 }
227
228 Ok(())
229 }
230}
231
232impl KGEmbeddingModel for GCN {
233 fn train(&mut self, triples: &[Triple]) -> Result<()> {
234 self.initialize(triples)?;
235
236 for epoch in 0..self.config.epochs {
237 self.train_gcn(triples)?;
238
239 if epoch % 10 == 0 {
240 println!("GCN training epoch {}/{}", epoch, self.config.epochs);
241 }
242 }
243
244 Ok(())
245 }
246
247 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
248 self.entity_embeddings
249 .get(entity)
250 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
251 }
252
253 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
254 self.relation_embeddings
255 .get(relation)
256 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
257 }
258
259 fn score_triple(&self, triple: &Triple) -> f32 {
260 if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
263 self.get_entity_embedding(&triple.subject),
264 self.get_relation_embedding(&triple.predicate),
265 self.get_entity_embedding(&triple.object),
266 ) {
267 let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
269 predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
270 } else {
271 0.0
272 }
273 }
274
275 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
276 if let (Some(head_emb), Some(rel_emb)) = (
277 self.get_entity_embedding(head),
278 self.get_relation_embedding(relation),
279 ) {
280 let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
281
282 let mut scores = Vec::new();
283 for entity in &self.entities {
284 if entity != head {
285 if let Some(entity_emb) = self.get_entity_embedding(entity) {
286 let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
287 scores.push((entity.clone(), score));
288 }
289 }
290 }
291
292 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
293 scores.into_iter().take(k).collect()
294 } else {
295 Vec::new()
296 }
297 }
298
299 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
300 if let (Some(rel_emb), Some(tail_emb)) = (
301 self.get_relation_embedding(relation),
302 self.get_entity_embedding(tail),
303 ) {
304 let mut scores = Vec::new();
305 for entity in &self.entities {
306 if entity != tail {
307 if let Some(entity_emb) = self.get_entity_embedding(entity) {
308 let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
309 let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
310 scores.push((entity.clone(), score));
311 }
312 }
313 }
314
315 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
316 scores.into_iter().take(k).collect()
317 } else {
318 Vec::new()
319 }
320 }
321
322 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
323 HashMap::new()
326 }
327
328 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
329 HashMap::new()
331 }
332}
333
334pub struct GraphSAGE {
336 config: KGEmbeddingConfig,
337 entity_embeddings: HashMap<String, DVector<f32>>,
338 relation_embeddings: HashMap<String, DVector<f32>>,
339 entities: Vec<String>,
340 relations: Vec<String>,
341 graph: HashMap<String, Vec<String>>, aggregator_type: AggregatorType,
343 num_layers: usize,
344 sample_size: usize,
345 sampling_strategy: SamplingStrategy,
346}
347
348#[derive(Debug, Clone, Copy)]
349pub enum AggregatorType {
350 Mean,
351 LSTM,
352 Pool,
353 Attention,
354}
355
356#[derive(Debug, Clone, Copy)]
357pub enum SamplingStrategy {
358 Uniform, Degree, PageRank, Recent, }
363
364impl GraphSAGE {
365 pub fn new(config: KGEmbeddingConfig) -> Self {
366 Self {
367 config,
368 entity_embeddings: HashMap::new(),
369 relation_embeddings: HashMap::new(),
370 entities: Vec::new(),
371 relations: Vec::new(),
372 graph: HashMap::new(),
373 aggregator_type: AggregatorType::Mean,
374 num_layers: 2,
375 sample_size: 10, sampling_strategy: SamplingStrategy::Uniform,
377 }
378 }
379
380 pub fn with_aggregator(mut self, aggregator: AggregatorType) -> Self {
381 self.aggregator_type = aggregator;
382 self
383 }
384
385 pub fn with_sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
386 self.sampling_strategy = strategy;
387 self
388 }
389
390 pub fn with_sample_size(mut self, size: usize) -> Self {
391 self.sample_size = size;
392 self
393 }
394
395 pub fn dimensions(&self) -> usize {
397 self.config.dimensions
398 }
399
400 fn initialize(&mut self, triples: &[Triple]) -> Result<()> {
402 let mut entities = std::collections::HashSet::new();
404 let mut relations = std::collections::HashSet::new();
405
406 for triple in triples {
407 entities.insert(triple.subject.clone());
408 entities.insert(triple.object.clone());
409 relations.insert(triple.predicate.clone());
410 }
411
412 self.entities = entities.into_iter().collect();
413 self.relations = relations.into_iter().collect();
414
415 self.build_graph(triples);
417
418 let mut rng = if let Some(seed) = self.config.random_seed {
420 Random::seed(seed)
421 } else {
422 Random::seed(42)
423 };
424
425 let normal = Normal::new(0.0, 0.1)
426 .map_err(|e| anyhow!("Failed to create normal distribution: {}", e))?;
427
428 for entity in &self.entities {
429 let embedding: Vec<f32> = (0..self.config.dimensions)
430 .map(|_| normal.sample(&mut rng))
431 .collect();
432 self.entity_embeddings
433 .insert(entity.clone(), DVector::from_vec(embedding));
434 }
435
436 for relation in &self.relations {
437 let embedding: Vec<f32> = (0..self.config.dimensions)
438 .map(|_| normal.sample(&mut rng))
439 .collect();
440 self.relation_embeddings
441 .insert(relation.clone(), DVector::from_vec(embedding));
442 }
443
444 Ok(())
445 }
446
447 fn build_graph(&mut self, triples: &[Triple]) {
449 for triple in triples {
450 self.graph
451 .entry(triple.subject.clone())
452 .or_default()
453 .push(triple.object.clone());
454
455 self.graph
456 .entry(triple.object.clone())
457 .or_default()
458 .push(triple.subject.clone());
459 }
460 }
461
462 #[allow(deprecated)]
464 fn sample_neighbors(&self, node: &str, rng: &mut impl Rng) -> Vec<String> {
465 if let Some(neighbors) = self.graph.get(node) {
466 if neighbors.len() <= self.sample_size {
467 neighbors.clone()
468 } else {
469 match self.sampling_strategy {
470 SamplingStrategy::Uniform => {
471 let mut sampled = Vec::new();
474 let sample_size = std::cmp::min(self.sample_size, neighbors.len());
475 for (i, neighbor) in neighbors.iter().enumerate() {
476 if sampled.len() < sample_size {
477 sampled.push(neighbor.clone());
478 } else {
479 let j = rng.gen_range(0..=i);
480 if j < sample_size {
481 sampled[j] = neighbor.clone();
482 }
483 }
484 }
485 sampled
486 }
487 SamplingStrategy::Degree => self.degree_based_sampling(neighbors, rng),
488 SamplingStrategy::PageRank => {
489 self.degree_based_sampling(neighbors, rng)
491 }
492 SamplingStrategy::Recent => {
493 neighbors
495 .iter()
496 .rev()
497 .take(self.sample_size)
498 .cloned()
499 .collect()
500 }
501 }
502 }
503 } else {
504 Vec::new()
505 }
506 }
507
508 #[allow(deprecated)]
510 fn degree_based_sampling(&self, neighbors: &[String], rng: &mut impl Rng) -> Vec<String> {
511 let mut neighbor_degrees: Vec<(String, usize)> = neighbors
512 .iter()
513 .map(|neighbor| {
514 let degree = self.graph.get(neighbor).map(|n| n.len()).unwrap_or(0);
515 (neighbor.clone(), degree)
516 })
517 .collect();
518
519 neighbor_degrees.sort_by(|a, b| {
521 let degree_cmp = b.1.cmp(&a.1);
522 if degree_cmp == std::cmp::Ordering::Equal {
523 if rng.gen_bool(0.5) {
525 std::cmp::Ordering::Greater
526 } else {
527 std::cmp::Ordering::Less
528 }
529 } else {
530 degree_cmp
531 }
532 });
533
534 neighbor_degrees
535 .into_iter()
536 .take(self.sample_size)
537 .map(|(neighbor, _)| neighbor)
538 .collect()
539 }
540
541 fn aggregate_neighbors(&self, neighbors: &[String]) -> Result<DVector<f32>> {
543 if neighbors.is_empty() {
544 return Ok(DVector::zeros(self.config.dimensions));
545 }
546
547 match self.aggregator_type {
548 AggregatorType::Mean => {
549 let mut sum = DVector::zeros(self.config.dimensions);
550 let mut count = 0;
551
552 for neighbor in neighbors {
553 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
554 sum += embedding;
555 count += 1;
556 }
557 }
558
559 if count > 0 {
560 Ok(sum / count as f32)
561 } else {
562 Ok(DVector::zeros(self.config.dimensions))
563 }
564 }
565 AggregatorType::Pool => {
566 let mut max_embedding =
568 DVector::from_element(self.config.dimensions, f32::NEG_INFINITY);
569
570 for neighbor in neighbors {
571 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
572 for i in 0..self.config.dimensions {
573 max_embedding[i] = max_embedding[i].max(embedding[i]);
574 }
575 }
576 }
577
578 for i in 0..self.config.dimensions {
580 if max_embedding[i] == f32::NEG_INFINITY {
581 max_embedding[i] = 0.0;
582 }
583 }
584
585 Ok(max_embedding)
586 }
587 AggregatorType::LSTM => {
588 self.lstm_aggregate(neighbors)
590 }
591 AggregatorType::Attention => {
592 self.attention_aggregate(neighbors)
594 }
595 }
596 }
597
598 fn lstm_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
600 if neighbors.is_empty() {
601 return Ok(DVector::zeros(self.config.dimensions));
602 }
603
604 let mut cell_state = DVector::zeros(self.config.dimensions);
606 let mut hidden_state = DVector::zeros(self.config.dimensions);
607
608 for neighbor in neighbors {
609 if let Some(embedding) = self.entity_embeddings.get(neighbor) {
610 let forget_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp())); let input_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
613 let candidate = embedding.map(|x| x.tanh()); cell_state =
617 cell_state.component_mul(&forget_gate) + input_gate.component_mul(&candidate);
618
619 let output_gate = embedding.map(|x| 1.0 / (1.0 + (-x).exp()));
621 hidden_state = output_gate.component_mul(&cell_state.map(|x| x.tanh()));
622 }
623 }
624
625 Ok(hidden_state)
626 }
627
628 fn attention_aggregate(&self, neighbors: &[String]) -> Result<DVector<f32>> {
630 if neighbors.is_empty() {
631 return Ok(DVector::zeros(self.config.dimensions));
632 }
633
634 let neighbor_embeddings: Vec<&DVector<f32>> = neighbors
635 .iter()
636 .filter_map(|neighbor| self.entity_embeddings.get(neighbor))
637 .collect();
638
639 if neighbor_embeddings.is_empty() {
640 return Ok(DVector::zeros(self.config.dimensions));
641 }
642
643 let mut attention_scores = Vec::new();
645 let mut weighted_sum = DVector::zeros(self.config.dimensions);
646
647 let query = DVector::from_element(self.config.dimensions, 1.0); for embedding in &neighbor_embeddings {
651 let score = query.dot(embedding).exp(); attention_scores.push(score);
653 }
654
655 let total_score: f32 = attention_scores.iter().sum();
657 if total_score > 0.0 {
658 for score in &mut attention_scores {
659 *score /= total_score;
660 }
661 }
662
663 for (embedding, &score) in neighbor_embeddings.iter().zip(attention_scores.iter()) {
665 weighted_sum += *embedding * score;
666 }
667
668 Ok(weighted_sum)
669 }
670
671 fn forward_node(&self, node: &str, rng: &mut impl Rng) -> Result<DVector<f32>> {
673 let neighbors = self.sample_neighbors(node, rng);
674 let neighbor_aggregate = self.aggregate_neighbors(&neighbors)?;
675
676 if let Some(node_embedding) = self.entity_embeddings.get(node) {
677 Ok(node_embedding + neighbor_aggregate)
680 } else {
681 Ok(neighbor_aggregate)
682 }
683 }
684}
685
686impl KGEmbeddingModel for GraphSAGE {
687 fn train(&mut self, triples: &[Triple]) -> Result<()> {
688 self.initialize(triples)?;
689
690 let mut rng = if let Some(seed) = self.config.random_seed {
691 Random::seed(seed)
692 } else {
693 Random::seed(42)
694 };
695
696 for epoch in 0..self.config.epochs {
697 let mut new_embeddings = HashMap::new();
698
699 for entity in &self.entities {
701 let new_embedding = self.forward_node(entity, &mut rng)?;
702 new_embeddings.insert(entity.clone(), new_embedding);
703 }
704
705 self.entity_embeddings = new_embeddings;
707
708 if epoch % 10 == 0 {
709 println!("GraphSAGE training epoch {}/{}", epoch, self.config.epochs);
710 }
711 }
712
713 Ok(())
714 }
715
716 fn get_entity_embedding(&self, entity: &str) -> Option<Vector> {
717 self.entity_embeddings
718 .get(entity)
719 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
720 }
721
722 fn get_relation_embedding(&self, relation: &str) -> Option<Vector> {
723 self.relation_embeddings
724 .get(relation)
725 .map(|embedding| Vector::new(embedding.as_slice().to_vec()))
726 }
727
728 fn score_triple(&self, triple: &Triple) -> f32 {
729 if let (Some(subj_emb), Some(rel_emb), Some(obj_emb)) = (
730 self.get_entity_embedding(&triple.subject),
731 self.get_relation_embedding(&triple.predicate),
732 self.get_entity_embedding(&triple.object),
733 ) {
734 let predicted = subj_emb.add(&rel_emb).unwrap_or(subj_emb);
735 predicted.cosine_similarity(&obj_emb).unwrap_or(0.0)
736 } else {
737 0.0
738 }
739 }
740
741 fn predict_tail(&self, head: &str, relation: &str, k: usize) -> Vec<(String, f32)> {
742 if let (Some(head_emb), Some(rel_emb)) = (
743 self.get_entity_embedding(head),
744 self.get_relation_embedding(relation),
745 ) {
746 let query = head_emb.add(&rel_emb).unwrap_or(head_emb);
747
748 let mut scores = Vec::new();
749 for entity in &self.entities {
750 if entity != head {
751 if let Some(entity_emb) = self.get_entity_embedding(entity) {
752 let score = query.cosine_similarity(&entity_emb).unwrap_or(0.0);
753 scores.push((entity.clone(), score));
754 }
755 }
756 }
757
758 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
759 scores.into_iter().take(k).collect()
760 } else {
761 Vec::new()
762 }
763 }
764
765 fn predict_head(&self, relation: &str, tail: &str, k: usize) -> Vec<(String, f32)> {
766 if let (Some(rel_emb), Some(tail_emb)) = (
767 self.get_relation_embedding(relation),
768 self.get_entity_embedding(tail),
769 ) {
770 let mut scores = Vec::new();
771 for entity in &self.entities {
772 if entity != tail {
773 if let Some(entity_emb) = self.get_entity_embedding(entity) {
774 let predicted = entity_emb.add(&rel_emb).unwrap_or(entity_emb);
775 let score = predicted.cosine_similarity(&tail_emb).unwrap_or(0.0);
776 scores.push((entity.clone(), score));
777 }
778 }
779 }
780
781 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
782 scores.into_iter().take(k).collect()
783 } else {
784 Vec::new()
785 }
786 }
787
788 fn get_entity_embeddings(&self) -> HashMap<String, Vector> {
789 HashMap::new()
790 }
791
792 fn get_relation_embeddings(&self) -> HashMap<String, Vector> {
793 HashMap::new()
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_gcn_creation() {
803 let config = KGEmbeddingConfig {
804 model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
805 dimensions: 64,
806 learning_rate: 0.01,
807 margin: 1.0,
808 negative_samples: 5,
809 batch_size: 32,
810 epochs: 10,
811 norm: 2,
812 random_seed: Some(42),
813 regularization: 0.01,
814 };
815
816 let gcn = GCN::new(config);
817 assert_eq!(gcn.num_layers, 2);
818 }
819
820 #[test]
821 fn test_graphsage_creation() {
822 let config = KGEmbeddingConfig {
823 model: crate::kg_embeddings::KGEmbeddingModelType::GraphSAGE,
824 dimensions: 64,
825 learning_rate: 0.01,
826 margin: 1.0,
827 negative_samples: 5,
828 batch_size: 32,
829 epochs: 10,
830 norm: 2,
831 random_seed: Some(42),
832 regularization: 0.01,
833 };
834
835 let graphsage = GraphSAGE::new(config);
836 assert_eq!(graphsage.sample_size, 10);
837 }
838
839 #[test]
840 fn test_gnn_training() {
841 let config = KGEmbeddingConfig {
842 model: crate::kg_embeddings::KGEmbeddingModelType::GCN,
843 dimensions: 32,
844 learning_rate: 0.01,
845 margin: 1.0,
846 negative_samples: 5,
847 batch_size: 16,
848 epochs: 5,
849 norm: 2,
850 random_seed: Some(42),
851 regularization: 0.01,
852 };
853
854 let mut gcn = GCN::new(config);
855
856 let triples = vec![
857 Triple::new(
858 "entity1".to_string(),
859 "relation1".to_string(),
860 "entity2".to_string(),
861 ),
862 Triple::new(
863 "entity2".to_string(),
864 "relation2".to_string(),
865 "entity3".to_string(),
866 ),
867 Triple::new(
868 "entity1".to_string(),
869 "relation3".to_string(),
870 "entity3".to_string(),
871 ),
872 ];
873
874 gcn.train(&triples).unwrap();
876
877 assert!(gcn.get_entity_embedding("entity1").is_some());
879 assert!(gcn.get_entity_embedding("entity2").is_some());
880 assert!(gcn.get_entity_embedding("entity3").is_some());
881 }
882}