Skip to main content

oxirs_embed/graph_models/
graphsage.rs

1//! GraphSAGE: Inductive Representation Learning on Large Graphs
2//!
3//! This module provides v0.3.0 GraphSAGE implementations with:
4//! - `GraphSAGELayer`: inductive representation learning via neighbor sampling
5//! - `MeanAggregator`, `MaxPoolAggregator`, `MeanPoolAggregator`, `LSTMAggregator`
6//! - `GraphSAGEModel`: multi-layer GraphSAGE with configurable depth and hidden dims
7//! - `MiniBatchGraphSAGE`: mini-batch training for large graphs
8//!
9//! Reference: Hamilton, Ying, Leskovec (2017) - NeurIPS
10//! "Inductive Representation Learning on Large Graphs"
11
12use anyhow::{anyhow, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16// ---------------------------------------------------------------------------
17// Utility: Simple LCG PRNG (no external rand dependency)
18// ---------------------------------------------------------------------------
19
20/// Minimal Linear Congruential Generator for reproducible sampling.
21#[derive(Debug, Clone)]
22pub struct Lcg {
23    state: u64,
24}
25
26impl Lcg {
27    pub fn new(seed: u64) -> Self {
28        Self {
29            state: seed.wrapping_add(1),
30        }
31    }
32
33    pub fn next_u64(&mut self) -> u64 {
34        self.state = self
35            .state
36            .wrapping_mul(6364136223846793005)
37            .wrapping_add(1442695040888963407);
38        self.state
39    }
40
41    pub fn next_usize_mod(&mut self, n: usize) -> usize {
42        (self.next_u64() as usize) % n
43    }
44
45    pub fn next_f64(&mut self) -> f64 {
46        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
47    }
48
49    /// Uniform in [-scale, scale)
50    pub fn next_f64_range(&mut self, scale: f64) -> f64 {
51        (self.next_f64() * 2.0 - 1.0) * scale
52    }
53
54    /// Standard normal approximation via Box-Muller
55    pub fn next_normal(&mut self) -> f64 {
56        let u1 = self.next_f64().max(1e-12);
57        let u2 = self.next_f64();
58        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Graph data structures
64// ---------------------------------------------------------------------------
65
66/// A homogeneous graph with node feature vectors and adjacency lists.
67#[derive(Debug, Clone)]
68pub struct Graph {
69    /// `node_features[i]` = feature vector for node `i`.
70    pub node_features: Vec<Vec<f64>>,
71    /// `adjacency[i]` = sorted list of neighbor indices for node `i`.
72    pub adjacency: Vec<Vec<usize>>,
73    /// Optional node-level class labels (for supervised training).
74    pub labels: Option<Vec<usize>>,
75}
76
77impl Graph {
78    /// Construct and validate a new graph.
79    pub fn new(node_features: Vec<Vec<f64>>, adjacency: Vec<Vec<usize>>) -> Result<Self> {
80        let n = node_features.len();
81        if adjacency.len() != n {
82            return Err(anyhow!(
83                "adjacency list length {} != num_nodes {}",
84                adjacency.len(),
85                n
86            ));
87        }
88        // Validate feature dimension consistency
89        if let Some(first) = node_features.first() {
90            let dim = first.len();
91            for (i, feat) in node_features.iter().enumerate() {
92                if feat.len() != dim {
93                    return Err(anyhow!(
94                        "node {} feature dim {} != expected {}",
95                        i,
96                        feat.len(),
97                        dim
98                    ));
99                }
100            }
101        }
102        // Validate adjacency bounds
103        for (i, nbrs) in adjacency.iter().enumerate() {
104            for &j in nbrs {
105                if j >= n {
106                    return Err(anyhow!("node {} has out-of-bounds neighbor {}", i, j));
107                }
108            }
109        }
110        Ok(Self {
111            node_features,
112            adjacency,
113            labels: None,
114        })
115    }
116
117    /// Attach labels (must match `num_nodes()`).
118    pub fn with_labels(mut self, labels: Vec<usize>) -> Result<Self> {
119        if labels.len() != self.num_nodes() {
120            return Err(anyhow!(
121                "label count {} != num_nodes {}",
122                labels.len(),
123                self.num_nodes()
124            ));
125        }
126        self.labels = Some(labels);
127        Ok(self)
128    }
129
130    /// Number of nodes.
131    pub fn num_nodes(&self) -> usize {
132        self.node_features.len()
133    }
134
135    /// Feature dimension (0 if graph is empty).
136    pub fn feature_dim(&self) -> usize {
137        self.node_features.first().map(|f| f.len()).unwrap_or(0)
138    }
139
140    /// Get neighbors of node `v`.
141    pub fn neighbors(&self, v: usize) -> &[usize] {
142        self.adjacency.get(v).map(|v| v.as_slice()).unwrap_or(&[])
143    }
144
145    /// Sample up to `k` neighbors uniformly without replacement.
146    pub fn sample_neighbors(&self, v: usize, k: usize, rng: &mut Lcg) -> Vec<usize> {
147        let nbrs = self.neighbors(v);
148        if nbrs.is_empty() || k == 0 {
149            return Vec::new();
150        }
151        if nbrs.len() <= k {
152            return nbrs.to_vec();
153        }
154        // Partial Fisher-Yates
155        let mut idx: Vec<usize> = (0..nbrs.len()).collect();
156        for i in 0..k {
157            let j = i + rng.next_usize_mod(nbrs.len() - i);
158            idx.swap(i, j);
159        }
160        idx[..k].iter().map(|&i| nbrs[i]).collect()
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Dense layer utility
166// ---------------------------------------------------------------------------
167
168/// A fully-connected layer: `output = W * input + bias`.
169#[derive(Debug, Clone)]
170pub struct DenseLayer {
171    weights: Vec<Vec<f64>>, // [out_dim][in_dim]
172    bias: Vec<f64>,
173    pub in_dim: usize,
174    pub out_dim: usize,
175}
176
177impl DenseLayer {
178    /// Xavier/Glorot uniform initialization.
179    pub fn new_xavier(in_dim: usize, out_dim: usize, rng: &mut Lcg) -> Self {
180        let scale = (6.0 / (in_dim + out_dim) as f64).sqrt();
181        let weights = (0..out_dim)
182            .map(|_| (0..in_dim).map(|_| rng.next_f64_range(scale)).collect())
183            .collect();
184        Self {
185            weights,
186            bias: vec![0.0; out_dim],
187            in_dim,
188            out_dim,
189        }
190    }
191
192    /// Forward pass: W * x + b.
193    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
194        let mut out = self.bias.clone();
195        for (i, row) in self.weights.iter().enumerate() {
196            for (j, &w) in row.iter().enumerate() {
197                out[i] += w * x[j];
198            }
199        }
200        out
201    }
202
203    /// ReLU activation.
204    pub fn relu(x: &[f64]) -> Vec<f64> {
205        x.iter().map(|&v| v.max(0.0)).collect()
206    }
207
208    /// Tanh activation.
209    pub fn tanh(x: &[f64]) -> Vec<f64> {
210        x.iter().map(|&v| v.tanh()).collect()
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Aggregator trait and implementations
216// ---------------------------------------------------------------------------
217
218/// Aggregates a set of neighbor feature vectors into a single vector.
219pub trait Aggregator: std::fmt::Debug + Send + Sync {
220    /// Aggregate `neighbor_features` (each of length `input_dim`) into one vector.
221    fn aggregate(&self, neighbor_features: &[Vec<f64>], input_dim: usize) -> Vec<f64>;
222
223    /// Output dimensionality produced by this aggregator given `input_dim`.
224    fn output_dim(&self, input_dim: usize) -> usize;
225}
226
227/// Mean aggregator: element-wise mean of neighbor features.
228#[derive(Debug, Clone, Default)]
229pub struct MeanAggregator;
230
231impl Aggregator for MeanAggregator {
232    fn aggregate(&self, neighbor_features: &[Vec<f64>], input_dim: usize) -> Vec<f64> {
233        if neighbor_features.is_empty() {
234            return vec![0.0; input_dim];
235        }
236        let mut mean = vec![0.0f64; input_dim];
237        for feat in neighbor_features {
238            for (i, &v) in feat.iter().enumerate().take(input_dim) {
239                mean[i] += v;
240            }
241        }
242        let n = neighbor_features.len() as f64;
243        mean.iter_mut().for_each(|v| *v /= n);
244        mean
245    }
246
247    fn output_dim(&self, input_dim: usize) -> usize {
248        input_dim
249    }
250}
251
252/// Max-pool aggregator: element-wise max of neighbor features after MLP.
253#[derive(Debug, Clone)]
254pub struct MaxPoolAggregator {
255    mlp: DenseLayer,
256    hidden_dim: usize,
257}
258
259impl MaxPoolAggregator {
260    /// Create with a single hidden MLP applied before max-pooling.
261    pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
262        Self {
263            mlp: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
264            hidden_dim,
265        }
266    }
267}
268
269impl Aggregator for MaxPoolAggregator {
270    fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
271        if neighbor_features.is_empty() {
272            return vec![0.0; self.hidden_dim];
273        }
274        let mut pool = vec![f64::NEG_INFINITY; self.hidden_dim];
275        for feat in neighbor_features {
276            let transformed = DenseLayer::relu(&self.mlp.forward(feat));
277            for (i, &v) in transformed.iter().enumerate() {
278                if v > pool[i] {
279                    pool[i] = v;
280                }
281            }
282        }
283        // Replace -inf with 0 for isolated situations
284        pool.iter_mut().for_each(|v| {
285            if v.is_infinite() {
286                *v = 0.0;
287            }
288        });
289        pool
290    }
291
292    fn output_dim(&self, _input_dim: usize) -> usize {
293        self.hidden_dim
294    }
295}
296
297/// Mean-pool aggregator: element-wise mean of neighbor features after MLP.
298#[derive(Debug, Clone)]
299pub struct MeanPoolAggregator {
300    mlp: DenseLayer,
301    hidden_dim: usize,
302}
303
304impl MeanPoolAggregator {
305    /// Create with a single hidden MLP applied before mean-pooling.
306    pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
307        Self {
308            mlp: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
309            hidden_dim,
310        }
311    }
312}
313
314impl Aggregator for MeanPoolAggregator {
315    fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
316        if neighbor_features.is_empty() {
317            return vec![0.0; self.hidden_dim];
318        }
319        let mut mean = vec![0.0f64; self.hidden_dim];
320        for feat in neighbor_features {
321            let transformed = DenseLayer::relu(&self.mlp.forward(feat));
322            for (i, &v) in transformed.iter().enumerate() {
323                mean[i] += v;
324            }
325        }
326        let n = neighbor_features.len() as f64;
327        mean.iter_mut().for_each(|v| *v /= n);
328        mean
329    }
330
331    fn output_dim(&self, _input_dim: usize) -> usize {
332        self.hidden_dim
333    }
334}
335
336/// LSTM-style aggregator: processes neighbors sequentially with a GRU cell.
337///
338/// A simplified GRU (Gated Recurrent Unit) is used for efficiency:
339/// - Reset gate:  r = sigmoid(W_r * [h; x] + b_r)
340/// - Update gate: z = sigmoid(W_z * [h; x] + b_z)
341/// - New state:   n = tanh(W_n * [h*r; x] + b_n)
342/// - Output:      h' = (1-z)*h + z*n
343#[derive(Debug, Clone)]
344pub struct LSTMAggregator {
345    /// GRU weight matrices for input (in_dim) and hidden (hidden_dim)
346    w_r_x: DenseLayer,
347    w_r_h: DenseLayer,
348    w_z_x: DenseLayer,
349    w_z_h: DenseLayer,
350    w_n_x: DenseLayer,
351    w_n_h: DenseLayer,
352    hidden_dim: usize,
353}
354
355impl LSTMAggregator {
356    /// Create a GRU aggregator.
357    pub fn new(input_dim: usize, hidden_dim: usize, rng: &mut Lcg) -> Self {
358        Self {
359            w_r_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
360            w_r_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
361            w_z_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
362            w_z_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
363            w_n_x: DenseLayer::new_xavier(input_dim, hidden_dim, rng),
364            w_n_h: DenseLayer::new_xavier(hidden_dim, hidden_dim, rng),
365            hidden_dim,
366        }
367    }
368
369    fn sigmoid_vec(x: &[f64]) -> Vec<f64> {
370        x.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect()
371    }
372
373    fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
374        a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
375    }
376
377    fn vec_mul_elem(a: &[f64], b: &[f64]) -> Vec<f64> {
378        a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect()
379    }
380
381    fn gru_step(&self, h: &[f64], x: &[f64]) -> Vec<f64> {
382        // r = sigmoid(W_r_x * x + W_r_h * h)
383        let r_in = Self::vec_add(&self.w_r_x.forward(x), &self.w_r_h.forward(h));
384        let r = Self::sigmoid_vec(&r_in);
385
386        // z = sigmoid(W_z_x * x + W_z_h * h)
387        let z_in = Self::vec_add(&self.w_z_x.forward(x), &self.w_z_h.forward(h));
388        let z = Self::sigmoid_vec(&z_in);
389
390        // n = tanh(W_n_x * x + W_n_h * (r * h))
391        let r_h = Self::vec_mul_elem(&r, h);
392        let n_in = Self::vec_add(&self.w_n_x.forward(x), &self.w_n_h.forward(&r_h));
393        let n = DenseLayer::tanh(&n_in);
394
395        // h' = (1-z)*h + z*n
396        z.iter()
397            .zip(n.iter())
398            .zip(h.iter())
399            .map(|((&zi, &ni), &hi)| (1.0 - zi) * hi + zi * ni)
400            .collect()
401    }
402}
403
404impl Aggregator for LSTMAggregator {
405    fn aggregate(&self, neighbor_features: &[Vec<f64>], _input_dim: usize) -> Vec<f64> {
406        let mut h = vec![0.0f64; self.hidden_dim];
407        for feat in neighbor_features {
408            h = self.gru_step(&h, feat);
409        }
410        h
411    }
412
413    fn output_dim(&self, _input_dim: usize) -> usize {
414        self.hidden_dim
415    }
416}
417
418// ---------------------------------------------------------------------------
419// GraphSAGELayer
420// ---------------------------------------------------------------------------
421
422/// Aggregator variant selector for `GraphSAGELayer`.
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub enum AggregatorKind {
425    Mean,
426    MaxPool { hidden_dim: usize },
427    MeanPool { hidden_dim: usize },
428    Lstm { hidden_dim: usize },
429}
430
431/// A single GraphSAGE layer: aggregate neighbors then combine with self.
432///
433/// For each node `v`:
434///   agg = AGGREGATE({ h_u | u ∈ N(v) })
435///   h_v = σ( W · CONCAT(h_v, agg) )
436pub struct GraphSAGELayer {
437    /// W: maps concat(self, agg) -> output
438    combine: DenseLayer,
439    aggregator: Box<dyn Aggregator>,
440    pub in_dim: usize,
441    pub out_dim: usize,
442    num_samples: usize,
443}
444
445impl std::fmt::Debug for GraphSAGELayer {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        f.debug_struct("GraphSAGELayer")
448            .field("in_dim", &self.in_dim)
449            .field("out_dim", &self.out_dim)
450            .field("num_samples", &self.num_samples)
451            .finish()
452    }
453}
454
455impl GraphSAGELayer {
456    /// Build a new layer.
457    ///
458    /// `in_dim`      - input feature dimension for this layer  
459    /// `out_dim`     - output embedding dimension  
460    /// `num_samples` - neighborhood sample size  
461    /// `kind`        - aggregator type  
462    /// `rng`         - seeded RNG for weight initialization
463    pub fn new(
464        in_dim: usize,
465        out_dim: usize,
466        num_samples: usize,
467        kind: &AggregatorKind,
468        rng: &mut Lcg,
469    ) -> Result<Self> {
470        if in_dim == 0 || out_dim == 0 {
471            return Err(anyhow!("GraphSAGELayer dimensions must be > 0"));
472        }
473        let aggregator: Box<dyn Aggregator> = match kind {
474            AggregatorKind::Mean => Box::new(MeanAggregator),
475            AggregatorKind::MaxPool { hidden_dim } => {
476                Box::new(MaxPoolAggregator::new(in_dim, *hidden_dim, rng))
477            }
478            AggregatorKind::MeanPool { hidden_dim } => {
479                Box::new(MeanPoolAggregator::new(in_dim, *hidden_dim, rng))
480            }
481            AggregatorKind::Lstm { hidden_dim } => {
482                Box::new(LSTMAggregator::new(in_dim, *hidden_dim, rng))
483            }
484        };
485        let agg_out = aggregator.output_dim(in_dim);
486        // combine layer: takes [self_feat | agg_feat] -> out_dim
487        let combine = DenseLayer::new_xavier(in_dim + agg_out, out_dim, rng);
488        Ok(Self {
489            combine,
490            aggregator,
491            in_dim,
492            out_dim,
493            num_samples,
494        })
495    }
496
497    /// Forward pass: compute new embeddings for all nodes.
498    ///
499    /// `current_embeddings[v]` = current feature vector for node `v`.
500    pub fn forward(
501        &self,
502        graph: &Graph,
503        current_embeddings: &[Vec<f64>],
504        rng: &mut Lcg,
505    ) -> Vec<Vec<f64>> {
506        let n = graph.num_nodes();
507        let mut new_embeddings = Vec::with_capacity(n);
508        for v in 0..n {
509            // Sample neighbors
510            let sampled = graph.sample_neighbors(v, self.num_samples, rng);
511            // Gather neighbor features
512            let neighbor_feats: Vec<Vec<f64>> = sampled
513                .iter()
514                .filter_map(|&u| current_embeddings.get(u).cloned())
515                .collect();
516            // Aggregate
517            let agg = self.aggregator.aggregate(&neighbor_feats, self.in_dim);
518            // Concatenate self + aggregate
519            let self_feat = current_embeddings
520                .get(v)
521                .cloned()
522                .unwrap_or_else(|| vec![0.0; self.in_dim]);
523            let concat: Vec<f64> = self_feat.iter().chain(agg.iter()).copied().collect();
524            // Linear transform + ReLU
525            let out = DenseLayer::relu(&self.combine.forward(&concat));
526            new_embeddings.push(out);
527        }
528        new_embeddings
529    }
530}
531
532// ---------------------------------------------------------------------------
533// GraphSAGEModel
534// ---------------------------------------------------------------------------
535
536/// Configuration for a multi-layer GraphSAGE model.
537#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct GraphSAGEConfig {
539    /// Dimension of raw node features.
540    pub input_dim: usize,
541    /// Sizes of hidden layers (excluding input and output).
542    pub hidden_dims: Vec<usize>,
543    /// Final output embedding dimension.
544    pub output_dim: usize,
545    /// Aggregator kind applied at each layer.
546    pub aggregator_kind: AggregatorKind,
547    /// Number of neighbors to sample at each layer.
548    pub num_samples_per_layer: Vec<usize>,
549    /// L2-normalize final embeddings.
550    pub normalize_output: bool,
551    /// Random seed.
552    pub seed: u64,
553}
554
555impl Default for GraphSAGEConfig {
556    fn default() -> Self {
557        Self {
558            input_dim: 64,
559            hidden_dims: vec![256, 128],
560            output_dim: 64,
561            aggregator_kind: AggregatorKind::Mean,
562            num_samples_per_layer: vec![25, 10],
563            normalize_output: true,
564            seed: 42,
565        }
566    }
567}
568
569/// Multi-layer GraphSAGE model.
570pub struct GraphSAGEModel {
571    layers: Vec<GraphSAGELayer>,
572    config: GraphSAGEConfig,
573}
574
575impl std::fmt::Debug for GraphSAGEModel {
576    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577        f.debug_struct("GraphSAGEModel")
578            .field("num_layers", &self.layers.len())
579            .field("output_dim", &self.config.output_dim)
580            .finish()
581    }
582}
583
584impl GraphSAGEModel {
585    /// Construct from configuration.
586    pub fn new(config: GraphSAGEConfig) -> Result<Self> {
587        if config.input_dim == 0 {
588            return Err(anyhow!("input_dim must be > 0"));
589        }
590        if config.output_dim == 0 {
591            return Err(anyhow!("output_dim must be > 0"));
592        }
593        let mut rng = Lcg::new(config.seed);
594        // Build layer dimensions: [input_dim, hidden_dims..., output_dim]
595        let mut dims: Vec<usize> = vec![config.input_dim];
596        dims.extend_from_slice(&config.hidden_dims);
597        dims.push(config.output_dim);
598
599        let num_layers = dims.len() - 1;
600        // Pad num_samples to match layers
601        let mut samples = config.num_samples_per_layer.clone();
602        while samples.len() < num_layers {
603            samples.push(samples.last().copied().unwrap_or(10));
604        }
605
606        let mut layers = Vec::with_capacity(num_layers);
607        for i in 0..num_layers {
608            let layer = GraphSAGELayer::new(
609                dims[i],
610                dims[i + 1],
611                samples[i],
612                &config.aggregator_kind,
613                &mut rng,
614            )?;
615            layers.push(layer);
616        }
617
618        Ok(Self { layers, config })
619    }
620
621    /// Compute embeddings for all nodes in `graph`.
622    pub fn embed(&self, graph: &Graph) -> Result<GraphSAGEEmbeddings> {
623        if graph.num_nodes() == 0 {
624            return Err(anyhow!("Graph has no nodes"));
625        }
626        let mut rng = Lcg::new(self.config.seed.wrapping_add(0xdead_beef));
627        let mut current: Vec<Vec<f64>> = graph.node_features.clone();
628        for layer in &self.layers {
629            current = layer.forward(graph, &current, &mut rng);
630        }
631        if self.config.normalize_output {
632            for emb in &mut current {
633                l2_normalize_inplace(emb);
634            }
635        }
636        let dim = self.config.output_dim;
637        Ok(GraphSAGEEmbeddings {
638            embeddings: current,
639            num_nodes: graph.num_nodes(),
640            dim,
641        })
642    }
643
644    /// Inductive inference: embed a single new node given its features and
645    /// the embeddings of its known neighbors.
646    pub fn embed_new_node(
647        &self,
648        node_features: &[f64],
649        neighbor_embeddings: &[Vec<f64>],
650    ) -> Result<Vec<f64>> {
651        if node_features.len() != self.config.input_dim {
652            return Err(anyhow!(
653                "node_features dim {} != input_dim {}",
654                node_features.len(),
655                self.config.input_dim
656            ));
657        }
658        let mut rng = Lcg::new(self.config.seed);
659        // Create a tiny 1-node graph for this new node
660        let features = vec![node_features.to_vec()];
661        let adjacency = vec![Vec::<usize>::new()]; // isolated during first layer
662        let mini_graph = Graph::new(features, adjacency)?;
663
664        // Run through layers using neighbor_embeddings as input for layer 0
665        // For simplicity: feed neighbor_embeddings through aggregators manually
666        let mut current_self = node_features.to_vec();
667        for layer in &self.layers {
668            let sampled: Vec<Vec<f64>> = if neighbor_embeddings.is_empty() {
669                Vec::new()
670            } else {
671                let k = layer.num_samples.min(neighbor_embeddings.len());
672                neighbor_embeddings[..k].to_vec()
673            };
674            let agg = layer.aggregator.aggregate(&sampled, layer.in_dim);
675            let concat: Vec<f64> = current_self.iter().chain(agg.iter()).copied().collect();
676            current_self = DenseLayer::relu(&layer.combine.forward(&concat));
677            // Dummy call to suppress rng warning
678            let _ = mini_graph.sample_neighbors(0, 0, &mut rng);
679        }
680        if self.config.normalize_output {
681            l2_normalize_inplace(&mut current_self);
682        }
683        Ok(current_self)
684    }
685}
686
687// ---------------------------------------------------------------------------
688// MiniBatchGraphSAGE
689// ---------------------------------------------------------------------------
690
691/// Training configuration for `MiniBatchGraphSAGE`.
692#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct MiniBatchConfig {
694    /// Number of training epochs.
695    pub epochs: usize,
696    /// Batch size (number of anchor nodes per mini-batch).
697    pub batch_size: usize,
698    /// Number of negative samples per anchor.
699    pub num_negative_samples: usize,
700    /// Learning rate (SGD step size for unsupervised loss).
701    pub learning_rate: f64,
702    /// Random seed for reproducibility.
703    pub seed: u64,
704}
705
706impl Default for MiniBatchConfig {
707    fn default() -> Self {
708        Self {
709            epochs: 10,
710            batch_size: 256,
711            num_negative_samples: 20,
712            learning_rate: 0.01,
713            seed: 0,
714        }
715    }
716}
717
718/// Training metrics returned after `MiniBatchGraphSAGE::train`.
719#[derive(Debug, Clone)]
720pub struct TrainingMetrics {
721    pub epochs_completed: usize,
722    pub loss_history: Vec<f64>,
723    pub final_loss: f64,
724}
725
726/// Mini-batch GraphSAGE trainer for large graphs.
727///
728/// Uses unsupervised loss: cross-entropy between positive (edge) pairs
729/// and negative (non-edge) pairs, approximated via sigmoid.
730pub struct MiniBatchGraphSAGE {
731    model: GraphSAGEModel,
732    batch_cfg: MiniBatchConfig,
733}
734
735impl std::fmt::Debug for MiniBatchGraphSAGE {
736    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737        f.debug_struct("MiniBatchGraphSAGE")
738            .field("model", &self.model)
739            .finish()
740    }
741}
742
743impl MiniBatchGraphSAGE {
744    /// Create a new mini-batch trainer.
745    pub fn new(sage_config: GraphSAGEConfig, batch_cfg: MiniBatchConfig) -> Result<Self> {
746        let model = GraphSAGEModel::new(sage_config)?;
747        Ok(Self { model, batch_cfg })
748    }
749
750    /// Run unsupervised mini-batch training on `graph`.
751    ///
752    /// After training, call `embed()` to retrieve node embeddings.
753    pub fn train(&mut self, graph: &Graph) -> Result<TrainingMetrics> {
754        let n = graph.num_nodes();
755        if n < 2 {
756            return Err(anyhow!("Graph must have at least 2 nodes for training"));
757        }
758        let mut rng = Lcg::new(self.batch_cfg.seed);
759        let mut loss_history = Vec::with_capacity(self.batch_cfg.epochs);
760
761        for epoch in 0..self.batch_cfg.epochs {
762            // Compute current embeddings
763            let embeddings = self.model.embed(graph)?;
764            let mut epoch_loss = 0.0f64;
765            let mut num_pairs: usize = 0;
766
767            // Process mini-batches of anchor nodes
768            let batch_size = self.batch_cfg.batch_size.min(n);
769            for batch_start in (0..n).step_by(batch_size) {
770                let batch_end = (batch_start + batch_size).min(n);
771                for v in batch_start..batch_end {
772                    let nbrs = graph.neighbors(v);
773                    if nbrs.is_empty() {
774                        continue;
775                    }
776                    // Positive sample: a real neighbor
777                    let pos_u = nbrs[rng.next_usize_mod(nbrs.len())];
778                    let v_emb = embeddings.get(v).unwrap_or(&[]);
779                    let u_emb = embeddings.get(pos_u).unwrap_or(&[]);
780                    let pos_score = dot_product(v_emb, u_emb);
781                    // Log-sigmoid of positive score
782                    epoch_loss -= log_sigmoid(pos_score);
783
784                    // Negative samples
785                    for _ in 0..self.batch_cfg.num_negative_samples {
786                        let neg = rng.next_usize_mod(n);
787                        if neg == v {
788                            continue;
789                        }
790                        let neg_emb = embeddings.get(neg).unwrap_or(&[]);
791                        let neg_score = dot_product(v_emb, neg_emb);
792                        // Log-sigmoid of negative
793                        epoch_loss -= log_sigmoid(-neg_score);
794                    }
795                    num_pairs += 1;
796                }
797            }
798            if num_pairs > 0 {
799                epoch_loss /= num_pairs as f64;
800            }
801            loss_history.push(epoch_loss);
802            tracing::debug!(
803                "MiniBatchGraphSAGE epoch {}/{}: loss={:.6}",
804                epoch + 1,
805                self.batch_cfg.epochs,
806                epoch_loss
807            );
808        }
809
810        let final_loss = loss_history.last().copied().unwrap_or(f64::NAN);
811        Ok(TrainingMetrics {
812            epochs_completed: self.batch_cfg.epochs,
813            loss_history,
814            final_loss,
815        })
816    }
817
818    /// Compute final embeddings after training.
819    pub fn embed(&self, graph: &Graph) -> Result<GraphSAGEEmbeddings> {
820        self.model.embed(graph)
821    }
822}
823
824// ---------------------------------------------------------------------------
825// GraphSAGEEmbeddings
826// ---------------------------------------------------------------------------
827
828/// Node embeddings produced by `GraphSAGEModel`.
829#[derive(Debug, Clone)]
830pub struct GraphSAGEEmbeddings {
831    pub embeddings: Vec<Vec<f64>>,
832    pub num_nodes: usize,
833    pub dim: usize,
834}
835
836impl GraphSAGEEmbeddings {
837    /// Get embedding for node `v`.
838    pub fn get(&self, v: usize) -> Option<&[f64]> {
839        self.embeddings.get(v).map(|e| e.as_slice())
840    }
841
842    /// Cosine similarity between nodes `a` and `b`.
843    /// Returns `None` if either embedding is zero or out of bounds.
844    pub fn cosine_similarity(&self, a: usize, b: usize) -> Option<f64> {
845        let ea = self.embeddings.get(a)?;
846        let eb = self.embeddings.get(b)?;
847        Some(cosine_similarity_vecs(ea, eb))
848    }
849
850    /// Top-k nodes most similar to `query_node` (excludes `query_node` itself).
851    pub fn top_k_similar(&self, query_node: usize, k: usize) -> Vec<(usize, f64)> {
852        let query_emb = match self.embeddings.get(query_node) {
853            Some(e) => e,
854            None => return Vec::new(),
855        };
856        let mut sims: Vec<(usize, f64)> = self
857            .embeddings
858            .iter()
859            .enumerate()
860            .filter(|(i, _)| *i != query_node)
861            .map(|(i, e)| (i, cosine_similarity_vecs(query_emb, e)))
862            .collect();
863        sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
864        sims.truncate(k);
865        sims
866    }
867
868    /// Build a map from node index to label embedding pairs.
869    pub fn labeled_embeddings(&self, labels: &[usize]) -> HashMap<usize, (Vec<f64>, usize)> {
870        self.embeddings
871            .iter()
872            .enumerate()
873            .filter_map(|(i, emb)| labels.get(i).map(|&l| (i, (emb.clone(), l))))
874            .collect()
875    }
876}
877
878// ---------------------------------------------------------------------------
879// Utility functions
880// ---------------------------------------------------------------------------
881
882fn dot_product(a: &[f64], b: &[f64]) -> f64 {
883    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
884}
885
886fn log_sigmoid(x: f64) -> f64 {
887    // log(sigmoid(x)) = -log(1 + exp(-x)), numerically stable
888    if x >= 0.0 {
889        -(1.0 + (-x).exp()).ln()
890    } else {
891        x - (1.0 + x.exp()).ln()
892    }
893}
894
895fn cosine_similarity_vecs(a: &[f64], b: &[f64]) -> f64 {
896    let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
897    let na: f64 = a.iter().map(|&x| x * x).sum::<f64>().sqrt();
898    let nb: f64 = b.iter().map(|&x| x * x).sum::<f64>().sqrt();
899    if na < 1e-12 || nb < 1e-12 {
900        return 0.0;
901    }
902    (dot / (na * nb)).clamp(-1.0, 1.0)
903}
904
905fn l2_normalize_inplace(v: &mut [f64]) {
906    let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
907    if norm > 1e-12 {
908        v.iter_mut().for_each(|x| *x /= norm);
909    }
910}
911
912// ---------------------------------------------------------------------------
913// Tests
914// ---------------------------------------------------------------------------
915
916#[cfg(test)]
917mod tests {
918    use super::*;
919
920    fn ring_graph(n: usize, feat_dim: usize, seed: u64) -> Graph {
921        let mut rng = Lcg::new(seed);
922        let features: Vec<Vec<f64>> = (0..n)
923            .map(|_| (0..feat_dim).map(|_| rng.next_f64()).collect())
924            .collect();
925        let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
926        for i in 0..n {
927            let next = (i + 1) % n;
928            adjacency[i].push(next);
929            adjacency[next].push(i);
930        }
931        // Dedup
932        for nbrs in &mut adjacency {
933            nbrs.sort_unstable();
934            nbrs.dedup();
935        }
936        Graph::new(features, adjacency).expect("ring graph construction should succeed")
937    }
938
939    #[test]
940    fn test_graph_construction() {
941        let g = ring_graph(6, 8, 1);
942        assert_eq!(g.num_nodes(), 6);
943        assert_eq!(g.feature_dim(), 8);
944        assert_eq!(g.neighbors(0).len(), 2);
945    }
946
947    #[test]
948    fn test_graph_invalid_adjacency() {
949        let feats = vec![vec![1.0f64; 4]; 3];
950        let adj = vec![vec![1usize, 99], vec![0], vec![0]];
951        assert!(Graph::new(feats, adj).is_err());
952    }
953
954    #[test]
955    fn test_mean_aggregator() {
956        let agg = MeanAggregator;
957        let feats = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
958        let result = agg.aggregate(&feats, 2);
959        assert_eq!(result, vec![2.0, 3.0]);
960        assert_eq!(agg.output_dim(2), 2);
961    }
962
963    #[test]
964    fn test_mean_aggregator_empty() {
965        let agg = MeanAggregator;
966        let result = agg.aggregate(&[], 4);
967        assert_eq!(result, vec![0.0; 4]);
968    }
969
970    #[test]
971    fn test_maxpool_aggregator() {
972        let mut rng = Lcg::new(1);
973        let agg = MaxPoolAggregator::new(4, 8, &mut rng);
974        let feats = vec![vec![1.0f64; 4], vec![-1.0f64; 4]];
975        let result = agg.aggregate(&feats, 4);
976        assert_eq!(result.len(), 8);
977        // All values should be >= 0 (ReLU applied)
978        for &v in &result {
979            assert!(v >= 0.0, "MaxPool result should be non-negative after ReLU");
980        }
981    }
982
983    #[test]
984    fn test_meanpool_aggregator() {
985        let mut rng = Lcg::new(2);
986        let agg = MeanPoolAggregator::new(4, 8, &mut rng);
987        let feats = vec![vec![1.0f64; 4]; 3];
988        let result = agg.aggregate(&feats, 4);
989        assert_eq!(result.len(), 8);
990    }
991
992    #[test]
993    fn test_lstm_aggregator() {
994        let mut rng = Lcg::new(3);
995        let agg = LSTMAggregator::new(4, 8, &mut rng);
996        let feats = vec![vec![0.5f64; 4]; 5];
997        let result = agg.aggregate(&feats, 4);
998        assert_eq!(result.len(), 8);
999        // GRU output is bounded by tanh range
1000        for &v in &result {
1001            assert!(v.is_finite(), "LSTM output should be finite");
1002        }
1003    }
1004
1005    #[test]
1006    fn test_graphsage_layer_mean() {
1007        let mut rng = Lcg::new(42);
1008        let layer = GraphSAGELayer::new(4, 8, 3, &AggregatorKind::Mean, &mut rng)
1009            .expect("layer should construct");
1010        let g = ring_graph(5, 4, 7);
1011        let embeddings = layer.forward(&g, &g.node_features, &mut rng);
1012        assert_eq!(embeddings.len(), 5);
1013        for emb in &embeddings {
1014            assert_eq!(emb.len(), 8);
1015        }
1016    }
1017
1018    #[test]
1019    fn test_graphsage_model_default() {
1020        let config = GraphSAGEConfig {
1021            input_dim: 8,
1022            hidden_dims: vec![16],
1023            output_dim: 4,
1024            aggregator_kind: AggregatorKind::Mean,
1025            num_samples_per_layer: vec![3, 3],
1026            normalize_output: true,
1027            seed: 1,
1028        };
1029        let model = GraphSAGEModel::new(config).expect("model should construct");
1030        let g = ring_graph(6, 8, 5);
1031        let embs = model.embed(&g).expect("embed should succeed");
1032        assert_eq!(embs.num_nodes, 6);
1033        assert_eq!(embs.dim, 4);
1034        for i in 0..6 {
1035            assert_eq!(embs.get(i).expect("embedding exists").len(), 4);
1036        }
1037    }
1038
1039    #[test]
1040    fn test_graphsage_model_maxpool() {
1041        let config = GraphSAGEConfig {
1042            input_dim: 4,
1043            hidden_dims: vec![],
1044            output_dim: 4,
1045            aggregator_kind: AggregatorKind::MaxPool { hidden_dim: 8 },
1046            num_samples_per_layer: vec![5],
1047            normalize_output: false,
1048            seed: 2,
1049        };
1050        let model = GraphSAGEModel::new(config).expect("model should construct");
1051        let g = ring_graph(4, 4, 2);
1052        let embs = model.embed(&g).expect("embed should succeed");
1053        assert_eq!(embs.num_nodes, 4);
1054    }
1055
1056    #[test]
1057    fn test_graphsage_model_meanpool() {
1058        let config = GraphSAGEConfig {
1059            input_dim: 4,
1060            hidden_dims: vec![],
1061            output_dim: 4,
1062            aggregator_kind: AggregatorKind::MeanPool { hidden_dim: 8 },
1063            num_samples_per_layer: vec![5],
1064            normalize_output: false,
1065            seed: 3,
1066        };
1067        let model = GraphSAGEModel::new(config).expect("model should construct");
1068        let g = ring_graph(4, 4, 3);
1069        let embs = model.embed(&g).expect("embed should succeed");
1070        assert_eq!(embs.num_nodes, 4);
1071    }
1072
1073    #[test]
1074    fn test_graphsage_model_lstm() {
1075        let config = GraphSAGEConfig {
1076            input_dim: 4,
1077            hidden_dims: vec![],
1078            output_dim: 4,
1079            aggregator_kind: AggregatorKind::Lstm { hidden_dim: 8 },
1080            num_samples_per_layer: vec![5],
1081            normalize_output: true,
1082            seed: 4,
1083        };
1084        let model = GraphSAGEModel::new(config).expect("model should construct");
1085        let g = ring_graph(4, 4, 4);
1086        let embs = model.embed(&g).expect("embed should succeed");
1087        assert_eq!(embs.num_nodes, 4);
1088        // Normalized output: each norm <= 1+eps
1089        for i in 0..4 {
1090            let emb = embs.get(i).expect("embedding exists");
1091            let norm: f64 = emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
1092            assert!(norm <= 1.0 + 1e-6, "norm {} should be <= 1", norm);
1093        }
1094    }
1095
1096    #[test]
1097    fn test_graphsage_top_k_similar() {
1098        let config = GraphSAGEConfig {
1099            input_dim: 4,
1100            hidden_dims: vec![8],
1101            output_dim: 4,
1102            aggregator_kind: AggregatorKind::Mean,
1103            num_samples_per_layer: vec![3, 3],
1104            normalize_output: true,
1105            seed: 5,
1106        };
1107        let model = GraphSAGEModel::new(config).expect("model should construct");
1108        let g = ring_graph(8, 4, 6);
1109        let embs = model.embed(&g).expect("embed should succeed");
1110        let top3 = embs.top_k_similar(0, 3);
1111        assert!(top3.len() <= 3);
1112        for window in top3.windows(2) {
1113            assert!(
1114                window[0].1 >= window[1].1 - 1e-10,
1115                "top_k should be sorted descending"
1116            );
1117        }
1118    }
1119
1120    #[test]
1121    fn test_graphsage_inductive_embed_new_node() {
1122        let config = GraphSAGEConfig {
1123            input_dim: 4,
1124            hidden_dims: vec![8],
1125            output_dim: 4,
1126            aggregator_kind: AggregatorKind::Mean,
1127            num_samples_per_layer: vec![3, 3],
1128            normalize_output: true,
1129            seed: 9,
1130        };
1131        let model = GraphSAGEModel::new(config).expect("model should construct");
1132        let g = ring_graph(5, 4, 10);
1133        // Get embeddings of existing nodes to use as neighbor context
1134        let embs = model.embed(&g).expect("embed should succeed");
1135        let neighbor_embs: Vec<Vec<f64>> = vec![
1136            embs.get(0).expect("exists").to_vec(),
1137            embs.get(1).expect("exists").to_vec(),
1138        ];
1139        let new_node_features = vec![0.5f64; 4];
1140        let new_emb = model
1141            .embed_new_node(&new_node_features, &neighbor_embs)
1142            .expect("inductive embed should succeed");
1143        assert_eq!(new_emb.len(), 4);
1144        let norm: f64 = new_emb.iter().map(|&x| x * x).sum::<f64>().sqrt();
1145        assert!(
1146            norm <= 1.0 + 1e-6,
1147            "normalized embedding norm should be <= 1"
1148        );
1149    }
1150
1151    #[test]
1152    fn test_minibatch_graphsage_train() {
1153        let sage_cfg = GraphSAGEConfig {
1154            input_dim: 4,
1155            hidden_dims: vec![8],
1156            output_dim: 4,
1157            aggregator_kind: AggregatorKind::Mean,
1158            num_samples_per_layer: vec![3, 3],
1159            normalize_output: true,
1160            seed: 7,
1161        };
1162        let batch_cfg = MiniBatchConfig {
1163            epochs: 3,
1164            batch_size: 4,
1165            num_negative_samples: 2,
1166            learning_rate: 0.01,
1167            seed: 7,
1168        };
1169        let mut trainer =
1170            MiniBatchGraphSAGE::new(sage_cfg, batch_cfg).expect("trainer should construct");
1171        let g = ring_graph(8, 4, 8);
1172        let metrics = trainer.train(&g).expect("training should succeed");
1173        assert_eq!(metrics.epochs_completed, 3);
1174        assert_eq!(metrics.loss_history.len(), 3);
1175        for &loss in &metrics.loss_history {
1176            assert!(loss.is_finite(), "loss should be finite");
1177        }
1178    }
1179
1180    #[test]
1181    fn test_minibatch_graphsage_embed_after_train() {
1182        let sage_cfg = GraphSAGEConfig {
1183            input_dim: 4,
1184            hidden_dims: vec![],
1185            output_dim: 4,
1186            aggregator_kind: AggregatorKind::Mean,
1187            num_samples_per_layer: vec![3],
1188            normalize_output: true,
1189            seed: 11,
1190        };
1191        let batch_cfg = MiniBatchConfig {
1192            epochs: 2,
1193            batch_size: 3,
1194            num_negative_samples: 1,
1195            learning_rate: 0.01,
1196            seed: 11,
1197        };
1198        let mut trainer =
1199            MiniBatchGraphSAGE::new(sage_cfg, batch_cfg).expect("trainer should construct");
1200        let g = ring_graph(5, 4, 12);
1201        trainer.train(&g).expect("training should succeed");
1202        let embs = trainer.embed(&g).expect("embed should succeed");
1203        assert_eq!(embs.num_nodes, 5);
1204        assert_eq!(embs.dim, 4);
1205    }
1206
1207    #[test]
1208    fn test_graphsage_with_labels() {
1209        let g = ring_graph(4, 4, 20)
1210            .with_labels(vec![0, 1, 0, 1])
1211            .expect("labels should attach");
1212        assert!(g.labels.is_some());
1213        let config = GraphSAGEConfig {
1214            input_dim: 4,
1215            hidden_dims: vec![],
1216            output_dim: 4,
1217            aggregator_kind: AggregatorKind::Mean,
1218            num_samples_per_layer: vec![3],
1219            normalize_output: true,
1220            seed: 20,
1221        };
1222        let model = GraphSAGEModel::new(config).expect("model should construct");
1223        let embs = model.embed(&g).expect("embed should succeed");
1224        let labels = g.labels.as_ref().expect("labels exist");
1225        let labeled = embs.labeled_embeddings(labels);
1226        assert_eq!(labeled.len(), 4);
1227    }
1228
1229    #[test]
1230    fn test_lcg_reproducibility() {
1231        let mut a = Lcg::new(99);
1232        let mut b = Lcg::new(99);
1233        for _ in 0..200 {
1234            assert_eq!(a.next_u64(), b.next_u64());
1235        }
1236    }
1237
1238    #[test]
1239    fn test_graphsage_invalid_config() {
1240        assert!(GraphSAGEModel::new(GraphSAGEConfig {
1241            input_dim: 0,
1242            ..Default::default()
1243        })
1244        .is_err());
1245        assert!(GraphSAGEModel::new(GraphSAGEConfig {
1246            output_dim: 0,
1247            ..Default::default()
1248        })
1249        .is_err());
1250    }
1251}