Skip to main content

oxicuda_gnn/
lib.rs

1//! `oxicuda-gnn` — Graph Neural Network primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of GNN building blocks suitable for CPU simulation
4//! and PTX kernel generation for GPU execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! oxicuda-gnn
10//! ├── graph/          — Sparse graph representations (CSR, COO, Heterogeneous, Sampling)
11//! ├── message_passing — Aggregate, Scatter, Update primitives
12//! ├── layers/         — GCN, GAT, GATv2, GraphSAGE, GIN
13//! ├── pooling/        — Global pool, Top-K pool, DiffPool
14//! ├── readout/        — Set2Set
15//! ├── error           — GnnError / GnnResult
16//! ├── handle          — GnnHandle (SmVersion + LcgRng)
17//! └── ptx_kernels     — GPU PTX kernel strings
18//! ```
19
20// ─── Module declarations ─────────────────────────────────────────────────────
21
22pub mod conv;
23pub mod error;
24pub mod graph;
25pub mod handle;
26pub mod layers;
27pub mod message_passing;
28pub mod ops;
29pub mod pooling;
30pub mod ptx_kernels;
31pub mod readout;
32pub mod sampling;
33
34// ─── Prelude ─────────────────────────────────────────────────────────────────
35
36/// Convenience re-exports for common GNN types.
37pub mod prelude {
38    pub use crate::conv::gcnii::{Gcnii, GcniiConfig, gcnii_beta};
39    pub use crate::error::{GnnError, GnnResult};
40    pub use crate::graph::coo::CooGraph;
41    pub use crate::graph::csr::CsrGraph;
42    pub use crate::graph::heterogeneous::HeteroGraph;
43    pub use crate::graph::sampling::{NeighborhoodSampler, SampledGraph, biased_walk, random_walk};
44    pub use crate::handle::{GnnHandle, LcgRng, SmVersion};
45    pub use crate::layers::appnp::{AppnpConfig, AppnpLayer};
46    pub use crate::layers::chebnet::{ChebNetConfig, ChebNetLayer};
47    pub use crate::layers::gat::{GatConfig, GatLayer};
48    pub use crate::layers::gat_v2::{GatV2Config, GatV2Layer};
49    pub use crate::layers::gcn::{GcnConfig, GcnLayer};
50    pub use crate::layers::gin::{GinConfig, GinLayer};
51    pub use crate::layers::grand::{GrandConfig, GrandLayer};
52    pub use crate::layers::graph_transformer::{
53        GraphTransformerConfig, GraphTransformerLayer, GraphTransformerWeights,
54    };
55    pub use crate::layers::jk_net::{JkMode, JkNet, JkNetConfig};
56    pub use crate::layers::k_wl_gnn::{
57        KWlConfig, KWlGnn, PairOp, apply_pair_op, graph_readout_sum,
58    };
59    pub use crate::layers::mixhop::{MixHopConfig, MixHopLayer};
60    pub use crate::layers::norm::{GraphNorm, PairNorm, PairNormMode};
61    pub use crate::layers::rgcn::{RgcnConfig, RgcnLayer};
62    pub use crate::layers::rwse::{RwseConfig, RwseEncoder, random_walk_se};
63    pub use crate::layers::sage::{SageAggregator, SageConfig, SageLayer};
64    pub use crate::layers::sgc::{sgc_forward, sgc_linear, sgc_propagate};
65    pub use crate::layers::sign::{SignConfig, SignConv, sign_precompute};
66    pub use crate::message_passing::aggregate::{
67        AggregationType, aggregate, aggregate_degree_norm, aggregate_max, aggregate_mean,
68        aggregate_softmax, aggregate_sum,
69    };
70    pub use crate::message_passing::scatter::{
71        gather, scatter_add, scatter_max, scatter_min, scatter_mul, segment_softmax,
72    };
73    pub use crate::message_passing::update::{
74        LinearUpdate, MlpUpdate, elu, leaky_relu, prelu, relu,
75    };
76    pub use crate::pooling::diff_pool::{DiffPool, DiffPoolConfig, DiffPoolResult};
77    pub use crate::pooling::global_pool::{
78        GlobalPoolType, batched_global_pool, global_attention_pool, global_max_pool,
79        global_mean_pool, global_sum_pool,
80    };
81    pub use crate::pooling::sag_pool::{SagPool, SagPoolResult};
82    pub use crate::pooling::topk_pool::{TopKPool, TopKPoolResult};
83    pub use crate::ptx_kernels::{
84        aggregate_mean_ptx, csr_spmv_ptx, f32_hex, gat_attention_ptx, gin_combine_ptx,
85        scatter_add_ptx, softmax_edge_ptx, topk_score_ptx,
86    };
87    pub use crate::readout::dgi::{Dgi, DgiConfig, DgiLoss, DgiWeights};
88    pub use crate::readout::set2set::Set2Set;
89    pub use crate::readout::sort_pool::{SortPool, SortPoolConfig};
90    pub use crate::sampling::cluster_gcn::{BatchSubgraph, ClusterGcn, Partition};
91    pub use crate::sampling::graphsaint::{GraphSaint, SaintNorm, SaintSampler, SaintSubgraph};
92}
93
94// ─── Integration tests ───────────────────────────────────────────────────────
95
96#[cfg(test)]
97mod tests {
98    use crate::prelude::*;
99
100    // ── Graph construction & SpMV ─────────────────────────────────────────────
101
102    #[test]
103    fn e2e_csr_graph_construction_and_spmv() {
104        // Triangle graph: 0↔1↔2↔0
105        let g = CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)])
106            .expect("test invariant: value must be valid");
107        assert_eq!(g.n_nodes(), 3);
108        assert_eq!(g.n_edges(), 6);
109
110        // SpMV with feat_dim=1: x = [1, 2, 3]
111        // y[0] = x[1] + x[2] = 5, y[1] = x[0] + x[2] = 4, y[2] = x[0] + x[1] = 3
112        let x = vec![1.0_f32, 2.0, 3.0];
113        let y = g.spmv(&x, 1).expect("test invariant: value must be valid");
114        assert!((y[0] - 5.0).abs() < 1e-5);
115        assert!((y[1] - 4.0).abs() < 1e-5);
116        assert!((y[2] - 3.0).abs() < 1e-5);
117    }
118
119    // ── COO → CSR roundtrip ───────────────────────────────────────────────────
120
121    #[test]
122    fn e2e_coo_to_csr_roundtrip() {
123        let src = vec![0usize, 1, 2, 0];
124        let dst = vec![1usize, 2, 0, 2];
125        let coo = CooGraph::new(3, src.clone(), dst.clone())
126            .expect("test invariant: value must be valid");
127        let csr = coo.to_csr().expect("test invariant: value must be valid");
128        assert_eq!(csr.n_nodes(), 3);
129        assert_eq!(csr.n_edges(), 4);
130
131        // Each source in coo should be a valid node in csr
132        for &s in &src {
133            assert!(csr.degree(s).expect("test invariant: value must be valid") > 0);
134        }
135    }
136
137    // ── Scatter-add ───────────────────────────────────────────────────────────
138
139    #[test]
140    fn e2e_scatter_add_correctness() {
141        // 4 messages → 2 destination nodes, feat_dim=2
142        let messages = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
143        let idx = vec![0usize, 0, 1, 1];
144        let out = scatter_add(&messages, &idx, 2, 2).expect("test invariant: value must be valid");
145        // dest 0 = [1+3, 2+4] = [4, 6]
146        assert!((out[0] - 4.0).abs() < 1e-5);
147        assert!((out[1] - 6.0).abs() < 1e-5);
148        // dest 1 = [5+7, 6+8] = [12, 14]
149        assert!((out[2] - 12.0).abs() < 1e-5);
150        assert!((out[3] - 14.0).abs() < 1e-5);
151    }
152
153    // ── Aggregate mean ────────────────────────────────────────────────────────
154
155    #[test]
156    fn e2e_aggregate_mean_small_graph() {
157        // Node 0 receives messages from edge 0 and edge 1
158        let messages = vec![2.0_f32, 4.0, 6.0, 8.0]; // 2 messages × feat_dim=2
159        let target_idx = vec![0usize, 0];
160        let out = aggregate_mean(&messages, &target_idx, 2, 2)
161            .expect("test invariant: value must be valid");
162        // node 0: mean([2,4],[6,8]) = [4, 6]
163        assert!((out[0] - 4.0).abs() < 1e-5);
164        assert!((out[1] - 6.0).abs() < 1e-5);
165        // node 1: no messages → 0
166        assert!((out[2]).abs() < 1e-6);
167    }
168
169    // ── GCN forward shape ─────────────────────────────────────────────────────
170
171    #[test]
172    fn e2e_gcn_forward_shape() {
173        let g = CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
174            .expect("test invariant: value must be valid");
175        let layer = GcnLayer::new(GcnConfig {
176            in_features: 4,
177            out_features: 8,
178            bias: false,
179            normalize: true,
180        })
181        .expect("test invariant: value must be valid");
182        let feats = vec![0.1_f32; 4 * 4];
183        let w = vec![0.1_f32; 4 * 8];
184        let out = layer
185            .forward(&g, &feats, &w, None)
186            .expect("test invariant: value must be valid");
187        assert_eq!(out.len(), 4 * 8);
188        assert!(out.iter().all(|v| v.is_finite()));
189    }
190
191    // ── GAT attention sums to one ─────────────────────────────────────────────
192
193    #[test]
194    fn e2e_gat_attention_sums_to_one() {
195        // 3-node ring with uniform features
196        let g = CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
197            .expect("test invariant: value must be valid");
198        let in_f = 4;
199        let out_f = 4;
200        let nh = 1;
201        let hd = out_f;
202        let layer = GatLayer::new(GatConfig {
203            in_features: in_f,
204            out_features: out_f,
205            num_heads: nh,
206            dropout: 0.0,
207            leaky_relu_slope: 0.2,
208            concat_heads: true,
209        })
210        .expect("test invariant: value must be valid");
211        let x = vec![1.0_f32; 3 * in_f];
212        let w = vec![1.0_f32; nh * hd * in_f];
213        let aw = vec![0.1_f32; nh * 2 * hd];
214        let out = layer
215            .forward(&g, &x, &w, &aw)
216            .expect("test invariant: value must be valid");
217        assert_eq!(out.len(), 3 * out_f);
218        // Outputs should be finite
219        assert!(out.iter().all(|v| v.is_finite()));
220    }
221
222    // ── GraphSAGE mean aggregator ─────────────────────────────────────────────
223
224    #[test]
225    fn e2e_sage_mean_aggregator() {
226        let g = CsrGraph::from_edges(4, &[(0, 1), (0, 2), (1, 3), (2, 3)])
227            .expect("test invariant: value must be valid");
228        let layer = SageLayer::new(SageConfig {
229            in_features: 3,
230            out_features: 3,
231            aggregator: SageAggregator::Mean,
232            normalize_output: false,
233        })
234        .expect("test invariant: value must be valid");
235        let x = vec![0.5_f32; 4 * 3];
236        let w = vec![0.1_f32; 3 * 6];
237        let b = vec![0.0_f32; 3];
238        let out = layer
239            .forward(&g, &x, &w, &b)
240            .expect("test invariant: value must be valid");
241        assert_eq!(out.len(), 4 * 3);
242        assert!(out.iter().all(|v| v.is_finite()));
243    }
244
245    // ── GIN epsilon effect ────────────────────────────────────────────────────
246
247    #[test]
248    fn e2e_gin_epsilon_effect() {
249        let g = CsrGraph::from_edges(3, &[(0, 1), (1, 2)])
250            .expect("test invariant: value must be valid");
251        let x = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; // 3×3 identity
252        let make_gin = |eps: f32| {
253            GinLayer::new(GinConfig {
254                in_features: 3,
255                hidden_features: 4,
256                out_features: 3,
257                epsilon: eps,
258                train_epsilon: false,
259            })
260            .expect("test invariant: value must be valid")
261        };
262        let w1 = vec![0.1_f32; 4 * 3];
263        let b1 = vec![0.0_f32; 4];
264        let w2 = vec![0.1_f32; 3 * 4];
265        let b2 = vec![0.0_f32; 3];
266        let out_e0 = make_gin(0.0)
267            .forward(&g, &x, &w1, &b1, &w2, &b2)
268            .expect("test invariant: value must be valid");
269        let out_e1 = make_gin(1.0)
270            .forward(&g, &x, &w1, &b1, &w2, &b2)
271            .expect("test invariant: value must be valid");
272        assert_eq!(out_e0.len(), 9);
273        assert_eq!(out_e1.len(), 9);
274        // Outputs differ because epsilon changes the weighting
275        let diff: f32 = out_e0
276            .iter()
277            .zip(out_e1.iter())
278            .map(|(a, b)| (a - b).abs())
279            .sum();
280        assert!(diff > 0.0 || out_e0.iter().all(|&v| v.abs() < 1e-8));
281    }
282
283    // ── Global mean pool ──────────────────────────────────────────────────────
284
285    #[test]
286    fn e2e_global_mean_pool() {
287        let x = vec![2.0_f32, 4.0, 6.0, 8.0]; // 2 nodes × feat_dim=2
288        let out = global_mean_pool(&x, 2, 2).expect("test invariant: value must be valid");
289        assert_eq!(out.len(), 2);
290        assert!((out[0] - 4.0).abs() < 1e-5); // (2+6)/2
291        assert!((out[1] - 6.0).abs() < 1e-5); // (4+8)/2
292    }
293
294    // ── Top-K pool: k nodes selected ─────────────────────────────────────────
295
296    #[test]
297    fn e2e_topk_pool_k_nodes_selected() {
298        let g = CsrGraph::from_edges(
299            5,
300            &[
301                (0, 1),
302                (1, 0),
303                (1, 2),
304                (2, 1),
305                (2, 3),
306                (3, 2),
307                (3, 4),
308                (4, 3),
309            ],
310        )
311        .expect("test invariant: value must be valid");
312        let feat_dim = 3;
313        let k = 3;
314        let pool = TopKPool::new_k(feat_dim, k);
315        let x: Vec<f32> = (0..5 * feat_dim).map(|i| i as f32 * 0.2).collect();
316        let proj = vec![1.0_f32, 0.5, 0.25];
317        let res = pool
318            .forward(&g, &x, &proj)
319            .expect("test invariant: value must be valid");
320        assert_eq!(res.n_nodes(), k);
321        assert_eq!(res.x.len(), k * feat_dim);
322        assert_eq!(res.graph.n_nodes(), k);
323    }
324
325    // ── DiffPool assignment row-stochastic ────────────────────────────────────
326
327    #[test]
328    fn e2e_diffpool_assignment_stochastic() {
329        let g = CsrGraph::from_edges(4, &[(0, 1), (1, 2), (2, 3), (3, 0)])
330            .expect("test invariant: value must be valid");
331        let d = 3;
332        let k = 2;
333        let dp = DiffPool::new(DiffPoolConfig {
334            in_features: d,
335            n_clusters: k,
336        })
337        .expect("test invariant: value must be valid");
338        let x = vec![1.0_f32; 4 * d];
339        let logits: Vec<f32> = (0..4 * k).map(|i| i as f32 * 0.1).collect();
340        let res = dp
341            .forward(&g, &x, &logits)
342            .expect("test invariant: value must be valid");
343        for i in 0..4 {
344            let row_sum: f32 = res.assignment[i * k..(i + 1) * k].iter().sum();
345            assert!((row_sum - 1.0).abs() < 1e-5);
346        }
347    }
348
349    // ── PTX kernels for all SM versions ──────────────────────────────────────
350
351    #[test]
352    fn e2e_ptx_kernels_all_sm_versions() {
353        for &sm in &[75u32, 80, 86, 90, 100, 120] {
354            let ptx = csr_spmv_ptx(sm);
355            assert!(ptx.contains("csr_spmv"));
356            assert!(ptx.contains(&format!("sm_{sm}")));
357
358            let ptx = scatter_add_ptx(sm);
359            assert!(ptx.contains("scatter_add"));
360
361            let ptx = gat_attention_ptx(sm);
362            assert!(ptx.contains("gat_attention"));
363
364            let ptx = softmax_edge_ptx(sm);
365            assert!(ptx.contains("softmax_edge"));
366
367            let ptx = aggregate_mean_ptx(sm);
368            assert!(ptx.contains("aggregate_mean"));
369
370            let ptx = gin_combine_ptx(sm);
371            assert!(ptx.contains("gin_combine"));
372
373            let ptx = topk_score_ptx(sm);
374            assert!(ptx.contains("topk_score"));
375        }
376    }
377
378    // ── Handle and RNG ────────────────────────────────────────────────────────
379
380    #[test]
381    fn e2e_handle_rng_deterministic() {
382        let mut h1 = GnnHandle::default_handle();
383        let mut h2 = GnnHandle::default_handle();
384        // Same seed → same sequence
385        let r1: Vec<u32> = (0..10).map(|_| h1.rng_mut().next_u32()).collect();
386        let r2: Vec<u32> = (0..10).map(|_| h2.rng_mut().next_u32()).collect();
387        assert_eq!(r1, r2);
388    }
389
390    // ── Neighbourhood sampling ────────────────────────────────────────────────
391
392    #[test]
393    fn e2e_neighborhood_sampling() {
394        let g = CsrGraph::from_edges(
395            8,
396            &[
397                (0, 1),
398                (0, 2),
399                (1, 3),
400                (1, 4),
401                (2, 5),
402                (2, 6),
403                (3, 7),
404                (4, 7),
405            ],
406        )
407        .expect("test invariant: value must be valid");
408        let sampler =
409            NeighborhoodSampler::new(vec![2, 2]).expect("test invariant: value must be valid");
410        let mut rng = LcgRng::new(42);
411        let result = sampler
412            .sample(&g, &[0], &mut rng)
413            .expect("test invariant: value must be valid");
414        assert!(result.n_nodes() >= 1);
415        assert!(result.local_to_global.contains(&0));
416    }
417}