Skip to main content

scirs2_graph/gnn/
mod.rs

1//! Graph Neural Network (GNN) layers and message-passing framework
2//!
3//! This module implements core GNN building blocks following the message-passing
4//! neural network (MPNN) paradigm:
5//!
6//! - **Message Passing**: flexible aggregation over neighborhoods
7//! - **GCNLayer** (legacy Vec-based API): Graph Convolutional Network (Kipf & Welling 2017)
8//! - **GraphSAGELayer** (legacy Vec-based API): Sample-and-aggregate (Hamilton et al. 2017)
9//! - **GATLayer** (legacy Vec-based API): Graph Attention Network (Veličković et al. 2018)
10//!
11//! ### Array2-based (ndarray) API (new in 0.3.0)
12//!
13//! - [`gcn`] – `GcnLayer`, `Gcn`, `gcn_forward`, `add_self_loops`, `CsrMatrix`
14//! - [`sage`] – `GraphSageLayer`, `GraphSage`, `sage_aggregate`, `SageAggregation`
15//! - [`gat`] – `GraphAttentionLayer`, `gat_forward`
16
17// --- Sub-modules with the new Array2-based API ---
18pub mod equivariant;
19pub mod gat;
20pub mod gcn;
21pub mod hgt;
22pub mod kg_completion;
23pub mod relation_message;
24pub mod rgcn;
25pub mod sage;
26pub mod transformers;
27
28// --- Re-export new Array2 API types ---
29pub use gat::{gat_forward, GraphAttentionLayer};
30pub use gcn::{add_self_loops, gcn_forward, symmetric_normalize, CsrMatrix, Gcn, GcnLayer};
31pub use sage::{sage_aggregate, sample_neighbors, GraphSage, GraphSageLayer, SageAggregation};
32
33// --- Legacy Vec-based message-passing API (kept for backward compatibility) ---
34// The following items are re-exported from the inline implementation below so
35// that the existing `lib.rs` re-exports continue to work without modification.
36
37use std::collections::HashMap;
38
39use scirs2_core::random::{Rng, RngExt};
40
41use crate::base::{EdgeWeight, Graph, Node};
42use crate::error::{GraphError, Result};
43
44// ============================================================================
45// Message aggregation types (legacy)
46// ============================================================================
47
48/// Aggregation strategy for collecting neighbor messages (legacy Vec-based API)
49#[derive(Debug, Clone, PartialEq, Default)]
50pub enum MessagePassing {
51    /// Sum all neighbor messages
52    Sum,
53    /// Arithmetic mean of neighbor messages
54    #[default]
55    Mean,
56    /// Element-wise maximum
57    Max,
58    /// Element-wise minimum
59    Min,
60    /// Attention-weighted mean (weights computed internally)
61    Attention,
62}
63
64// ============================================================================
65// MessagePassingLayer trait (legacy)
66// ============================================================================
67
68/// Core trait for GNN layers following the message-passing paradigm
69///
70/// Implementors must provide `aggregate` (neighbourhood → message) and
71/// `update` (message + self → new embedding) methods.
72pub trait MessagePassingLayer {
73    /// Aggregate messages from the neighborhood of each node
74    fn aggregate(
75        &self,
76        node_features: &[Vec<f64>],
77        adjacency: &[(usize, usize, f64)],
78        n_nodes: usize,
79    ) -> Result<Vec<Vec<f64>>>;
80
81    /// Update node embeddings using aggregated messages and self-features
82    fn update(&self, aggregated: &[Vec<f64>], node_features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>>;
83
84    /// Run one full forward pass: aggregate + update
85    fn forward(
86        &self,
87        node_features: &[Vec<f64>],
88        adjacency: &[(usize, usize, f64)],
89    ) -> Result<Vec<Vec<f64>>> {
90        let n = node_features.len();
91        let aggregated = self.aggregate(node_features, adjacency, n)?;
92        self.update(&aggregated, node_features)
93    }
94}
95
96// ============================================================================
97// Helper utilities (legacy)
98// ============================================================================
99
100fn validate_features(features: &[Vec<f64>]) -> Result<usize> {
101    if features.is_empty() {
102        return Ok(0);
103    }
104    let dim = features[0].len();
105    for (i, row) in features.iter().enumerate() {
106        if row.len() != dim {
107            return Err(GraphError::InvalidParameter {
108                param: "node_features".to_string(),
109                value: format!("row {} has {} dims, expected {}", i, row.len(), dim),
110                expected: format!("all rows must have {} dimensions", dim),
111                context: "GNN feature validation".to_string(),
112            });
113        }
114    }
115    Ok(dim)
116}
117
118fn relu(x: f64) -> f64 {
119    x.max(0.0)
120}
121
122fn dot(a: &[f64], b: &[f64]) -> f64 {
123    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
124}
125
126fn softmax_vec(xs: &[f64]) -> Vec<f64> {
127    if xs.is_empty() {
128        return Vec::new();
129    }
130    let max_val = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
131    let exps: Vec<f64> = xs.iter().map(|x| (x - max_val).exp()).collect();
132    let sum: f64 = exps.iter().sum::<f64>().max(1e-10);
133    exps.iter().map(|e| e / sum).collect()
134}
135
136fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
137    w.iter().map(|row| dot(row, x)).collect()
138}
139
140/// Convert graph to sparse adjacency (src, dst, weight) triples
141pub fn graph_to_adjacency<N, E, Ix>(graph: &Graph<N, E, Ix>) -> (Vec<N>, Vec<(usize, usize, f64)>)
142where
143    N: Node + Clone + std::fmt::Debug,
144    E: EdgeWeight + Clone + Into<f64>,
145    Ix: petgraph::graph::IndexType,
146{
147    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
148    let node_to_idx: HashMap<N, usize> = nodes
149        .iter()
150        .enumerate()
151        .map(|(i, n)| (n.clone(), i))
152        .collect();
153
154    let mut adjacency = Vec::new();
155    for edge in graph.edges() {
156        if let (Some(&si), Some(&ti)) =
157            (node_to_idx.get(&edge.source), node_to_idx.get(&edge.target))
158        {
159            let w: f64 = edge.weight.clone().into();
160            adjacency.push((si, ti, w));
161            adjacency.push((ti, si, w)); // undirected
162        }
163    }
164
165    (nodes, adjacency)
166}
167
168// ============================================================================
169// GCNLayer (legacy Vec-based)
170// ============================================================================
171
172/// Graph Convolutional Network layer (Kipf & Welling, 2017) - legacy Vec API
173///
174/// Forward pass:
175/// ```text
176///   H' = σ( D̃^{-1/2} Ã D̃^{-1/2} H W )
177/// ```
178#[derive(Debug, Clone)]
179pub struct GCNLayer {
180    /// Weight matrix (out_dim × in_dim)
181    pub weights: Vec<Vec<f64>>,
182    /// Bias vector (out_dim)
183    pub bias: Vec<f64>,
184    /// Output dimension
185    pub out_dim: usize,
186    /// Aggregation strategy
187    pub aggregation: MessagePassing,
188    /// Whether to apply ReLU activation
189    pub use_activation: bool,
190}
191
192impl GCNLayer {
193    /// Create a new GCN layer
194    pub fn new(in_dim: usize, out_dim: usize) -> Self {
195        let scale = (2.0 / (in_dim + out_dim) as f64).sqrt();
196        let mut weights = vec![vec![0.0f64; in_dim]; out_dim];
197        for (i, row) in weights.iter_mut().enumerate() {
198            for (j, w) in row.iter_mut().enumerate() {
199                *w = if i == j {
200                    scale
201                } else {
202                    scale * 0.01 * ((i as f64 - j as f64).sin())
203                };
204            }
205        }
206        GCNLayer {
207            weights,
208            bias: vec![0.0; out_dim],
209            out_dim,
210            aggregation: MessagePassing::Mean,
211            use_activation: true,
212        }
213    }
214
215    /// Set custom weights
216    pub fn with_weights(mut self, weights: Vec<Vec<f64>>) -> Result<Self> {
217        if weights.len() != self.out_dim {
218            return Err(GraphError::InvalidParameter {
219                param: "weights".to_string(),
220                value: format!("rows={}", weights.len()),
221                expected: format!("rows={}", self.out_dim),
222                context: "GCNLayer::with_weights".to_string(),
223            });
224        }
225        self.weights = weights;
226        Ok(self)
227    }
228}
229
230impl MessagePassingLayer for GCNLayer {
231    fn aggregate(
232        &self,
233        node_features: &[Vec<f64>],
234        adjacency: &[(usize, usize, f64)],
235        n_nodes: usize,
236    ) -> Result<Vec<Vec<f64>>> {
237        let in_dim = validate_features(node_features)?;
238        if in_dim == 0 {
239            return Ok(Vec::new());
240        }
241
242        let mut deg = vec![1.0f64; n_nodes];
243        for &(src, dst, _) in adjacency {
244            deg[src] += 1.0;
245            let _ = dst;
246        }
247
248        let mut agg: Vec<Vec<f64>> = (0..n_nodes).map(|_| vec![0.0f64; in_dim]).collect();
249
250        for i in 0..n_nodes {
251            let d_inv = 1.0 / deg[i].sqrt();
252            for k in 0..in_dim {
253                agg[i][k] += d_inv * node_features[i][k] * d_inv;
254            }
255        }
256
257        for &(src, dst, w) in adjacency {
258            if src < n_nodes && dst < n_nodes {
259                let norm = w / (deg[src].sqrt() * deg[dst].sqrt());
260                for k in 0..in_dim {
261                    agg[dst][k] += norm * node_features[src][k];
262                }
263            }
264        }
265
266        Ok(agg)
267    }
268
269    fn update(
270        &self,
271        aggregated: &[Vec<f64>],
272        _node_features: &[Vec<f64>],
273    ) -> Result<Vec<Vec<f64>>> {
274        let mut result = Vec::with_capacity(aggregated.len());
275        for agg in aggregated {
276            let mut h = matvec(&self.weights, agg);
277            for (hi, bi) in h.iter_mut().zip(self.bias.iter()) {
278                *hi += bi;
279                if self.use_activation {
280                    *hi = relu(*hi);
281                }
282            }
283            result.push(h);
284        }
285        Ok(result)
286    }
287}
288
289// ============================================================================
290// GraphSAGELayer (legacy Vec-based)
291// ============================================================================
292
293/// GraphSAGE layer - legacy Vec API
294#[derive(Debug, Clone)]
295pub struct GraphSAGELayer {
296    /// Weight matrix for concatenated [self || neighbor_agg] (out × 2*in)
297    pub weights: Vec<Vec<f64>>,
298    /// Bias (out_dim)
299    pub bias: Vec<f64>,
300    /// Output dimension
301    pub out_dim: usize,
302    /// Aggregation strategy
303    pub aggregation: MessagePassing,
304    /// Whether to apply ReLU
305    pub use_activation: bool,
306}
307
308impl GraphSAGELayer {
309    /// Create a new GraphSAGE layer
310    pub fn new(in_dim: usize, out_dim: usize) -> Self {
311        let concat_dim = 2 * in_dim;
312        let scale = (2.0 / (concat_dim + out_dim) as f64).sqrt();
313        let mut weights = vec![vec![0.0f64; concat_dim]; out_dim];
314        for (i, row) in weights.iter_mut().enumerate() {
315            for (j, w) in row.iter_mut().enumerate() {
316                *w = if i == j % out_dim {
317                    scale
318                } else {
319                    scale * 0.01 * ((i as f64 - j as f64).cos())
320                };
321            }
322        }
323        GraphSAGELayer {
324            weights,
325            bias: vec![0.0; out_dim],
326            out_dim,
327            aggregation: MessagePassing::Mean,
328            use_activation: true,
329        }
330    }
331}
332
333impl MessagePassingLayer for GraphSAGELayer {
334    fn aggregate(
335        &self,
336        node_features: &[Vec<f64>],
337        adjacency: &[(usize, usize, f64)],
338        n_nodes: usize,
339    ) -> Result<Vec<Vec<f64>>> {
340        let in_dim = validate_features(node_features)?;
341        if in_dim == 0 {
342            return Ok(Vec::new());
343        }
344
345        let mut neighbor_sums: Vec<Vec<f64>> = (0..n_nodes).map(|_| vec![0.0f64; in_dim]).collect();
346        let mut neighbor_counts: Vec<f64> = vec![0.0; n_nodes];
347        let mut neighbor_max: Vec<Vec<f64>> = (0..n_nodes)
348            .map(|_| vec![f64::NEG_INFINITY; in_dim])
349            .collect();
350        let mut neighbor_min: Vec<Vec<f64>> =
351            (0..n_nodes).map(|_| vec![f64::INFINITY; in_dim]).collect();
352
353        for &(src, dst, _) in adjacency {
354            if src < n_nodes && dst < n_nodes {
355                neighbor_counts[dst] += 1.0;
356                for k in 0..in_dim {
357                    neighbor_sums[dst][k] += node_features[src][k];
358                    if node_features[src][k] > neighbor_max[dst][k] {
359                        neighbor_max[dst][k] = node_features[src][k];
360                    }
361                    if node_features[src][k] < neighbor_min[dst][k] {
362                        neighbor_min[dst][k] = node_features[src][k];
363                    }
364                }
365            }
366        }
367
368        let agg_neighbor: Vec<Vec<f64>> = (0..n_nodes)
369            .map(|i| {
370                let count = neighbor_counts[i].max(1.0);
371                match &self.aggregation {
372                    MessagePassing::Sum => neighbor_sums[i].clone(),
373                    MessagePassing::Mean => neighbor_sums[i].iter().map(|s| s / count).collect(),
374                    MessagePassing::Max => neighbor_max[i]
375                        .iter()
376                        .map(|&v| if v == f64::NEG_INFINITY { 0.0 } else { v })
377                        .collect(),
378                    MessagePassing::Min => neighbor_min[i]
379                        .iter()
380                        .map(|&v| if v == f64::INFINITY { 0.0 } else { v })
381                        .collect(),
382                    MessagePassing::Attention => {
383                        neighbor_sums[i].iter().map(|s| s / count).collect()
384                    }
385                }
386            })
387            .collect();
388
389        let concat: Vec<Vec<f64>> = node_features
390            .iter()
391            .zip(agg_neighbor.iter())
392            .map(|(self_feat, nbr)| {
393                let mut cat = self_feat.clone();
394                cat.extend_from_slice(nbr);
395                cat
396            })
397            .collect();
398
399        Ok(concat)
400    }
401
402    fn update(
403        &self,
404        aggregated: &[Vec<f64>],
405        _node_features: &[Vec<f64>],
406    ) -> Result<Vec<Vec<f64>>> {
407        let mut result = Vec::with_capacity(aggregated.len());
408        for agg in aggregated {
409            let mut h = matvec(&self.weights, agg);
410            for (hi, bi) in h.iter_mut().zip(self.bias.iter()) {
411                *hi += bi;
412                if self.use_activation {
413                    *hi = relu(*hi);
414                }
415            }
416            let norm: f64 = h.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
417            h.iter_mut().for_each(|x| *x /= norm);
418            result.push(h);
419        }
420        Ok(result)
421    }
422}
423
424// ============================================================================
425// GATLayer (legacy Vec-based)
426// ============================================================================
427
428/// Graph Attention Network layer - legacy Vec API
429#[derive(Debug, Clone)]
430pub struct GATLayer {
431    /// Node transformation matrix W (out_dim × in_dim)
432    pub weights: Vec<Vec<f64>>,
433    /// Attention vector a (2 * out_dim)
434    pub attention_weights: Vec<f64>,
435    /// Output dimension
436    pub out_dim: usize,
437    /// LeakyReLU negative slope
438    pub negative_slope: f64,
439    /// Whether to apply ELU activation on output
440    pub use_activation: bool,
441}
442
443impl GATLayer {
444    /// Create a new GAT layer
445    pub fn new(in_dim: usize, out_dim: usize) -> Self {
446        let scale = (2.0 / (in_dim + out_dim) as f64).sqrt();
447        let mut weights = vec![vec![0.0f64; in_dim]; out_dim];
448        for (i, row) in weights.iter_mut().enumerate() {
449            for (j, w) in row.iter_mut().enumerate() {
450                *w = if i == j { scale } else { scale * 0.01 };
451            }
452        }
453        let attention_weights: Vec<f64> = (0..2 * out_dim)
454            .map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
455            .collect();
456        GATLayer {
457            weights,
458            attention_weights,
459            out_dim,
460            negative_slope: 0.2,
461            use_activation: true,
462        }
463    }
464
465    fn leaky_relu(&self, x: f64) -> f64 {
466        if x >= 0.0 {
467            x
468        } else {
469            self.negative_slope * x
470        }
471    }
472}
473
474impl MessagePassingLayer for GATLayer {
475    fn aggregate(
476        &self,
477        node_features: &[Vec<f64>],
478        adjacency: &[(usize, usize, f64)],
479        n_nodes: usize,
480    ) -> Result<Vec<Vec<f64>>> {
481        let _in_dim = validate_features(node_features)?;
482
483        let transformed: Vec<Vec<f64>> = node_features
484            .iter()
485            .map(|h| matvec(&self.weights, h))
486            .collect();
487
488        let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n_nodes];
489        for &(src, dst, _) in adjacency {
490            if src < n_nodes && dst < n_nodes {
491                neighbors[dst].push(src);
492            }
493        }
494        for i in 0..n_nodes {
495            if !neighbors[i].contains(&i) {
496                neighbors[i].push(i);
497            }
498        }
499
500        let mut aggregated: Vec<Vec<f64>> = vec![vec![0.0; self.out_dim]; n_nodes];
501
502        for i in 0..n_nodes {
503            let nbrs = &neighbors[i];
504            if nbrs.is_empty() {
505                continue;
506            }
507
508            let scores: Vec<f64> = nbrs
509                .iter()
510                .map(|&j| {
511                    let mut concat = transformed[i].clone();
512                    concat.extend_from_slice(&transformed[j]);
513                    let e = dot(&self.attention_weights, &concat);
514                    self.leaky_relu(e)
515                })
516                .collect();
517
518            let alphas = softmax_vec(&scores);
519
520            for (k, &j) in nbrs.iter().enumerate() {
521                let alpha = alphas[k];
522                for d in 0..self.out_dim {
523                    aggregated[i][d] += alpha * transformed[j][d];
524                }
525            }
526        }
527
528        Ok(aggregated)
529    }
530
531    fn update(
532        &self,
533        aggregated: &[Vec<f64>],
534        _node_features: &[Vec<f64>],
535    ) -> Result<Vec<Vec<f64>>> {
536        if !self.use_activation {
537            return Ok(aggregated.to_vec());
538        }
539        let result: Vec<Vec<f64>> = aggregated
540            .iter()
541            .map(|row| {
542                row.iter()
543                    .map(|&x| if x >= 0.0 { x } else { x.exp() - 1.0 })
544                    .collect()
545            })
546            .collect();
547        Ok(result)
548    }
549}
550
551// ============================================================================
552// NodeEmbedding: high-level embedding container (legacy)
553// ============================================================================
554
555/// Container for node feature embeddings
556#[derive(Debug, Clone)]
557pub struct NodeEmbeddings {
558    /// Node index to name mapping (optional)
559    pub node_names: Vec<String>,
560    /// Feature matrix (n_nodes × embedding_dim)
561    pub embeddings: Vec<Vec<f64>>,
562    /// Embedding dimensionality
563    pub dim: usize,
564}
565
566impl NodeEmbeddings {
567    /// Create embeddings from raw feature matrix
568    pub fn new(embeddings: Vec<Vec<f64>>) -> Result<Self> {
569        let dim = validate_features(&embeddings)?;
570        let n = embeddings.len();
571        Ok(NodeEmbeddings {
572            node_names: (0..n).map(|i| i.to_string()).collect(),
573            embeddings,
574            dim,
575        })
576    }
577
578    /// Create random initial embeddings
579    pub fn random(n_nodes: usize, dim: usize) -> Self {
580        let mut rng = scirs2_core::random::rng();
581        let embeddings: Vec<Vec<f64>> = (0..n_nodes)
582            .map(|_| (0..dim).map(|_| rng.random::<f64>() * 2.0 - 1.0).collect())
583            .collect();
584        NodeEmbeddings {
585            node_names: (0..n_nodes).map(|i| i.to_string()).collect(),
586            embeddings,
587            dim,
588        }
589    }
590
591    /// Create one-hot embeddings
592    pub fn one_hot(n_nodes: usize) -> Self {
593        let embeddings: Vec<Vec<f64>> = (0..n_nodes)
594            .map(|i| {
595                let mut v = vec![0.0f64; n_nodes];
596                v[i] = 1.0;
597                v
598            })
599            .collect();
600        NodeEmbeddings {
601            node_names: (0..n_nodes).map(|i| i.to_string()).collect(),
602            embeddings,
603            dim: n_nodes,
604        }
605    }
606
607    /// Get the number of nodes
608    pub fn n_nodes(&self) -> usize {
609        self.embeddings.len()
610    }
611
612    /// Get embedding for node i
613    pub fn get(&self, i: usize) -> Option<&Vec<f64>> {
614        self.embeddings.get(i)
615    }
616
617    /// Apply a GNN layer to these embeddings
618    pub fn apply_layer<L: MessagePassingLayer>(
619        &self,
620        layer: &L,
621        adjacency: &[(usize, usize, f64)],
622    ) -> Result<NodeEmbeddings> {
623        let new_embeddings = layer.forward(&self.embeddings, adjacency)?;
624        let dim = validate_features(&new_embeddings)?;
625        Ok(NodeEmbeddings {
626            node_names: self.node_names.clone(),
627            embeddings: new_embeddings,
628            dim,
629        })
630    }
631}
632
633/// Build a GNN pipeline and apply it to a graph
634pub fn run_gnn_pipeline<N, E, Ix, L>(
635    graph: &Graph<N, E, Ix>,
636    initial_features: Option<NodeEmbeddings>,
637    layers: &[L],
638) -> Result<NodeEmbeddings>
639where
640    N: Node + Clone + std::fmt::Debug,
641    E: EdgeWeight + Clone + Into<f64>,
642    Ix: petgraph::graph::IndexType,
643    L: MessagePassingLayer,
644{
645    let (_, adjacency) = graph_to_adjacency(graph);
646    let n = graph.nodes().len();
647
648    let mut embeddings = match initial_features {
649        Some(e) => e,
650        None => NodeEmbeddings::one_hot(n),
651    };
652
653    for layer in layers {
654        embeddings = embeddings.apply_layer(layer, &adjacency)?;
655    }
656
657    Ok(embeddings)
658}
659
660// ============================================================================
661// Tests (legacy API)
662// ============================================================================
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use crate::base::Graph;
668
669    type TriangleGraph = (Graph<usize, f64>, Vec<(usize, usize, f64)>);
670
671    fn make_triangle_graph() -> TriangleGraph {
672        let mut g: Graph<usize, f64> = Graph::new();
673        let _ = g.add_edge(0, 1, 1.0);
674        let _ = g.add_edge(1, 2, 1.0);
675        let _ = g.add_edge(0, 2, 1.0);
676        let (_, adj) = graph_to_adjacency(&g);
677        (g, adj)
678    }
679
680    fn make_features(n: usize, dim: usize) -> Vec<Vec<f64>> {
681        (0..n)
682            .map(|i| (0..dim).map(|j| (i * dim + j) as f64 / 10.0).collect())
683            .collect()
684    }
685
686    #[test]
687    fn test_gcn_layer_output_shape() {
688        let (_, adj) = make_triangle_graph();
689        let features = make_features(3, 4);
690        let layer = GCNLayer::new(4, 8);
691        let out = layer.forward(&features, &adj).expect("GCN forward failed");
692        assert_eq!(out.len(), 3);
693        assert_eq!(out[0].len(), 8);
694    }
695
696    #[test]
697    fn test_graphsage_layer_output_shape() {
698        let (_, adj) = make_triangle_graph();
699        let features = make_features(3, 4);
700        let layer = GraphSAGELayer::new(4, 6);
701        let out = layer.forward(&features, &adj).expect("SAGE forward failed");
702        assert_eq!(out.len(), 3);
703        assert_eq!(out[0].len(), 6);
704    }
705
706    #[test]
707    fn test_gat_layer_output_shape() {
708        let (_, adj) = make_triangle_graph();
709        let features = make_features(3, 4);
710        let layer = GATLayer::new(4, 8);
711        let out = layer.forward(&features, &adj).expect("GAT forward failed");
712        assert_eq!(out.len(), 3);
713        assert_eq!(out[0].len(), 8);
714    }
715
716    #[test]
717    fn test_node_embeddings_one_hot() {
718        let emb = NodeEmbeddings::one_hot(3);
719        assert_eq!(emb.n_nodes(), 3);
720        assert_eq!(emb.dim, 3);
721        let row0 = emb.get(0).expect("No embedding for node 0");
722        assert!((row0[0] - 1.0).abs() < 1e-10);
723        assert!((row0[1]).abs() < 1e-10);
724    }
725
726    #[test]
727    fn test_run_gnn_pipeline() {
728        let mut g: Graph<usize, f64> = Graph::new();
729        let _ = g.add_edge(0, 1, 1.0);
730        let _ = g.add_edge(1, 2, 1.0);
731        let _ = g.add_edge(2, 3, 1.0);
732        let layers = vec![GCNLayer::new(4, 4), GCNLayer::new(4, 4)];
733        let features = NodeEmbeddings::new(make_features(4, 4)).expect("Features");
734        let result = run_gnn_pipeline(&g, Some(features), &layers).expect("Pipeline");
735        assert_eq!(result.n_nodes(), 4);
736    }
737}