1use anyhow::{anyhow, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
22pub struct Lcg {
23 state: u64,
24}
25
26impl Lcg {
27 pub fn new(seed: u64) -> Self {
28 Self {
29 state: seed.wrapping_add(1),
30 }
31 }
32
33 pub fn next_u64(&mut self) -> u64 {
34 self.state = self
35 .state
36 .wrapping_mul(6364136223846793005)
37 .wrapping_add(1442695040888963407);
38 self.state
39 }
40
41 pub fn next_usize_mod(&mut self, n: usize) -> usize {
42 (self.next_u64() as usize) % n
43 }
44
45 pub fn next_f64(&mut self) -> f64 {
46 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
47 }
48
49 pub fn next_f64_range(&mut self, scale: f64) -> f64 {
51 (self.next_f64() * 2.0 - 1.0) * scale
52 }
53
54 pub fn next_normal(&mut self) -> f64 {
56 let u1 = self.next_f64().max(1e-12);
57 let u2 = self.next_f64();
58 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
59 }
60}
61
62#[derive(Debug, Clone)]
68pub struct Graph {
69 pub node_features: Vec<Vec<f64>>,
71 pub adjacency: Vec<Vec<usize>>,
73 pub labels: Option<Vec<usize>>,
75}
76
77impl Graph {
78 pub fn new(node_features: Vec<Vec<f64>>, adjacency: Vec<Vec<usize>>) -> Result<Self> {
80 let n = node_features.len();
81 if adjacency.len() != n {
82 return Err(anyhow!(
83 "adjacency list length {} != num_nodes {}",
84 adjacency.len(),
85 n
86 ));
87 }
88 if let Some(first) = node_features.first() {
90 let dim = first.len();
91 for (i, feat) in node_features.iter().enumerate() {
92 if feat.len() != dim {
93 return Err(anyhow!(
94 "node {} feature dim {} != expected {}",
95 i,
96 feat.len(),
97 dim
98 ));
99 }
100 }
101 }
102 for (i, nbrs) in adjacency.iter().enumerate() {
104 for &j in nbrs {
105 if j >= n {
106 return Err(anyhow!("node {} has out-of-bounds neighbor {}", i, j));
107 }
108 }
109 }
110 Ok(Self {
111 node_features,
112 adjacency,
113 labels: None,
114 })
115 }
116
117 pub fn with_labels(mut self, labels: Vec<usize>) -> Result<Self> {
119 if labels.len() != self.num_nodes() {
120 return Err(anyhow!(
121 "label count {} != num_nodes {}",
122 labels.len(),
123 self.num_nodes()
124 ));
125 }
126 self.labels = Some(labels);
127 Ok(self)
128 }
129
130 pub fn num_nodes(&self) -> usize {
132 self.node_features.len()
133 }
134
135 pub fn feature_dim(&self) -> usize {
137 self.node_features.first().map(|f| f.len()).unwrap_or(0)
138 }
139
140 pub fn neighbors(&self, v: usize) -> &[usize] {
142 self.adjacency.get(v).map(|v| v.as_slice()).unwrap_or(&[])
143 }
144
145 pub fn sample_neighbors(&self, v: usize, k: usize, rng: &mut Lcg) -> Vec<usize> {
147 let nbrs = self.neighbors(v);
148 if nbrs.is_empty() || k == 0 {
149 return Vec::new();
150 }
151 if nbrs.len() <= k {
152 return nbrs.to_vec();
153 }
154 let mut idx: Vec<usize> = (0..nbrs.len()).collect();
156 for i in 0..k {
157 let j = i + rng.next_usize_mod(nbrs.len() - i);
158 idx.swap(i, j);
159 }
160 idx[..k].iter().map(|&i| nbrs[i]).collect()
161 }
162}
163
164#[derive(Debug, Clone)]
170pub struct DenseLayer {
171 weights: Vec<Vec<f64>>, bias: Vec<f64>,
173 pub in_dim: usize,
174 pub out_dim: usize,
175}
176
177impl DenseLayer {
178 pub fn new_xavier(in_dim: usize, out_dim: usize, rng: &mut Lcg) -> Self {
180 let scale = (6.0 / (in_dim + out_dim) as f64).sqrt();
181 let weights = (0..out_dim)
182 .map(|_| (0..in_dim).map(|_| rng.next_f64_range(scale)).collect())
183 .collect();
184 Self {
185 weights,
186 bias: vec![0.0; out_dim],
187 in_dim,
188 out_dim,
189 }
190 }
191
192 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
194 let mut out = self.bias.clone();
195 for (i, row) in self.weights.iter().enumerate() {
196 for (j, &w) in row.iter().enumerate() {
197 out[i] += w * x[j];
198 }
199 }
200 out
201 }
202
203 pub fn relu(x: &[f64]) -> Vec<f64> {
205 x.iter().map(|&v| v.max(0.0)).collect()
206 }
207
208 pub fn tanh(x: &[f64]) -> Vec<f64> {
210 x.iter().map(|&v| v.tanh()).collect()
211 }
212}
213
214pub trait Aggregator: std::fmt::Debug + Send + Sync {
220 fn aggregate(&self, neighbor_features: &[Vec<f64>], input_dim: usize) -> Vec<f64>;
222
223 fn output_dim(&self, input_dim: usize) -> usize;
225}
226
227#[derive(Debug, Clone, Default)]
229pub struct MeanAggregator;
230
231impl Aggregator for MeanAggregator {
232 fn aggregate(&self, neighbor_features: &[Vec<f64>], input_dim: usize) -> Vec<f64> {
233 if neighbor_features.is_empty() {
234 return vec![0.0; input_dim];
235 }
236 let mut mean = vec![0.0f64; input_dim];
237 for feat in neighbor_features {
238 for (i, &v) in feat.iter().enumerate().take(input_dim) {
239 mean[i] += v;
240 }
241 }
242 let n = neighbor_features.len() as f64;
243 mean.iter_mut().for_each(|v| *v /= n);
244 mean
245 }
246
247 fn output_dim(&self, input_dim: usize) -> usize {
248 input_dim
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct MaxPoolAggregator {
255 mlp: DenseLayer,
256 hidden_dim: usize,
257}
258
259impl MaxPoolAggregator {
260 pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
262 Self {
263 mlp: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
264 hidden_dim,
265 }
266 }
267}
268
269impl Aggregator for MaxPoolAggregator {
270 fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
271 if neighbor_features.is_empty() {
272 return vec![0.0; self.hidden_dim];
273 }
274 let mut pool = vec![f64::NEG_INFINITY; self.hidden_dim];
275 for feat in neighbor_features {
276 let transformed = DenseLayer::relu(&self.mlp.forward(feat));
277 for (i, &v) in transformed.iter().enumerate() {
278 if v > pool[i] {
279 pool[i] = v;
280 }
281 }
282 }
283 pool.iter_mut().for_each(|v| {
285 if v.is_infinite() {
286 *v = 0.0;
287 }
288 });
289 pool
290 }
291
292 fn output_dim(&self, _input_dim: usize) -> usize {
293 self.hidden_dim
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct MeanPoolAggregator {
300 mlp: DenseLayer,
301 hidden_dim: usize,
302}
303
304impl MeanPoolAggregator {
305 pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
307 Self {
308 mlp: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
309 hidden_dim,
310 }
311 }
312}
313
314impl Aggregator for MeanPoolAggregator {
315 fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
316 if neighbor_features.is_empty() {
317 return vec![0.0; self.hidden_dim];
318 }
319 let mut mean = vec![0.0f64; self.hidden_dim];
320 for feat in neighbor_features {
321 let transformed = DenseLayer::relu(&self.mlp.forward(feat));
322 for (i, &v) in transformed.iter().enumerate() {
323 mean[i] += v;
324 }
325 }
326 let n = neighbor_features.len() as f64;
327 mean.iter_mut().for_each(|v| *v /= n);
328 mean
329 }
330
331 fn output_dim(&self, _input_dim: usize) -> usize {
332 self.hidden_dim
333 }
334}
335
336#[derive(Debug, Clone)]
344pub struct LSTMAggregator {
345 w_r_x: DenseLayer,
347 w_r_h: DenseLayer,
348 w_z_x: DenseLayer,
349 w_z_h: DenseLayer,
350 w_n_x: DenseLayer,
351 w_n_h: DenseLayer,
352 hidden_dim: usize,
353}
354
355impl LSTMAggregator {
356 pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
358 Self {
359 w_r_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
360 w_r_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
361 w_z_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
362 w_z_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
363 w_n_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
364 w_n_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
365 hidden_dim,
366 }
367 }
368
369 fn sigmoid_vec(x: &[f64]) -> Vec<f64> {
370 x.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect()
371 }
372
373 fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
374 a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
375 }
376
377 fn vec_mul_elem(a: &[f64], b: &[f64]) -> Vec<f64> {
378 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect()
379 }
380
381 fn gru_step(&self, h: &[f64], x: &[f64]) -> Vec<f64> {
382 let r_in = Self::vec_add(&self.w_r_x.forward(x), &self.w_r_h.forward(h));
384 let r = Self::sigmoid_vec(&r_in);
385
386 let z_in = Self::vec_add(&self.w_z_x.forward(x), &self.w_z_h.forward(h));
388 let z = Self::sigmoid_vec(&z_in);
389
390 let r_h = Self::vec_mul_elem(&r, h);
392 let n_in = Self::vec_add(&self.w_n_x.forward(x), &self.w_n_h.forward(&r_h));
393 let n = DenseLayer::tanh(&n_in);
394
395 z.iter()
397 .zip(n.iter())
398 .zip(h.iter())
399 .map(|((&zi, &ni), &hi)| (1.0 - zi) * hi + zi * ni)
400 .collect()
401 }
402}
403
404impl Aggregator for LSTMAggregator {
405 fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
406 let mut h = vec![0.0f64; self.hidden_dim];
407 for feat in neighbor_features {
408 h = self.gru_step(&h, feat);
409 }
410 h
411 }
412
413 fn output_dim(&self, _input_dim: usize) -> usize {
414 self.hidden_dim
415 }
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
424pub enum AggregatorKind {
425 Mean,
426 MaxPool { hidden_dim: usize },
427 MeanPool { hidden_dim: usize },
428 Lstm { hidden_dim: usize },
429}
430
431pub struct GraphSAGELayer {
437 combine: DenseLayer,
439 aggregator: Box<dyn Aggregator>,
440 pub in_dim: usize,
441 pub out_dim: usize,
442 num_samples: usize,
443}
444
445impl std::fmt::Debug for GraphSAGELayer {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 f.debug_struct("GraphSAGELayer")
448 .field("in_dim", &self.in_dim)
449 .field("out_dim", &self.out_dim)
450 .field("num_samples", &self.num_samples)
451 .finish()
452 }
453}
454
455impl GraphSAGELayer {
456 pub fn new(
464 in_dim: usize,
465 out_dim: usize,
466 num_samples: usize,
467 kind: &AggregatorKind,
468 rng: &mut Lcg,
469 ) -> Result<Self> {
470 if in_dim == 0 || out_dim == 0 {
471 return Err(anyhow!("GraphSAGELayer dimensions must be > 0"));
472 }
473 let aggregator: Box<dyn Aggregator> = match kind {
474 AggregatorKind::Mean => Box::new(MeanAggregator),
475 AggregatorKind::MaxPool { hidden_dim } => {
476 Box::new(MaxPoolAggregator::new(in_dim, *hidden_dim, rng))
477 }
478 AggregatorKind::MeanPool { hidden_dim } => {
479 Box::new(MeanPoolAggregator::new(in_dim, *hidden_dim, rng))
480 }
481 AggregatorKind::Lstm { hidden_dim } => {
482 Box::new(LSTMAggregator::new(in_dim, *hidden_dim, rng))
483 }
484 };
485 let agg_out = aggregator.output_dim(in_dim);
486 let combine = DenseLayer::new_xavier(in_dim + agg_out, out_dim, rng);
488 Ok(Self {
489 combine,
490 aggregator,
491 in_dim,
492 out_dim,
493 num_samples,
494 })
495 }
496
497 pub fn forward(
501 &self,
502 graph: &Graph,
503 current_embeddings: &[Vec<f64>],
504 rng: &mut Lcg,
505 ) -> Vec<Vec<f64>> {
506 let n = graph.num_nodes();
507 let mut new_embeddings = Vec::with_capacity(n);
508 for v in 0..n {
509 let sampled = graph.sample_neighbors(v, self.num_samples, rng);
511 let neighbor_feats: Vec<Vec<f64>> = sampled
513 .iter()
514 .filter_map(|&u| current_embeddings.get(u).cloned())
515 .collect();
516 let agg = self.aggregator.aggregate(&neighbor_feats, self.in_dim);
518 let self_feat = current_embeddings
520 .get(v)
521 .cloned()
522 .unwrap_or_else(|| vec![0.0; self.in_dim]);
523 let concat: Vec<f64> = self_feat.iter().chain(agg.iter()).copied().collect();
524 let out = DenseLayer::relu(&self.combine.forward(&concat));
526 new_embeddings.push(out);
527 }
528 new_embeddings
529 }
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct GraphSAGEConfig {
539 pub input_dim: usize,
541 pub hidden_dims: Vec<usize>,
543 pub output_dim: usize,
545 pub aggregator_kind: AggregatorKind,
547 pub num_samples_per_layer: Vec<usize>,
549 pub normalize_output: bool,
551 pub seed: u64,
553}
554
555impl Default for GraphSAGEConfig {
556 fn default() -> Self {
557 Self {
558 input_dim: 64,
559 hidden_dims: vec![256, 128],
560 output_dim: 64,
561 aggregator_kind: AggregatorKind::Mean,
562 num_samples_per_layer: vec![25, 10],
563 normalize_output: true,
564 seed: 42,
565 }
566 }
567}
568
569pub struct GraphSAGEModel {
571 layers: Vec<GraphSAGELayer>,
572 config: GraphSAGEConfig,
573}
574
575impl std::fmt::Debug for GraphSAGEModel {
576 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577 f.debug_struct("GraphSAGEModel")
578 .field("num_layers", &self.layers.len())
579 .field("output_dim", &self.config.output_dim)
580 .finish()
581 }
582}
583
584impl GraphSAGEModel {
585 pub fn new(config: GraphSAGEConfig) -> Result<Self> {
587 if config.input_dim == 0 {
588 return Err(anyhow!("input_dim must be > 0"));
589 }
590 if config.output_dim == 0 {
591 return Err(anyhow!("output_dim must be > 0"));
592 }
593 let mut rng = Lcg::new(config.seed);
594 let mut dims: Vec<usize> = vec![config.input_dim];
596 dims.extend_from_slice(&config.hidden_dims);
597 dims.push(config.output_dim);
598
599 let num_layers = dims.len() - 1;
600 let mut samples = config.num_samples_per_layer.clone();
602 while samples.len() < num_layers {
603 samples.push(samples.last().copied().unwrap_or(10));
604 }
605
606 let mut layers = Vec::with_capacity(num_layers);
607 for i in 0..num_layers {
608 let layer = GraphSAGELayer::new(
609 dims[i],
610 dims[i + 1],
611 samples[i],
612 &config.aggregator_kind,
613 &mut rng,
614 )?;
615 layers.push(layer);
616 }
617
618 Ok(Self { layers, config })
619 }
620
621 pub fn embed(&self, graph: &Graph) -> Result<GraphSAGEEmbeddings> {
623 if graph.num_nodes() == 0 {
624 return Err(anyhow!("Graph has no nodes"));
625 }
626 let mut rng = Lcg::new(self.config.seed.wrapping_add(0xdead_beef));
627 let mut current: Vec<Vec<f64>> = graph.node_features.clone();
628 for layer in &self.layers {
629 current = layer.forward(graph, ¤t, &mut rng);
630 }
631 if self.config.normalize_output {
632 for emb in &mut current {
633 l2_normalize_inplace(emb);
634 }
635 }
636 let dim = self.config.output_dim;
637 Ok(GraphSAGEEmbeddings {
638 embeddings: current,
639 num_nodes: graph.num_nodes(),
640 dim,
641 })
642 }
643
644 pub fn embed_new_node(
647 &self,
648 node_features: &[f64],
649 neighbor_embeddings: &[Vec<f64>],
650 ) -> Result<Vec<f64>> {
651 if node_features.len() != self.config.input_dim {
652 return Err(anyhow!(
653 "node_features dim {} != input_dim {}",
654 node_features.len(),
655 self.config.input_dim
656 ));
657 }
658 let mut rng = Lcg::new(self.config.seed);
659 let features = vec![node_features.to_vec()];
661 let adjacency = vec![Vec::<usize>::new()]; let mini_graph = Graph::new(features, adjacency)?;
663
664 let mut current_self = node_features.to_vec();
667 for layer in &self.layers {
668 let sampled: Vec<Vec<f64>> = if neighbor_embeddings.is_empty() {
669 Vec::new()
670 } else {
671 let k = layer.num_samples.min(neighbor_embeddings.len());
672 neighbor_embeddings[..k].to_vec()
673 };
674 let agg = layer.aggregator.aggregate(&sampled, layer.in_dim);
675 let concat: Vec<f64> = current_self.iter().chain(agg.iter()).copied().collect();
676 current_self = DenseLayer::relu(&layer.combine.forward(&concat));
677 let _ = mini_graph.sample_neighbors(0, 0, &mut rng);
679 }
680 if self.config.normalize_output {
681 l2_normalize_inplace(&mut current_self);
682 }
683 Ok(current_self)
684 }
685}
686
687#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct MiniBatchConfig {
694 pub epochs: usize,
696 pub batch_size: usize,
698 pub num_negative_samples: usize,
700 pub learning_rate: f64,
702 pub seed: u64,
704}
705
706impl Default for MiniBatchConfig {
707 fn default() -> Self {
708 Self {
709 epochs: 10,
710 batch_size: 256,
711 num_negative_samples: 20,
712 learning_rate: 0.01,
713 seed: 0,
714 }
715 }
716}
717
718#[derive(Debug, Clone)]
720pub struct TrainingMetrics {
721 pub epochs_completed: usize,
722 pub loss_history: Vec<f64>,
723 pub final_loss: f64,
724}
725
726pub struct MiniBatchGraphSAGE {
731 model: GraphSAGEModel,
732 batch_cfg: MiniBatchConfig,
733}
734
735impl std::fmt::Debug for MiniBatchGraphSAGE {
736 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737 f.debug_struct("MiniBatchGraphSAGE")
738 .field("model", &self.model)
739 .finish()
740 }
741}
742
743impl MiniBatchGraphSAGE {
744 pub fn new(sage_config: GraphSAGEConfig, batch_cfg: MiniBatchConfig) -> Result<Self> {
746 let model = GraphSAGEModel::new(sage_config)?;
747 Ok(Self { model, batch_cfg })
748 }
749
750 pub fn train(&mut self, graph: &Graph) -> Result<TrainingMetrics> {
754 let n = graph.num_nodes();
755 if n < 2 {
756 return Err(anyhow!("Graph must have at least 2 nodes for training"));
757 }
758 let mut rng = Lcg::new(self.batch_cfg.seed);
759 let mut loss_history = Vec::with_capacity(self.batch_cfg.epochs);
760
761 for epoch in 0..self.batch_cfg.epochs {
762 let embeddings = self.model.embed(graph)?;
764 let mut epoch_loss = 0.0f64;
765 let mut num_pairs: usize = 0;
766
767 let batch_size = self.batch_cfg.batch_size.min(n);
769 for batch_start in (0..n).step_by(batch_size) {
770 let batch_end = (batch_start + batch_size).min(n);
771 for v in batch_start..batch_end {
772 let nbrs = graph.neighbors(v);
773 if nbrs.is_empty() {
774 continue;
775 }
776 let pos_u = nbrs[rng.next_usize_mod(nbrs.len())];
778 let v_emb = embeddings.get(v).unwrap_or(&[]);
779 let u_emb = embeddings.get(pos_u).unwrap_or(&[]);
780 let pos_score = dot_product(v_emb, u_emb);
781 epoch_loss -= log_sigmoid(pos_score);
783
784 for _ in 0..self.batch_cfg.num_negative_samples {
786 let neg = rng.next_usize_mod(n);
787 if neg == v {
788 continue;
789 }
790 let neg_emb = embeddings.get(neg).unwrap_or(&[]);
791 let neg_score = dot_product(v_emb, neg_emb);
792 epoch_loss -= log_sigmoid(-neg_score);
794 }
795 num_pairs += 1;
796 }
797 }
798 if num_pairs > 0 {
799 epoch_loss /= num_pairs as f64;
800 }
801 loss_history.push(epoch_loss);
802 tracing::debug!(
803 "MiniBatchGraphSAGE epoch {}/{}: loss={:.6}",
804 epoch + 1,
805 self.batch_cfg.epochs,
806 epoch_loss
807 );
808 }
809
810 let final_loss = loss_history.last().copied().unwrap_or(f64::NAN);
811 Ok(TrainingMetrics {
812 epochs_completed: self.batch_cfg.epochs,
813 loss_history,
814 final_loss,
815 })
816 }
817
818 pub fn embed(&self, graph: &Graph) -> Result<GraphSAGEEmbeddings> {
820 self.model.embed(graph)
821 }
822}
823
824#[derive(Debug, Clone)]
830pub struct GraphSAGEEmbeddings {
831 pub embeddings: Vec<Vec<f64>>,
832 pub num_nodes: usize,
833 pub dim: usize,
834}
835
836impl GraphSAGEEmbeddings {
837 pub fn get(&self, v: usize) -> Option<&[f64]> {
839 self.embeddings.get(v).map(|e| e.as_slice())
840 }
841
842 pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
845 let ea = self.embeddings.get(a)?;
846 let eb = self.embeddings.get(b)?;
847 Some(cosine_similarity_vecs(ea, eb))
848 }
849
850 pub fn top_k_similar(&self, query_node: usize, k: usize) -> Vec<(usize, f64)> {
852 let query_emb = match self.embeddings.get(query_node) {
853 Some(e) => e,
854 None => return Vec::new(),
855 };
856 let mut sims: Vec<(usize, f64)> = self
857 .embeddings
858 .iter()
859 .enumerate()
860 .filter(|(i, _)| *i != query_node)
861 .map(|(i, e)| (i, cosine_similarity_vecs(query_emb, e)))
862 .collect();
863 sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
864 sims.truncate(k);
865 sims
866 }
867
868 pub fn labeled_embeddings(&self, labels: &[usize]) -> HashMap<usize, (Vec<f64>, usize)> {
870 self.embeddings
871 .iter()
872 .enumerate()
873 .filter_map(|(i, emb)| labels.get(i).map(|&l| (i, (emb.clone(), l))))
874 .collect()
875 }
876}
877
878fn dot_product(a: &[f64], b: &[f64]) -> f64 {
883 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
884}
885
886fn log_sigmoid(x: f64) -> f64 {
887 if x >= 0.0 {
889 -(1.0 + (-x).exp()).ln()
890 } else {
891 x - (1.0 + x.exp()).ln()
892 }
893}
894
895fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
896 let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
897 let na: f64 = a.iter().map(|&x| x * x).sum::<f64>().sqrt();
898 let nb: f64 = b.iter().map(|&x| x * x).sum::<f64>().sqrt();
899 if na < 1e-12 || nb < 1e-12 {
900 return 0.0;
901 }
902 (dot / (na * nb)).clamp(-1.0, 1.0)
903}
904
905fn l2_normalize_inplace(v: &mut [f64]) {
906 let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
907 if norm > 1e-12 {
908 v.iter_mut().for_each(|x| *x /= norm);
909 }
910}
911
912#[cfg(test)]
917mod tests {
918 use super::*;
919
920 fn ring_graph(n: usize, feat_dim: usize, seed: u64) -> Graph {
921 let mut rng = Lcg::new(seed);
922 let features: Vec<Vec<f64>> = (0..n)
923 .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
924 .collect();
925 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
926 for i in 0..n {
927 let next = (i + 1) % n;
928 adjacency[i].push(next);
929 adjacency[next].push(i);
930 }
931 for nbrs in &mut adjacency {
933 nbrs.sort_unstable();
934 nbrs.dedup();
935 }
936 Graph::new(features, adjacency).expect("ring graph construction should succeed")
937 }
938
939 #[test]
940 fn test_graph_construction() {
941 let g = ring_graph(6, 8, 1);
942 assert_eq!(g.num_nodes(), 6);
943 assert_eq!(g.feature_dim(), 8);
944 assert_eq!(g.neighbors(0).len(), 2);
945 }
946
947 #[test]
948 fn test_graph_invalid_adjacency() {
949 let feats = vec![vec![1.0f64; 4]; 3];
950 let adj = vec![vec![1usize, 99], vec![0], vec![0]];
951 assert!(Graph::new(feats, adj).is_err());
952 }
953
954 #[test]
955 fn test_mean_aggregator() {
956 let agg = MeanAggregator;
957 let feats = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
958 let result = agg.aggregate(&feats, 2);
959 assert_eq!(result, vec![2.0, 3.0]);
960 assert_eq!(agg.output_dim(2), 2);
961 }
962
963 #[test]
964 fn test_mean_aggregator_empty() {
965 let agg = MeanAggregator;
966 let result = agg.aggregate(&[], 4);
967 assert_eq!(result, vec![0.0; 4]);
968 }
969
970 #[test]
971 fn test_maxpool_aggregator() {
972 let mut rng = Lcg::new(1);
973 let agg = MaxPoolAggregator::new(4, 8, &mut rng);
974 let feats = vec![vec![1.0f64; 4], vec![-1.0f64; 4]];
975 let result = agg.aggregate(&feats, 4);
976 assert_eq!(result.len(), 8);
977 for &v in &result {
979 assert!(v >= 0.0, "MaxPool result should be non-negative after ReLU");
980 }
981 }
982
983 #[test]
984 fn test_meanpool_aggregator() {
985 let mut rng = Lcg::new(2);
986 let agg = MeanPoolAggregator::new(4, 8, &mut rng);
987 let feats = vec![vec![1.0f64; 4]; 3];
988 let result = agg.aggregate(&feats, 4);
989 assert_eq!(result.len(), 8);
990 }
991
992 #[test]
993 fn test_lstm_aggregator() {
994 let mut rng = Lcg::new(3);
995 let agg = LSTMAggregator::new(4, 8, &mut rng);
996 let feats = vec![vec![0.5f64; 4]; 5];
997 let result = agg.aggregate(&feats, 4);
998 assert_eq!(result.len(), 8);
999 for &v in &result {
1001 assert!(v.is_finite(), "LSTM output should be finite");
1002 }
1003 }
1004
1005 #[test]
1006 fn test_graphsage_layer_mean() {
1007 let mut rng = Lcg::new(42);
1008 let layer = GraphSAGELayer::new(4, 8, 3, &AggregatorKind::Mean, &mut rng)
1009 .expect("layer should construct");
1010 let g = ring_graph(5, 4, 7);
1011 let embeddings = layer.forward(&g, &g.node_features, &mut rng);
1012 assert_eq!(embeddings.len(), 5);
1013 for emb in &embeddings {
1014 assert_eq!(emb.len(), 8);
1015 }
1016 }
1017
1018 #[test]
1019 fn test_graphsage_model_default() {
1020 let config = GraphSAGEConfig {
1021 input_dim: 8,
1022 hidden_dims: vec![16],
1023 output_dim: 4,
1024 aggregator_kind: AggregatorKind::Mean,
1025 num_samples_per_layer: vec![3, 3],
1026 normalize_output: true,
1027 seed: 1,
1028 };
1029 let model = GraphSAGEModel::new(config).expect("model should construct");
1030 let g = ring_graph(6, 8, 5);
1031 let embs = model.embed(&g).expect("embed should succeed");
1032 assert_eq!(embs.num_nodes, 6);
1033 assert_eq!(embs.dim, 4);
1034 for i in 0..6 {
1035 assert_eq!(embs.get(i).expect("embedding exists").len(), 4);
1036 }
1037 }
1038
1039 #[test]
1040 fn test_graphsage_model_maxpool() {
1041 let config = GraphSAGEConfig {
1042 input_dim: 4,
1043 hidden_dims: vec![],
1044 output_dim: 4,
1045 aggregator_kind: AggregatorKind::MaxPool { hidden_dim: 8 },
1046 num_samples_per_layer: vec![5],
1047 normalize_output: false,
1048 seed: 2,
1049 };
1050 let model = GraphSAGEModel::new(config).expect("model should construct");
1051 let g = ring_graph(4, 4, 2);
1052 let embs = model.embed(&g).expect("embed should succeed");
1053 assert_eq!(embs.num_nodes, 4);
1054 }
1055
1056 #[test]
1057 fn test_graphsage_model_meanpool() {
1058 let config = GraphSAGEConfig {
1059 input_dim: 4,
1060 hidden_dims: vec![],
1061 output_dim: 4,
1062 aggregator_kind: AggregatorKind::MeanPool { hidden_dim: 8 },
1063 num_samples_per_layer: vec![5],
1064 normalize_output: false,
1065 seed: 3,
1066 };
1067 let model = GraphSAGEModel::new(config).expect("model should construct");
1068 let g = ring_graph(4, 4, 3);
1069 let embs = model.embed(&g).expect("embed should succeed");
1070 assert_eq!(embs.num_nodes, 4);
1071 }
1072
1073 #[test]
1074 fn test_graphsage_model_lstm() {
1075 let config = GraphSAGEConfig {
1076 input_dim: 4,
1077 hidden_dims: vec![],
1078 output_dim: 4,
1079 aggregator_kind: AggregatorKind::Lstm { hidden_dim: 8 },
1080 num_samples_per_layer: vec![5],
1081 normalize_output: true,
1082 seed: 4,
1083 };
1084 let model = GraphSAGEModel::new(config).expect("model should construct");
1085 let g = ring_graph(4, 4, 4);
1086 let embs = model.embed(&g).expect("embed should succeed");
1087 assert_eq!(embs.num_nodes, 4);
1088 for i in 0..4 {
1090 let emb = embs.get(i).expect("embedding exists");
1091 let norm: f64 = emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
1092 assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
1093 }
1094 }
1095
1096 #[test]
1097 fn test_graphsage_top_k_similar() {
1098 let config = GraphSAGEConfig {
1099 input_dim: 4,
1100 hidden_dims: vec![8],
1101 output_dim: 4,
1102 aggregator_kind: AggregatorKind::Mean,
1103 num_samples_per_layer: vec![3, 3],
1104 normalize_output: true,
1105 seed: 5,
1106 };
1107 let model = GraphSAGEModel::new(config).expect("model should construct");
1108 let g = ring_graph(8, 4, 6);
1109 let embs = model.embed(&g).expect("embed should succeed");
1110 let top3 = embs.top_k_similar(0, 3);
1111 assert!(top3.len() <= 3);
1112 for window in top3.windows(2) {
1113 assert!(
1114 window[0].1 >= window[1].1 - 1e-10,
1115 "top_k should be sorted descending"
1116 );
1117 }
1118 }
1119
1120 #[test]
1121 fn test_graphsage_inductive_embed_new_node() {
1122 let config = GraphSAGEConfig {
1123 input_dim: 4,
1124 hidden_dims: vec![8],
1125 output_dim: 4,
1126 aggregator_kind: AggregatorKind::Mean,
1127 num_samples_per_layer: vec![3, 3],
1128 normalize_output: true,
1129 seed: 9,
1130 };
1131 let model = GraphSAGEModel::new(config).expect("model should construct");
1132 let g = ring_graph(5, 4, 10);
1133 let embs = model.embed(&g).expect("embed should succeed");
1135 let neighbor_embs: Vec<Vec<f64>> = vec![
1136 embs.get(0).expect("exists").to_vec(),
1137 embs.get(1).expect("exists").to_vec(),
1138 ];
1139 let new_node_features = vec![0.5f64; 4];
1140 let new_emb = model
1141 .embed_new_node(&new_node_features, &neighbor_embs)
1142 .expect("inductive embed should succeed");
1143 assert_eq!(new_emb.len(), 4);
1144 let norm: f64 = new_emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
1145 assert!(
1146 norm <= 1.0 + 1e-6,
1147 "normalized embedding norm should be <= 1"
1148 );
1149 }
1150
1151 #[test]
1152 fn test_minibatch_graphsage_train() {
1153 let sage_cfg = GraphSAGEConfig {
1154 input_dim: 4,
1155 hidden_dims: vec![8],
1156 output_dim: 4,
1157 aggregator_kind: AggregatorKind::Mean,
1158 num_samples_per_layer: vec![3, 3],
1159 normalize_output: true,
1160 seed: 7,
1161 };
1162 let batch_cfg = MiniBatchConfig {
1163 epochs: 3,
1164 batch_size: 4,
1165 num_negative_samples: 2,
1166 learning_rate: 0.01,
1167 seed: 7,
1168 };
1169 let mut trainer =
1170 MiniBatchGraphSAGE::new(sage_cfg, batch_cfg).expect("trainer should construct");
1171 let g = ring_graph(8, 4, 8);
1172 let metrics = trainer.train(&g).expect("training should succeed");
1173 assert_eq!(metrics.epochs_completed, 3);
1174 assert_eq!(metrics.loss_history.len(), 3);
1175 for &loss in &metrics.loss_history {
1176 assert!(loss.is_finite(), "loss should be finite");
1177 }
1178 }
1179
1180 #[test]
1181 fn test_minibatch_graphsage_embed_after_train() {
1182 let sage_cfg = GraphSAGEConfig {
1183 input_dim: 4,
1184 hidden_dims: vec![],
1185 output_dim: 4,
1186 aggregator_kind: AggregatorKind::Mean,
1187 num_samples_per_layer: vec![3],
1188 normalize_output: true,
1189 seed: 11,
1190 };
1191 let batch_cfg = MiniBatchConfig {
1192 epochs: 2,
1193 batch_size: 3,
1194 num_negative_samples: 1,
1195 learning_rate: 0.01,
1196 seed: 11,
1197 };
1198 let mut trainer =
1199 MiniBatchGraphSAGE::new(sage_cfg, batch_cfg).expect("trainer should construct");
1200 let g = ring_graph(5, 4, 12);
1201 trainer.train(&g).expect("training should succeed");
1202 let embs = trainer.embed(&g).expect("embed should succeed");
1203 assert_eq!(embs.num_nodes, 5);
1204 assert_eq!(embs.dim, 4);
1205 }
1206
1207 #[test]
1208 fn test_graphsage_with_labels() {
1209 let g = ring_graph(4, 4, 20)
1210 .with_labels(vec![0, 1, 0, 1])
1211 .expect("labels should attach");
1212 assert!(g.labels.is_some());
1213 let config = GraphSAGEConfig {
1214 input_dim: 4,
1215 hidden_dims: vec![],
1216 output_dim: 4,
1217 aggregator_kind: AggregatorKind::Mean,
1218 num_samples_per_layer: vec![3],
1219 normalize_output: true,
1220 seed: 20,
1221 };
1222 let model = GraphSAGEModel::new(config).expect("model should construct");
1223 let embs = model.embed(&g).expect("embed should succeed");
1224 let labels = g.labels.as_ref().expect("labels exist");
1225 let labeled = embs.labeled_embeddings(labels);
1226 assert_eq!(labeled.len(), 4);
1227 }
1228
1229 #[test]
1230 fn test_lcg_reproducibility() {
1231 let mut a = Lcg::new(99);
1232 let mut b = Lcg::new(99);
1233 for _ in 0..200 {
1234 assert_eq!(a.next_u64(), b.next_u64());
1235 }
1236 }
1237
1238 #[test]
1239 fn test_graphsage_invalid_config() {
1240 assert!(GraphSAGEModel::new(GraphSAGEConfig {
1241 input_dim: 0,
1242 ..Default::default()
1243 })
1244 .is_err());
1245 assert!(GraphSAGEModel::new(GraphSAGEConfig {
1246 output_dim: 0,
1247 ..Default::default()
1248 })
1249 .is_err());
1250 }
1251}