1use crate::EmbeddingError;
10use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum AggregatorType {
17 Mean,
19 MaxPool { hidden_dim: usize },
21 MeanConcat,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct GraphSageConfig {
28 pub input_dim: usize,
30 pub hidden_dims: Vec<usize>,
32 pub output_dim: usize,
34 pub aggregator: AggregatorType,
36 pub num_samples: Vec<usize>,
38 pub dropout: f64,
40 pub learning_rate: f64,
42 pub epochs: usize,
44 pub batch_size: usize,
46 pub normalize_output: bool,
48 pub seed: u64,
50}
51
52impl Default for GraphSageConfig {
53 fn default() -> Self {
54 Self {
55 input_dim: 64,
56 hidden_dims: vec![256, 128],
57 output_dim: 64,
58 aggregator: AggregatorType::Mean,
59 num_samples: vec![25, 10],
60 dropout: 0.5,
61 learning_rate: 0.01,
62 epochs: 10,
63 batch_size: 512,
64 normalize_output: true,
65 seed: 42,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct GraphData {
73 pub node_features: Vec<Vec<f64>>,
75 pub adjacency: Vec<Vec<usize>>,
77 pub labels: Option<Vec<usize>>,
79}
80
81impl GraphData {
82 pub fn new(features: Vec<Vec<f64>>, adjacency: Vec<Vec<usize>>) -> Result<Self> {
86 let num_nodes = features.len();
87 if adjacency.len() != num_nodes {
88 return Err(anyhow!(
89 "Adjacency list length {} does not match number of nodes {}",
90 adjacency.len(),
91 num_nodes
92 ));
93 }
94 for (i, neighbors) in adjacency.iter().enumerate() {
96 for &neighbor in neighbors {
97 if neighbor >= num_nodes {
98 return Err(anyhow!(
99 "Node {} has neighbor index {} which is out of bounds (num_nodes={})",
100 i,
101 neighbor,
102 num_nodes
103 ));
104 }
105 }
106 }
107 if let Some(first) = features.first() {
109 let dim = first.len();
110 for (i, feat) in features.iter().enumerate() {
111 if feat.len() != dim {
112 return Err(anyhow!(
113 "Node {} has feature dimension {} but expected {}",
114 i,
115 feat.len(),
116 dim
117 ));
118 }
119 }
120 }
121 Ok(Self {
122 node_features: features,
123 adjacency,
124 labels: None,
125 })
126 }
127
128 pub fn num_nodes(&self) -> usize {
130 self.node_features.len()
131 }
132
133 pub fn feature_dim(&self) -> usize {
135 self.node_features.first().map(|f| f.len()).unwrap_or(0)
136 }
137
138 pub fn neighbors(&self, node: usize) -> &[usize] {
140 if node < self.adjacency.len() {
141 &self.adjacency[node]
142 } else {
143 &[]
144 }
145 }
146
147 pub fn sample_neighbors(&self, node: usize, k: usize, rng: &mut SimpleLcg) -> Vec<usize> {
149 let neighbors = self.neighbors(node);
150 if neighbors.is_empty() {
151 return Vec::new();
152 }
153 if neighbors.len() <= k {
154 return neighbors.to_vec();
155 }
156 let mut indices: Vec<usize> = (0..neighbors.len()).collect();
158 for i in 0..k {
159 let j = i + (rng.next_usize() % (indices.len() - i));
160 indices.swap(i, j);
161 }
162 indices[..k].iter().map(|&idx| neighbors[idx]).collect()
163 }
164
165 pub fn with_labels(mut self, labels: Vec<usize>) -> Result<Self> {
167 if labels.len() != self.num_nodes() {
168 return Err(anyhow!(
169 "Labels length {} does not match num_nodes {}",
170 labels.len(),
171 self.num_nodes()
172 ));
173 }
174 self.labels = Some(labels);
175 Ok(self)
176 }
177}
178
179#[derive(Debug, Clone)]
182pub struct SimpleLcg {
183 state: u64,
184}
185
186impl SimpleLcg {
187 pub fn new(seed: u64) -> Self {
189 Self {
190 state: seed.wrapping_add(1),
191 }
192 }
193
194 pub fn next_u64(&mut self) -> u64 {
196 self.state = self
198 .state
199 .wrapping_mul(6364136223846793005)
200 .wrapping_add(1442695040888963407);
201 self.state
202 }
203
204 pub fn next_usize(&mut self) -> usize {
206 self.next_u64() as usize
207 }
208
209 pub fn next_f64(&mut self) -> f64 {
211 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
212 }
213
214 pub fn next_f64_range(&mut self, scale: f64) -> f64 {
216 (self.next_f64() * 2.0 - 1.0) * scale
217 }
218}
219
220#[derive(Debug, Clone)]
222struct DenseLayer {
223 weights: Vec<Vec<f64>>, bias: Vec<f64>,
225 input_dim: usize,
226 output_dim: usize,
227}
228
229impl DenseLayer {
230 fn new(input_dim: usize, output_dim: usize, rng: &mut SimpleLcg) -> Self {
232 let scale = (6.0 / (input_dim + output_dim) as f64).sqrt();
233 let weights = (0..output_dim)
234 .map(|_| (0..input_dim).map(|_| rng.next_f64_range(scale)).collect())
235 .collect();
236 let bias = vec![0.0; output_dim];
237 Self {
238 weights,
239 bias,
240 input_dim,
241 output_dim,
242 }
243 }
244
245 fn forward(&self, input: &[f64]) -> Vec<f64> {
247 debug_assert_eq!(input.len(), self.input_dim);
248 let mut output = self.bias.clone();
249 for (i, row) in self.weights.iter().enumerate() {
250 for (j, &w) in row.iter().enumerate() {
251 output[i] += w * input[j];
252 }
253 }
254 output
255 }
256
257 fn relu(x: &[f64]) -> Vec<f64> {
259 x.iter().map(|&v| v.max(0.0)).collect()
260 }
261}
262
263#[derive(Debug, Clone)]
265struct SageLayer {
266 self_transform: DenseLayer,
268 neigh_transform: DenseLayer,
270 pool_mlp: Option<DenseLayer>,
272 output_dim: usize,
274}
275
276impl SageLayer {
277 fn new(
283 input_dim: usize,
284 neigh_dim: usize,
285 output_dim: usize,
286 pool_hidden: Option<usize>,
287 rng: &mut SimpleLcg,
288 ) -> Self {
289 let self_transform = DenseLayer::new(input_dim, output_dim, rng);
290 let neigh_transform = DenseLayer::new(neigh_dim, output_dim, rng);
291 let pool_mlp = pool_hidden.map(|hidden| DenseLayer::new(neigh_dim, hidden, rng));
292 Self {
293 self_transform,
294 neigh_transform,
295 pool_mlp,
296 output_dim,
297 }
298 }
299
300 fn aggregate_mean(neighbor_features: &[Vec<f64>]) -> Vec<f64> {
302 if neighbor_features.is_empty() {
303 return Vec::new();
304 }
305 let dim = neighbor_features[0].len();
306 let mut result = vec![0.0f64; dim];
307 for feat in neighbor_features {
308 for (r, &v) in result.iter_mut().zip(feat.iter()) {
309 *r += v;
310 }
311 }
312 let n = neighbor_features.len() as f64;
313 result.iter_mut().for_each(|v| *v /= n);
314 result
315 }
316
317 fn aggregate_maxpool(neighbor_features: &[Vec<f64>], pool_layer: &DenseLayer) -> Vec<f64> {
319 if neighbor_features.is_empty() {
320 return Vec::new();
321 }
322 let transformed: Vec<Vec<f64>> = neighbor_features
323 .iter()
324 .map(|feat| DenseLayer::relu(&pool_layer.forward(feat)))
325 .collect();
326 let dim = transformed[0].len();
327 let mut result = vec![f64::NEG_INFINITY; dim];
328 for feat in &transformed {
329 for (r, &v) in result.iter_mut().zip(feat.iter()) {
330 if v > *r {
331 *r = v;
332 }
333 }
334 }
335 result
336 }
337
338 fn forward(
343 &self,
344 self_feat: &[f64],
345 neighbor_feats: &[Vec<f64>],
346 aggregator: &AggregatorType,
347 ) -> Vec<f64> {
348 let agg = if neighbor_feats.is_empty() {
349 vec![0.0; self_feat.len()]
350 } else {
351 match aggregator {
352 AggregatorType::Mean | AggregatorType::MeanConcat => {
353 Self::aggregate_mean(neighbor_feats)
354 }
355 AggregatorType::MaxPool { .. } => {
356 if let Some(pool_layer) = &self.pool_mlp {
357 Self::aggregate_maxpool(neighbor_feats, pool_layer)
358 } else {
359 Self::aggregate_mean(neighbor_feats)
360 }
361 }
362 }
363 };
364
365 let agg_padded = if agg.len() != self.neigh_transform.input_dim {
367 let mut padded = vec![0.0f64; self.neigh_transform.input_dim];
368 let copy_len = agg.len().min(self.neigh_transform.input_dim);
369 padded[..copy_len].copy_from_slice(&agg[..copy_len]);
370 padded
371 } else {
372 agg
373 };
374
375 let self_padded = if self_feat.len() != self.self_transform.input_dim {
377 let mut padded = vec![0.0f64; self.self_transform.input_dim];
378 let copy_len = self_feat.len().min(self.self_transform.input_dim);
379 padded[..copy_len].copy_from_slice(&self_feat[..copy_len]);
380 padded
381 } else {
382 self_feat.to_vec()
383 };
384
385 let h_self = self.self_transform.forward(&self_padded);
386 let h_neigh = self.neigh_transform.forward(&agg_padded);
387
388 let combined = match aggregator {
390 AggregatorType::MeanConcat => {
391 let mut concat = h_self;
393 concat.extend(h_neigh);
394 concat.truncate(self.output_dim);
396 while concat.len() < self.output_dim {
397 concat.push(0.0);
398 }
399 concat
400 }
401 _ => {
402 h_self
404 .iter()
405 .zip(h_neigh.iter())
406 .map(|(a, b)| a + b)
407 .collect()
408 }
409 };
410
411 DenseLayer::relu(&combined)
413 }
414}
415
416#[derive(Debug, Clone)]
422pub struct GraphSage {
423 config: GraphSageConfig,
424 layers: Vec<SageLayer>,
425 rng: SimpleLcg,
426}
427
428impl GraphSage {
429 pub fn new(config: GraphSageConfig) -> Result<Self> {
431 if config.input_dim == 0 {
432 return Err(anyhow!("input_dim must be > 0"));
433 }
434 if config.output_dim == 0 {
435 return Err(anyhow!("output_dim must be > 0"));
436 }
437 if config.num_samples.is_empty() {
438 return Err(anyhow!("num_samples must have at least one entry"));
439 }
440
441 let mut rng = SimpleLcg::new(config.seed);
442 let pool_hidden = match &config.aggregator {
443 AggregatorType::MaxPool { hidden_dim } => Some(*hidden_dim),
444 _ => None,
445 };
446
447 let mut dims: Vec<usize> = vec![config.input_dim];
452 dims.extend(config.hidden_dims.iter().copied());
453 dims.push(config.output_dim);
454
455 let num_layers = dims.len() - 1;
456 let mut layers = Vec::with_capacity(num_layers);
457
458 for i in 0..num_layers {
459 let in_dim = dims[i];
460 let out_dim = dims[i + 1];
461 let neigh_dim = in_dim;
463 layers.push(SageLayer::new(
464 in_dim,
465 neigh_dim,
466 out_dim,
467 pool_hidden,
468 &mut rng,
469 ));
470 }
471
472 Ok(Self {
473 config,
474 layers,
475 rng,
476 })
477 }
478
479 pub fn embed(&self, graph: &GraphData) -> Result<GraphSageEmbeddings> {
483 if graph.num_nodes() == 0 {
484 return Err(anyhow!("Graph has no nodes"));
485 }
486 if graph.feature_dim() != self.config.input_dim {
487 return Err(anyhow!(
488 "Graph feature dim {} does not match model input_dim {}",
489 graph.feature_dim(),
490 self.config.input_dim
491 ));
492 }
493
494 let mut h_prev: Vec<Vec<f64>> = graph.node_features.clone();
497
498 for (layer_idx, layer) in self.layers.iter().enumerate() {
499 let mut h_next: Vec<Vec<f64>> = Vec::with_capacity(graph.num_nodes());
500
501 let num_samples = self
503 .config
504 .num_samples
505 .get(layer_idx)
506 .copied()
507 .unwrap_or(25);
508
509 let mut local_rng = SimpleLcg::new(self.config.seed.wrapping_add(layer_idx as u64));
511
512 for node in 0..graph.num_nodes() {
513 let sampled = graph.sample_neighbors(node, num_samples, &mut local_rng);
514 let neighbor_feats: Vec<Vec<f64>> = sampled
515 .iter()
516 .filter_map(|&n| h_prev.get(n).cloned())
517 .collect();
518
519 let self_feat = h_prev.get(node).cloned().unwrap_or_default();
520 let h = layer.forward(&self_feat, &neighbor_feats, &self.config.aggregator);
521 h_next.push(h);
522 }
523
524 h_prev = h_next;
525 }
526
527 let embeddings: Vec<Vec<f64>> = if self.config.normalize_output {
529 h_prev.into_iter().map(|v| Self::normalize(&v)).collect()
530 } else {
531 h_prev
532 };
533
534 let dim = self.config.output_dim;
535 let num_nodes = graph.num_nodes();
536
537 Ok(GraphSageEmbeddings {
538 embeddings,
539 config: self.config.clone(),
540 num_nodes,
541 dim,
542 })
543 }
544
545 pub fn train_unsupervised(&mut self, graph: &GraphData) -> Result<GraphSageTrainingMetrics> {
553 if graph.num_nodes() < 2 {
554 return Err(anyhow!("Graph must have at least 2 nodes for training"));
555 }
556 if graph.feature_dim() != self.config.input_dim {
557 return Err(anyhow!(
558 "Graph feature dim {} != model input_dim {}",
559 graph.feature_dim(),
560 self.config.input_dim
561 ));
562 }
563
564 let mut loss_history = Vec::with_capacity(self.config.epochs);
565
566 for epoch in 0..self.config.epochs {
567 let embeddings = self.embed(graph)?;
568 let epoch_loss = self.compute_unsupervised_loss(&embeddings, graph);
569 loss_history.push(epoch_loss);
570
571 self.apply_random_gradient_step(epoch_loss);
574
575 tracing::debug!(epoch = epoch, loss = epoch_loss, "GraphSAGE training step");
576 }
577
578 let final_loss = loss_history.last().copied().unwrap_or(f64::NAN);
579 let convergence = loss_history.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-4);
580
581 Ok(GraphSageTrainingMetrics {
582 epochs_completed: self.config.epochs,
583 final_loss,
584 loss_history,
585 convergence_achieved: convergence,
586 })
587 }
588
589 fn compute_unsupervised_loss(
591 &self,
592 embeddings: &GraphSageEmbeddings,
593 graph: &GraphData,
594 ) -> f64 {
595 let num_nodes = graph.num_nodes();
596 if num_nodes < 2 {
597 return 0.0;
598 }
599
600 let mut total_loss = 0.0;
601 let mut count = 0usize;
602 let mut local_rng = SimpleLcg::new(self.rng.state);
603
604 let sample_nodes: Vec<usize> = (0..num_nodes.min(self.config.batch_size))
606 .map(|i| i % num_nodes)
607 .collect();
608
609 for &node in &sample_nodes {
610 let neighbors = graph.neighbors(node);
611 if neighbors.is_empty() {
612 continue;
613 }
614 let pos_neighbor = neighbors[local_rng.next_usize() % neighbors.len()];
616
617 let neg_node = local_rng.next_usize() % num_nodes;
619
620 if let (Some(h_u), Some(h_pos), Some(h_neg)) = (
621 embeddings.get(node),
622 embeddings.get(pos_neighbor),
623 embeddings.get(neg_node),
624 ) {
625 let pos_score = dot_product(h_u, h_pos);
626 let neg_score = dot_product(h_u, h_neg);
627
628 let pos_loss = -sigmoid(pos_score).max(1e-10).ln();
630 let neg_loss = -(1.0 - sigmoid(neg_score)).max(1e-10).ln();
631 total_loss += pos_loss + neg_loss;
632 count += 1;
633 }
634 }
635
636 if count > 0 {
637 total_loss / count as f64
638 } else {
639 0.0
640 }
641 }
642
643 fn apply_random_gradient_step(&mut self, loss: f64) {
645 let noise_scale = self.config.learning_rate * loss.abs().min(1.0) * 0.01;
646 for layer in self.layers.iter_mut() {
647 for row in layer.self_transform.weights.iter_mut() {
648 for w in row.iter_mut() {
649 *w -= noise_scale * self.rng.next_f64_range(1.0);
650 }
651 }
652 }
653 }
654
655 pub fn normalize(v: &[f64]) -> Vec<f64> {
657 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
658 if norm < 1e-12 {
659 return v.to_vec();
660 }
661 v.iter().map(|x| x / norm).collect()
662 }
663}
664
665#[derive(Debug, Clone, Serialize, Deserialize)]
667pub struct GraphSageTrainingMetrics {
668 pub epochs_completed: usize,
669 pub final_loss: f64,
670 pub loss_history: Vec<f64>,
671 pub convergence_achieved: bool,
672}
673
674#[derive(Debug, Clone)]
676pub struct GraphSageEmbeddings {
677 pub embeddings: Vec<Vec<f64>>,
679 pub config: GraphSageConfig,
681 pub num_nodes: usize,
683 pub dim: usize,
685}
686
687impl GraphSageEmbeddings {
688 pub fn get(&self, node: usize) -> Option<&[f64]> {
690 self.embeddings.get(node).map(|v| v.as_slice())
691 }
692
693 pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
695 let va = self.get(a)?;
696 let vb = self.get(b)?;
697 Some(cosine_similarity_vecs(va, vb))
698 }
699
700 pub fn top_k_similar(&self, node: usize, k: usize) -> Vec<(usize, f64)> {
704 let query = match self.get(node) {
705 Some(v) => v,
706 None => return Vec::new(),
707 };
708
709 let mut similarities: Vec<(usize, f64)> = (0..self.num_nodes)
710 .filter(|&i| i != node)
711 .filter_map(|i| self.get(i).map(|v| (i, cosine_similarity_vecs(query, v))))
712 .collect();
713
714 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
716 similarities.truncate(k);
717 similarities
718 }
719
720 pub fn labeled_embeddings(&self, labels: &[usize]) -> HashMap<usize, Vec<Vec<f64>>> {
722 let mut map: HashMap<usize, Vec<Vec<f64>>> = HashMap::new();
723 for (node, &label) in labels.iter().enumerate() {
724 if let Some(emb) = self.get(node) {
725 map.entry(label).or_default().push(emb.to_vec());
726 }
727 }
728 map
729 }
730}
731
732pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
734 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
735}
736
737pub fn sigmoid(x: f64) -> f64 {
739 1.0 / (1.0 + (-x).exp())
740}
741
742pub fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
744 let dot = dot_product(a, b);
745 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
746 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
747 if norm_a < 1e-12 || norm_b < 1e-12 {
748 return 0.0;
749 }
750 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
751}
752
753pub fn embedding_err(msg: impl Into<String>) -> crate::EmbeddingError {
755 EmbeddingError::Other(anyhow!(msg.into()))
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761
762 fn star_graph(n: usize, feat_dim: usize, seed: u64) -> GraphData {
764 let mut rng = SimpleLcg::new(seed);
765 let features: Vec<Vec<f64>> = (0..n)
766 .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
767 .collect();
768 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
769 for i in 1..n {
771 adjacency[0].push(i);
772 adjacency[i].push(0);
773 }
774 GraphData::new(features, adjacency).expect("star graph construction should succeed")
775 }
776
777 #[test]
778 fn test_graphsage_default_config() {
779 let config = GraphSageConfig::default();
780 assert_eq!(config.input_dim, 64);
781 assert_eq!(config.output_dim, 64);
782 assert!(!config.num_samples.is_empty());
783 }
784
785 #[test]
786 fn test_graphdata_construction() {
787 let graph = star_graph(5, 8, 1);
788 assert_eq!(graph.num_nodes(), 5);
789 assert_eq!(graph.feature_dim(), 8);
790 assert_eq!(graph.neighbors(0).len(), 4);
791 assert_eq!(graph.neighbors(1).len(), 1);
792 assert_eq!(graph.neighbors(1)[0], 0);
793 }
794
795 #[test]
796 fn test_graphdata_invalid_adjacency() {
797 let features = vec![vec![1.0, 2.0]; 3];
798 let adjacency = vec![
799 vec![1usize, 99], vec![0],
801 vec![0],
802 ];
803 assert!(GraphData::new(features, adjacency).is_err());
804 }
805
806 #[test]
807 fn test_graphsage_embed_shape() {
808 let config = GraphSageConfig {
809 input_dim: 8,
810 hidden_dims: vec![16],
811 output_dim: 4,
812 num_samples: vec![3],
813 epochs: 1,
814 ..Default::default()
815 };
816 let model = GraphSage::new(config).expect("model construction should succeed");
817 let graph = star_graph(5, 8, 42);
818 let embeddings = model.embed(&graph).expect("embed should succeed");
819
820 assert_eq!(embeddings.num_nodes, 5);
821 assert_eq!(embeddings.dim, 4);
822 for i in 0..5 {
823 let emb = embeddings
824 .get(i)
825 .expect("should have embedding for every node");
826 assert_eq!(emb.len(), 4);
827 }
828 }
829
830 #[test]
831 fn test_graphsage_normalized_output() {
832 let config = GraphSageConfig {
833 input_dim: 8,
834 hidden_dims: vec![],
835 output_dim: 8,
836 num_samples: vec![5],
837 normalize_output: true,
838 epochs: 1,
839 ..Default::default()
840 };
841 let model = GraphSage::new(config).expect("model should construct");
842 let graph = star_graph(5, 8, 7);
843 let embeddings = model.embed(&graph).expect("embed should succeed");
844
845 for i in 0..5 {
847 let emb = embeddings.get(i).expect("embedding exists");
848 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
849 assert!(norm < 1.0 + 1e-6, "norm {} should be <= 1", norm);
851 }
852 }
853
854 #[test]
855 fn test_cosine_similarity() {
856 let config = GraphSageConfig {
857 input_dim: 4,
858 hidden_dims: vec![],
859 output_dim: 4,
860 num_samples: vec![5],
861 normalize_output: false,
862 epochs: 1,
863 ..Default::default()
864 };
865 let model = GraphSage::new(config).expect("model should construct");
866 let graph = star_graph(5, 4, 13);
867 let embeddings = model.embed(&graph).expect("embed should succeed");
868
869 if let Some(sim) = embeddings.cosine_similarity(0, 0) {
871 assert!((0.0..=1.0 + 1e-6).contains(&sim));
873 }
874 }
875
876 #[test]
877 fn test_top_k_similar() {
878 let config = GraphSageConfig {
879 input_dim: 4,
880 hidden_dims: vec![8],
881 output_dim: 4,
882 num_samples: vec![5],
883 normalize_output: true,
884 epochs: 1,
885 ..Default::default()
886 };
887 let model = GraphSage::new(config).expect("model should construct");
888 let graph = star_graph(6, 4, 17);
889 let embeddings = model.embed(&graph).expect("embed should succeed");
890
891 let top3 = embeddings.top_k_similar(0, 3);
892 assert!(top3.len() <= 3);
893 for window in top3.windows(2) {
895 assert!(window[0].1 >= window[1].1 - 1e-10);
896 }
897 }
898
899 #[test]
900 fn test_maxpool_aggregator() {
901 let config = GraphSageConfig {
902 input_dim: 4,
903 hidden_dims: vec![8],
904 output_dim: 4,
905 aggregator: AggregatorType::MaxPool { hidden_dim: 8 },
906 num_samples: vec![3],
907 epochs: 1,
908 ..Default::default()
909 };
910 let model = GraphSage::new(config).expect("model should construct with MaxPool");
911 let graph = star_graph(4, 4, 99);
912 let embeddings = model.embed(&graph).expect("embed should succeed");
913 assert_eq!(embeddings.num_nodes, 4);
914 }
915
916 #[test]
917 fn test_train_unsupervised() {
918 let config = GraphSageConfig {
919 input_dim: 4,
920 hidden_dims: vec![8],
921 output_dim: 4,
922 num_samples: vec![3],
923 epochs: 3,
924 batch_size: 4,
925 ..Default::default()
926 };
927 let mut model = GraphSage::new(config).expect("model should construct");
928 let graph = star_graph(5, 4, 42);
929 let metrics = model
930 .train_unsupervised(&graph)
931 .expect("training should succeed");
932 assert_eq!(metrics.epochs_completed, 3);
933 assert_eq!(metrics.loss_history.len(), 3);
934 }
935
936 #[test]
937 fn test_simplecg_reproducibility() {
938 let mut rng1 = SimpleLcg::new(42);
939 let mut rng2 = SimpleLcg::new(42);
940 for _ in 0..100 {
941 assert_eq!(rng1.next_u64(), rng2.next_u64());
942 }
943 }
944
945 #[test]
946 fn test_sample_neighbors() {
947 let graph = star_graph(10, 4, 1);
948 let mut rng = SimpleLcg::new(55);
949 let sampled = graph.sample_neighbors(0, 3, &mut rng);
950 assert!(sampled.len() <= 3);
951 for &n in &sampled {
952 assert!(graph.neighbors(0).contains(&n));
953 }
954 }
955}