Skip to main content

scirs2_graph/gnn/
sage.rs

1//! GraphSAGE (Sample and Aggregate) layers
2//!
3//! Implements the inductive node embedding framework from Hamilton, Ying &
4//! Leskovec (2017), "Inductive Representation Learning on Large Graphs".
5//!
6//! GraphSAGE learns to aggregate feature information from local neighborhoods,
7//! enabling inductive generalization to unseen nodes.
8//!
9//! Supported aggregation types:
10//! - **Mean** – element-wise mean of neighbor features
11//! - **Max** – element-wise max-pooling of neighbor features
12//! - **Sum** – element-wise sum of neighbor features
13//! - **LSTM** – sequential LSTM over a randomly-ordered neighborhood sample
14//!   (full LSTM backprop is outside scope; approximated here with a simple
15//!   gated recurrent aggregation for the forward pass)
16
17use scirs2_core::ndarray::{Array1, Array2};
18use scirs2_core::random::{Rng, RngExt};
19
20use crate::error::{GraphError, Result};
21use crate::gnn::gcn::CsrMatrix;
22
23// ============================================================================
24// Aggregation type
25// ============================================================================
26
27/// Aggregation strategy used to collect neighbor messages in GraphSAGE.
28#[derive(Debug, Clone, PartialEq, Eq, Default)]
29pub enum SageAggregation {
30    /// Element-wise arithmetic mean over the neighborhood
31    #[default]
32    Mean,
33    /// Element-wise maximum (max-pooling)
34    Max,
35    /// Element-wise sum
36    Sum,
37    /// Gated LSTM-style sequential aggregation over neighbors
38    Lstm,
39}
40
41// ============================================================================
42// Neighborhood sampling
43// ============================================================================
44
45/// Sample up to `k` neighbors for each node, returning indices.
46///
47/// If a node has fewer than `k` neighbors, all neighbors are returned (no
48/// replacement sampling).  The sampling is deterministic per `seed`.
49///
50/// # Arguments
51/// * `adj` – Sparse adjacency (any direction; the function treats each row as
52///   the neighbor list of that node).
53/// * `k` – Maximum neighborhood size to sample.
54///
55/// # Returns
56/// `sampled[i]` contains up to `k` neighbor indices for node `i`.
57pub fn sample_neighbors(adj: &CsrMatrix, k: usize) -> Vec<Vec<usize>> {
58    let n = adj.n_rows;
59    let mut rng = scirs2_core::random::rng();
60
61    (0..n)
62        .map(|i| {
63            let start = adj.indptr[i];
64            let end = adj.indptr[i + 1];
65            let neighbors: Vec<usize> = adj.indices[start..end].to_vec();
66            if neighbors.len() <= k {
67                neighbors
68            } else {
69                // Reservoir sampling
70                let mut reservoir: Vec<usize> = neighbors[..k].to_vec();
71                for idx in k..neighbors.len() {
72                    let j = (rng.random::<f64>() * (idx + 1) as f64) as usize;
73                    if j < k {
74                        reservoir[j] = neighbors[idx];
75                    }
76                }
77                reservoir
78            }
79        })
80        .collect()
81}
82
83// ============================================================================
84// Neighborhood aggregation (functional API)
85// ============================================================================
86
87/// Aggregate neighbor features using the specified aggregation type.
88///
89/// # Arguments
90/// * `adj` – Sparse adjacency matrix (row i → neighbors of node i).
91/// * `features` – Node feature matrix `[n_nodes, feat_dim]`.
92/// * `aggr_type` – Which aggregation to apply.
93///
94/// # Returns
95/// Aggregated neighbor embeddings `[n_nodes, feat_dim]`.  Isolated nodes
96/// (no neighbors) receive a zero vector.
97pub fn sage_aggregate(
98    adj: &CsrMatrix,
99    features: &Array2<f64>,
100    aggr_type: &SageAggregation,
101) -> Result<Array2<f64>> {
102    let n = adj.n_rows;
103    let (feat_n, feat_dim) = features.dim();
104
105    if feat_n != n {
106        return Err(GraphError::InvalidParameter {
107            param: "features".to_string(),
108            value: format!("{feat_n} rows"),
109            expected: format!("{n} rows (matching adj.n_rows)"),
110            context: "sage_aggregate".to_string(),
111        });
112    }
113
114    let mut agg = Array2::<f64>::zeros((n, feat_dim));
115
116    match aggr_type {
117        SageAggregation::Mean | SageAggregation::Sum => {
118            let mut counts = vec![0usize; n];
119            for (row, col, _) in adj.iter() {
120                if col < feat_n {
121                    counts[row] += 1;
122                    for k in 0..feat_dim {
123                        agg[[row, k]] += features[[col, k]];
124                    }
125                }
126            }
127            if *aggr_type == SageAggregation::Mean {
128                for i in 0..n {
129                    if counts[i] > 0 {
130                        let inv = 1.0 / counts[i] as f64;
131                        for k in 0..feat_dim {
132                            agg[[i, k]] *= inv;
133                        }
134                    }
135                }
136            }
137        }
138
139        SageAggregation::Max => {
140            // Initialize to NEG_INFINITY, then reduce
141            let mut initialized = vec![false; n];
142            for (row, col, _) in adj.iter() {
143                if col < feat_n {
144                    if !initialized[row] {
145                        for k in 0..feat_dim {
146                            agg[[row, k]] = features[[col, k]];
147                        }
148                        initialized[row] = true;
149                    } else {
150                        for k in 0..feat_dim {
151                            if features[[col, k]] > agg[[row, k]] {
152                                agg[[row, k]] = features[[col, k]];
153                            }
154                        }
155                    }
156                }
157            }
158            // Nodes with no neighbors keep zero (already set)
159        }
160
161        SageAggregation::Lstm => {
162            // Gated sequential aggregation over neighbors (approximates LSTM
163            // forward pass without backprop).  For each node i we process its
164            // neighbor features in order; a hidden state h is updated via:
165            //   z = sigmoid(x + h)
166            //   h = z * h + (1 - z) * x   (simplified GRU-like update)
167            for i in 0..n {
168                let start = adj.indptr[i];
169                let end = adj.indptr[i + 1];
170                let neighbor_indices = &adj.indices[start..end];
171
172                if neighbor_indices.is_empty() {
173                    continue;
174                }
175
176                let mut h = vec![0.0f64; feat_dim];
177                for &nb in neighbor_indices {
178                    if nb < feat_n {
179                        for k in 0..feat_dim {
180                            let x = features[[nb, k]];
181                            // Sigmoid gate
182                            let z = 1.0 / (1.0 + (-(x + h[k])).exp());
183                            h[k] = z * h[k] + (1.0 - z) * x;
184                        }
185                    }
186                }
187                for k in 0..feat_dim {
188                    agg[[i, k]] = h[k];
189                }
190            }
191        }
192    }
193
194    Ok(agg)
195}
196
197// ============================================================================
198// GraphSAGE layer
199// ============================================================================
200
201/// A single GraphSAGE layer.
202///
203/// Concatenates each node's own representation with the aggregated neighbor
204/// representation, then applies a linear transformation:
205/// ```text
206///   h_v = σ( W · concat(h_v, AGG({h_u : u ∈ N(v)})) + b )
207/// ```
208/// The output is L2-normalized (per-node) following the original paper.
209#[derive(Debug, Clone)]
210pub struct GraphSageLayer {
211    /// Weight matrix `[2 * in_dim, out_dim]`
212    pub weights: Array2<f64>,
213    /// Optional bias `[out_dim]`
214    pub bias: Option<Array1<f64>>,
215    /// Input feature dimension
216    pub in_dim: usize,
217    /// Output feature dimension
218    pub out_dim: usize,
219    /// Aggregation strategy
220    pub aggregation: SageAggregation,
221    /// Apply ReLU activation
222    pub use_relu: bool,
223    /// Apply L2 normalization on output embeddings
224    pub normalize: bool,
225}
226
227impl GraphSageLayer {
228    /// Create a new GraphSAGE layer with Glorot-uniform initialization.
229    ///
230    /// # Arguments
231    /// * `in_dim` – Input feature dimension per node.
232    /// * `out_dim` – Output embedding dimension.
233    pub fn new(in_dim: usize, out_dim: usize) -> Self {
234        let concat_dim = 2 * in_dim;
235        let scale = (6.0_f64 / (concat_dim + out_dim) as f64).sqrt();
236        let mut rng = scirs2_core::random::rng();
237        let weights = Array2::from_shape_fn((concat_dim, out_dim), |_| {
238            rng.random::<f64>() * 2.0 * scale - scale
239        });
240        GraphSageLayer {
241            weights,
242            bias: None,
243            in_dim,
244            out_dim,
245            aggregation: SageAggregation::Mean,
246            use_relu: true,
247            normalize: true,
248        }
249    }
250
251    /// Use a specific aggregation strategy.
252    pub fn with_aggregation(mut self, aggr: SageAggregation) -> Self {
253        self.aggregation = aggr;
254        self
255    }
256
257    /// Disable L2 normalization on output.
258    pub fn without_normalize(mut self) -> Self {
259        self.normalize = false;
260        self
261    }
262
263    /// Disable ReLU activation.
264    pub fn without_activation(mut self) -> Self {
265        self.use_relu = false;
266        self
267    }
268
269    /// Forward pass.
270    ///
271    /// # Arguments
272    /// * `adj` – Sparse adjacency (row i → outgoing neighbors of i).
273    /// * `features` – Node feature matrix `[n_nodes, in_dim]`.
274    pub fn forward(&self, adj: &CsrMatrix, features: &Array2<f64>) -> Result<Array2<f64>> {
275        let n = adj.n_rows;
276        let (feat_n, feat_dim) = features.dim();
277
278        if feat_n != n {
279            return Err(GraphError::InvalidParameter {
280                param: "features".to_string(),
281                value: format!("{feat_n}"),
282                expected: format!("{n}"),
283                context: "GraphSageLayer::forward".to_string(),
284            });
285        }
286        if feat_dim != self.in_dim {
287            return Err(GraphError::InvalidParameter {
288                param: "features feat_dim".to_string(),
289                value: format!("{feat_dim}"),
290                expected: format!("{}", self.in_dim),
291                context: "GraphSageLayer::forward".to_string(),
292            });
293        }
294
295        // Step 1: aggregate neighbor features
296        let agg = sage_aggregate(adj, features, &self.aggregation)?;
297
298        // Step 2: concatenate [self_feat || agg_feat]  →  [n, 2*in_dim]
299        let concat_dim = 2 * self.in_dim;
300        let mut concat = Array2::<f64>::zeros((n, concat_dim));
301        for i in 0..n {
302            for k in 0..feat_dim {
303                concat[[i, k]] = features[[i, k]];
304                concat[[i, feat_dim + k]] = agg[[i, k]];
305            }
306        }
307
308        // Step 3: linear transform  concat @ weights  →  [n, out_dim]
309        let (_, out_dim) = self.weights.dim();
310        let mut output = Array2::<f64>::zeros((n, out_dim));
311        for i in 0..n {
312            for j in 0..out_dim {
313                let mut sum = 0.0;
314                for k in 0..concat_dim {
315                    sum += concat[[i, k]] * self.weights[[k, j]];
316                }
317                output[[i, j]] = sum;
318            }
319        }
320
321        // Add bias
322        if let Some(ref b) = self.bias {
323            for i in 0..n {
324                for j in 0..out_dim {
325                    output[[i, j]] += b[j];
326                }
327            }
328        }
329
330        // Activation
331        if self.use_relu {
332            output.mapv_inplace(|x| x.max(0.0));
333        }
334
335        // L2 normalize each row
336        if self.normalize {
337            for i in 0..n {
338                let norm = {
339                    let row = output.row(i);
340                    row.iter().map(|&x| x * x).sum::<f64>().sqrt()
341                };
342                if norm > 1e-10 {
343                    for j in 0..out_dim {
344                        output[[i, j]] /= norm;
345                    }
346                }
347            }
348        }
349
350        Ok(output)
351    }
352}
353
354// ============================================================================
355// Multi-layer GraphSAGE model
356// ============================================================================
357
358/// Multi-layer GraphSAGE model.
359///
360/// Stacks `GraphSageLayer`s with optional neighborhood sampling at each layer.
361pub struct GraphSage {
362    /// Layer stack
363    pub layers: Vec<GraphSageLayer>,
364    /// Optional max neighborhood size per layer (None = use all neighbors)
365    pub neighbor_samples: Vec<Option<usize>>,
366}
367
368impl GraphSage {
369    /// Build a GraphSAGE model from a list of `(in_dim, out_dim)` layer specs.
370    ///
371    /// # Arguments
372    /// * `dims` – Sequence `[d_0, d_1, …, d_L]`.
373    /// * `aggr` – Aggregation type applied to all layers.
374    pub fn new(dims: &[usize], aggr: SageAggregation) -> Result<Self> {
375        if dims.len() < 2 {
376            return Err(GraphError::InvalidParameter {
377                param: "dims".to_string(),
378                value: format!("len={}", dims.len()),
379                expected: "at least 2 elements".to_string(),
380                context: "GraphSage::new".to_string(),
381            });
382        }
383        let mut layers = Vec::with_capacity(dims.len() - 1);
384        for i in 0..(dims.len() - 1) {
385            let is_last = i == dims.len() - 2;
386            let mut layer =
387                GraphSageLayer::new(dims[i], dims[i + 1]).with_aggregation(aggr.clone());
388            if is_last {
389                layer = layer.without_activation();
390            }
391            layers.push(layer);
392        }
393        let neighbor_samples = vec![None; dims.len() - 1];
394        Ok(GraphSage {
395            layers,
396            neighbor_samples,
397        })
398    }
399
400    /// Set the maximum number of neighbors sampled at each layer.
401    ///
402    /// `sizes[i]` controls layer `i`.  Pass `None` for a layer to use all
403    /// neighbors.
404    pub fn with_neighbor_samples(mut self, sizes: Vec<Option<usize>>) -> Result<Self> {
405        if sizes.len() != self.layers.len() {
406            return Err(GraphError::InvalidParameter {
407                param: "sizes".to_string(),
408                value: format!("len={}", sizes.len()),
409                expected: format!("len={}", self.layers.len()),
410                context: "GraphSage::with_neighbor_samples".to_string(),
411            });
412        }
413        self.neighbor_samples = sizes;
414        Ok(self)
415    }
416
417    /// Forward pass through all layers.
418    ///
419    /// # Arguments
420    /// * `adj` – Sparse adjacency matrix.
421    /// * `features` – Initial feature matrix `[n_nodes, d_0]`.
422    pub fn forward(&self, adj: &CsrMatrix, features: &Array2<f64>) -> Result<Array2<f64>> {
423        let mut h = features.clone();
424        for (i, layer) in self.layers.iter().enumerate() {
425            // Optionally sub-sample the adjacency for mini-batch training
426            let sampled_adj = if let Some(k) = self.neighbor_samples[i] {
427                // Build a sub-sampled CSR from sampled neighbors
428                let sampled = sample_neighbors(adj, k);
429                let mut coo = Vec::new();
430                for (node_i, nbrs) in sampled.iter().enumerate() {
431                    for &nb in nbrs {
432                        coo.push((node_i, nb, 1.0f64));
433                    }
434                }
435                CsrMatrix::from_coo(adj.n_rows, adj.n_cols, &coo)?
436            } else {
437                adj.clone()
438            };
439            h = layer.forward(&sampled_adj, &h)?;
440        }
441        Ok(h)
442    }
443}
444
445// ============================================================================
446// Tests
447// ============================================================================
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    fn path_csr(n: usize) -> CsrMatrix {
454        let mut coo = Vec::new();
455        for i in 0..(n - 1) {
456            coo.push((i, i + 1, 1.0));
457            coo.push((i + 1, i, 1.0));
458        }
459        CsrMatrix::from_coo(n, n, &coo).expect("path CSR")
460    }
461
462    fn features(n: usize, d: usize) -> Array2<f64> {
463        Array2::from_shape_fn((n, d), |(i, j)| (i * d + j) as f64 * 0.1)
464    }
465
466    #[test]
467    fn test_mean_aggregate_shape() {
468        let adj = path_csr(4);
469        let feats = features(4, 6);
470        let agg = sage_aggregate(&adj, &feats, &SageAggregation::Mean).expect("mean agg");
471        assert_eq!(agg.dim(), (4, 6));
472    }
473
474    #[test]
475    fn test_max_aggregate_shape() {
476        let adj = path_csr(4);
477        let feats = features(4, 6);
478        let agg = sage_aggregate(&adj, &feats, &SageAggregation::Max).expect("max agg");
479        assert_eq!(agg.dim(), (4, 6));
480    }
481
482    #[test]
483    fn test_sum_aggregate_shape() {
484        let adj = path_csr(4);
485        let feats = features(4, 6);
486        let agg = sage_aggregate(&adj, &feats, &SageAggregation::Sum).expect("sum agg");
487        assert_eq!(agg.dim(), (4, 6));
488    }
489
490    #[test]
491    fn test_lstm_aggregate_shape() {
492        let adj = path_csr(4);
493        let feats = features(4, 6);
494        let agg = sage_aggregate(&adj, &feats, &SageAggregation::Lstm).expect("lstm agg");
495        assert_eq!(agg.dim(), (4, 6));
496    }
497
498    #[test]
499    fn test_sage_layer_output_shape() {
500        let adj = path_csr(5);
501        let feats = features(5, 4);
502        let layer = GraphSageLayer::new(4, 8);
503        let out = layer.forward(&adj, &feats).expect("sage forward");
504        assert_eq!(out.dim(), (5, 8));
505    }
506
507    #[test]
508    fn test_sage_layer_l2_normalization() {
509        let adj = path_csr(5);
510        let feats = features(5, 4);
511        let layer = GraphSageLayer::new(4, 8);
512        let out = layer.forward(&adj, &feats).expect("sage forward");
513        // Each row should have unit L2 norm (or be zero)
514        for i in 0..5 {
515            let norm: f64 = out.row(i).iter().map(|&x| x * x).sum::<f64>().sqrt();
516            assert!(
517                norm < 1e-10 || (norm - 1.0).abs() < 1e-9,
518                "norm={norm} for row {i}"
519            );
520        }
521    }
522
523    #[test]
524    fn test_graphsage_multilayer() {
525        let adj = path_csr(6);
526        let feats = features(6, 8);
527        let model = GraphSage::new(&[8, 16, 4], SageAggregation::Mean).expect("sage model");
528        let out = model.forward(&adj, &feats).expect("forward");
529        assert_eq!(out.dim(), (6, 4));
530    }
531
532    #[test]
533    fn test_neighbor_sampling() {
534        let adj = path_csr(4);
535        let sampled = sample_neighbors(&adj, 1);
536        assert_eq!(sampled.len(), 4);
537        // Internal nodes have 2 neighbors; sampled to 1
538        assert!(sampled[1].len() <= 1);
539        assert!(sampled[2].len() <= 1);
540    }
541
542    #[test]
543    fn test_graphsage_with_sampling() {
544        let adj = path_csr(6);
545        let feats = features(6, 4);
546        let model = GraphSage::new(&[4, 8, 4], SageAggregation::Mean)
547            .expect("sage model")
548            .with_neighbor_samples(vec![Some(2), Some(2)])
549            .expect("samples");
550        let out = model.forward(&adj, &feats).expect("forward");
551        assert_eq!(out.dim(), (6, 4));
552    }
553
554    #[test]
555    fn test_sage_aggregation_isolated_node() {
556        // Node 0 has no neighbors
557        let coo = vec![(1, 2, 1.0), (2, 1, 1.0)];
558        let adj = CsrMatrix::from_coo(3, 3, &coo).expect("isolated CSR");
559        let feats = features(3, 4);
560        let agg = sage_aggregate(&adj, &feats, &SageAggregation::Mean).expect("mean agg");
561        // Node 0: no neighbors → zero aggregation
562        for k in 0..4 {
563            assert_eq!(agg[[0, k]], 0.0);
564        }
565    }
566}