1pub mod conv;
23pub mod error;
24pub mod graph;
25pub mod handle;
26pub mod layers;
27pub mod message_passing;
28pub mod ops;
29pub mod pooling;
30pub mod ptx_kernels;
31pub mod readout;
32pub mod sampling;
33
34pub mod prelude {
38 pub use crate::conv::gcnii::{Gcnii, GcniiConfig, gcnii_beta};
39 pub use crate::error::{GnnError, GnnResult};
40 pub use crate::graph::coo::CooGraph;
41 pub use crate::graph::csr::CsrGraph;
42 pub use crate::graph::heterogeneous::HeteroGraph;
43 pub use crate::graph::sampling::{NeighborhoodSampler, SampledGraph, biased_walk, random_walk};
44 pub use crate::handle::{GnnHandle, LcgRng, SmVersion};
45 pub use crate::layers::appnp::{AppnpConfig, AppnpLayer};
46 pub use crate::layers::chebnet::{ChebNetConfig, ChebNetLayer};
47 pub use crate::layers::gat::{GatConfig, GatLayer};
48 pub use crate::layers::gat_v2::{GatV2Config, GatV2Layer};
49 pub use crate::layers::gcn::{GcnConfig, GcnLayer};
50 pub use crate::layers::gin::{GinConfig, GinLayer};
51 pub use crate::layers::grand::{GrandConfig, GrandLayer};
52 pub use crate::layers::graph_transformer::{
53 GraphTransformerConfig, GraphTransformerLayer, GraphTransformerWeights,
54 };
55 pub use crate::layers::jk_net::{JkMode, JkNet, JkNetConfig};
56 pub use crate::layers::k_wl_gnn::{
57 KWlConfig, KWlGnn, PairOp, apply_pair_op, graph_readout_sum,
58 };
59 pub use crate::layers::mixhop::{MixHopConfig, MixHopLayer};
60 pub use crate::layers::norm::{GraphNorm, PairNorm, PairNormMode};
61 pub use crate::layers::rgcn::{RgcnConfig, RgcnLayer};
62 pub use crate::layers::rwse::{RwseConfig, RwseEncoder, random_walk_se};
63 pub use crate::layers::sage::{SageAggregator, SageConfig, SageLayer};
64 pub use crate::layers::sgc::{sgc_forward, sgc_linear, sgc_propagate};
65 pub use crate::layers::sign::{SignConfig, SignConv, sign_precompute};
66 pub use crate::message_passing::aggregate::{
67 AggregationType, aggregate, aggregate_degree_norm, aggregate_max, aggregate_mean,
68 aggregate_softmax, aggregate_sum,
69 };
70 pub use crate::message_passing::scatter::{
71 gather, scatter_add, scatter_max, scatter_min, scatter_mul, segment_softmax,
72 };
73 pub use crate::message_passing::update::{
74 LinearUpdate, MlpUpdate, elu, leaky_relu, prelu, relu,
75 };
76 pub use crate::pooling::diff_pool::{DiffPool, DiffPoolConfig, DiffPoolResult};
77 pub use crate::pooling::global_pool::{
78 GlobalPoolType, batched_global_pool, global_attention_pool, global_max_pool,
79 global_mean_pool, global_sum_pool,
80 };
81 pub use crate::pooling::sag_pool::{SagPool, SagPoolResult};
82 pub use crate::pooling::topk_pool::{TopKPool, TopKPoolResult};
83 pub use crate::ptx_kernels::{
84 aggregate_mean_ptx, csr_spmv_ptx, f32_hex, gat_attention_ptx, gin_combine_ptx,
85 scatter_add_ptx, softmax_edge_ptx, topk_score_ptx,
86 };
87 pub use crate::readout::dgi::{Dgi, DgiConfig, DgiLoss, DgiWeights};
88 pub use crate::readout::set2set::Set2Set;
89 pub use crate::readout::sort_pool::{SortPool, SortPoolConfig};
90 pub use crate::sampling::cluster_gcn::{BatchSubgraph, ClusterGcn, Partition};
91 pub use crate::sampling::graphsaint::{GraphSaint, SaintNorm, SaintSampler, SaintSubgraph};
92}
93
94#[cfg(test)]
97mod tests {
98 use crate::prelude::*;
99
100 #[test]
103 fn e2e_csr_graph_construction_and_spmv() {
104 let g = CsrGraph::from_edges(3, &[(0, 1), (1, 0), (1, 2), (2, 1), (0, 2), (2, 0)])
106 .expect("test invariant: value must be valid");
107 assert_eq!(g.n_nodes(), 3);
108 assert_eq!(g.n_edges(), 6);
109
110 let x = vec![1.0_f32, 2.0, 3.0];
113 let y = g.spmv(&x, 1).expect("test invariant: value must be valid");
114 assert!((y[0] - 5.0).abs() < 1e-5);
115 assert!((y[1] - 4.0).abs() < 1e-5);
116 assert!((y[2] - 3.0).abs() < 1e-5);
117 }
118
119 #[test]
122 fn e2e_coo_to_csr_roundtrip() {
123 let src = vec![0usize, 1, 2, 0];
124 let dst = vec![1usize, 2, 0, 2];
125 let coo = CooGraph::new(3, src.clone(), dst.clone())
126 .expect("test invariant: value must be valid");
127 let csr = coo.to_csr().expect("test invariant: value must be valid");
128 assert_eq!(csr.n_nodes(), 3);
129 assert_eq!(csr.n_edges(), 4);
130
131 for &s in &src {
133 assert!(csr.degree(s).expect("test invariant: value must be valid") > 0);
134 }
135 }
136
137 #[test]
140 fn e2e_scatter_add_correctness() {
141 let messages = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
143 let idx = vec![0usize, 0, 1, 1];
144 let out = scatter_add(&messages, &idx, 2, 2).expect("test invariant: value must be valid");
145 assert!((out[0] - 4.0).abs() < 1e-5);
147 assert!((out[1] - 6.0).abs() < 1e-5);
148 assert!((out[2] - 12.0).abs() < 1e-5);
150 assert!((out[3] - 14.0).abs() < 1e-5);
151 }
152
153 #[test]
156 fn e2e_aggregate_mean_small_graph() {
157 let messages = vec![2.0_f32, 4.0, 6.0, 8.0]; let target_idx = vec![0usize, 0];
160 let out = aggregate_mean(&messages, &target_idx, 2, 2)
161 .expect("test invariant: value must be valid");
162 assert!((out[0] - 4.0).abs() < 1e-5);
164 assert!((out[1] - 6.0).abs() < 1e-5);
165 assert!((out[2]).abs() < 1e-6);
167 }
168
169 #[test]
172 fn e2e_gcn_forward_shape() {
173 let g = CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
174 .expect("test invariant: value must be valid");
175 let layer = GcnLayer::new(GcnConfig {
176 in_features: 4,
177 out_features: 8,
178 bias: false,
179 normalize: true,
180 })
181 .expect("test invariant: value must be valid");
182 let feats = vec![0.1_f32; 4 * 4];
183 let w = vec![0.1_f32; 4 * 8];
184 let out = layer
185 .forward(&g, &feats, &w, None)
186 .expect("test invariant: value must be valid");
187 assert_eq!(out.len(), 4 * 8);
188 assert!(out.iter().all(|v| v.is_finite()));
189 }
190
191 #[test]
194 fn e2e_gat_attention_sums_to_one() {
195 let g = CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
197 .expect("test invariant: value must be valid");
198 let in_f = 4;
199 let out_f = 4;
200 let nh = 1;
201 let hd = out_f;
202 let layer = GatLayer::new(GatConfig {
203 in_features: in_f,
204 out_features: out_f,
205 num_heads: nh,
206 dropout: 0.0,
207 leaky_relu_slope: 0.2,
208 concat_heads: true,
209 })
210 .expect("test invariant: value must be valid");
211 let x = vec![1.0_f32; 3 * in_f];
212 let w = vec![1.0_f32; nh * hd * in_f];
213 let aw = vec![0.1_f32; nh * 2 * hd];
214 let out = layer
215 .forward(&g, &x, &w, &aw)
216 .expect("test invariant: value must be valid");
217 assert_eq!(out.len(), 3 * out_f);
218 assert!(out.iter().all(|v| v.is_finite()));
220 }
221
222 #[test]
225 fn e2e_sage_mean_aggregator() {
226 let g = CsrGraph::from_edges(4, &[(0, 1), (0, 2), (1, 3), (2, 3)])
227 .expect("test invariant: value must be valid");
228 let layer = SageLayer::new(SageConfig {
229 in_features: 3,
230 out_features: 3,
231 aggregator: SageAggregator::Mean,
232 normalize_output: false,
233 })
234 .expect("test invariant: value must be valid");
235 let x = vec![0.5_f32; 4 * 3];
236 let w = vec![0.1_f32; 3 * 6];
237 let b = vec![0.0_f32; 3];
238 let out = layer
239 .forward(&g, &x, &w, &b)
240 .expect("test invariant: value must be valid");
241 assert_eq!(out.len(), 4 * 3);
242 assert!(out.iter().all(|v| v.is_finite()));
243 }
244
245 #[test]
248 fn e2e_gin_epsilon_effect() {
249 let g = CsrGraph::from_edges(3, &[(0, 1), (1, 2)])
250 .expect("test invariant: value must be valid");
251 let x = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]; let make_gin = |eps: f32| {
253 GinLayer::new(GinConfig {
254 in_features: 3,
255 hidden_features: 4,
256 out_features: 3,
257 epsilon: eps,
258 train_epsilon: false,
259 })
260 .expect("test invariant: value must be valid")
261 };
262 let w1 = vec![0.1_f32; 4 * 3];
263 let b1 = vec![0.0_f32; 4];
264 let w2 = vec![0.1_f32; 3 * 4];
265 let b2 = vec![0.0_f32; 3];
266 let out_e0 = make_gin(0.0)
267 .forward(&g, &x, &w1, &b1, &w2, &b2)
268 .expect("test invariant: value must be valid");
269 let out_e1 = make_gin(1.0)
270 .forward(&g, &x, &w1, &b1, &w2, &b2)
271 .expect("test invariant: value must be valid");
272 assert_eq!(out_e0.len(), 9);
273 assert_eq!(out_e1.len(), 9);
274 let diff: f32 = out_e0
276 .iter()
277 .zip(out_e1.iter())
278 .map(|(a, b)| (a - b).abs())
279 .sum();
280 assert!(diff > 0.0 || out_e0.iter().all(|&v| v.abs() < 1e-8));
281 }
282
283 #[test]
286 fn e2e_global_mean_pool() {
287 let x = vec![2.0_f32, 4.0, 6.0, 8.0]; let out = global_mean_pool(&x, 2, 2).expect("test invariant: value must be valid");
289 assert_eq!(out.len(), 2);
290 assert!((out[0] - 4.0).abs() < 1e-5); assert!((out[1] - 6.0).abs() < 1e-5); }
293
294 #[test]
297 fn e2e_topk_pool_k_nodes_selected() {
298 let g = CsrGraph::from_edges(
299 5,
300 &[
301 (0, 1),
302 (1, 0),
303 (1, 2),
304 (2, 1),
305 (2, 3),
306 (3, 2),
307 (3, 4),
308 (4, 3),
309 ],
310 )
311 .expect("test invariant: value must be valid");
312 let feat_dim = 3;
313 let k = 3;
314 let pool = TopKPool::new_k(feat_dim, k);
315 let x: Vec<f32> = (0..5 * feat_dim).map(|i| i as f32 * 0.2).collect();
316 let proj = vec![1.0_f32, 0.5, 0.25];
317 let res = pool
318 .forward(&g, &x, &proj)
319 .expect("test invariant: value must be valid");
320 assert_eq!(res.n_nodes(), k);
321 assert_eq!(res.x.len(), k * feat_dim);
322 assert_eq!(res.graph.n_nodes(), k);
323 }
324
325 #[test]
328 fn e2e_diffpool_assignment_stochastic() {
329 let g = CsrGraph::from_edges(4, &[(0, 1), (1, 2), (2, 3), (3, 0)])
330 .expect("test invariant: value must be valid");
331 let d = 3;
332 let k = 2;
333 let dp = DiffPool::new(DiffPoolConfig {
334 in_features: d,
335 n_clusters: k,
336 })
337 .expect("test invariant: value must be valid");
338 let x = vec![1.0_f32; 4 * d];
339 let logits: Vec<f32> = (0..4 * k).map(|i| i as f32 * 0.1).collect();
340 let res = dp
341 .forward(&g, &x, &logits)
342 .expect("test invariant: value must be valid");
343 for i in 0..4 {
344 let row_sum: f32 = res.assignment[i * k..(i + 1) * k].iter().sum();
345 assert!((row_sum - 1.0).abs() < 1e-5);
346 }
347 }
348
349 #[test]
352 fn e2e_ptx_kernels_all_sm_versions() {
353 for &sm in &[75u32, 80, 86, 90, 100, 120] {
354 let ptx = csr_spmv_ptx(sm);
355 assert!(ptx.contains("csr_spmv"));
356 assert!(ptx.contains(&format!("sm_{sm}")));
357
358 let ptx = scatter_add_ptx(sm);
359 assert!(ptx.contains("scatter_add"));
360
361 let ptx = gat_attention_ptx(sm);
362 assert!(ptx.contains("gat_attention"));
363
364 let ptx = softmax_edge_ptx(sm);
365 assert!(ptx.contains("softmax_edge"));
366
367 let ptx = aggregate_mean_ptx(sm);
368 assert!(ptx.contains("aggregate_mean"));
369
370 let ptx = gin_combine_ptx(sm);
371 assert!(ptx.contains("gin_combine"));
372
373 let ptx = topk_score_ptx(sm);
374 assert!(ptx.contains("topk_score"));
375 }
376 }
377
378 #[test]
381 fn e2e_handle_rng_deterministic() {
382 let mut h1 = GnnHandle::default_handle();
383 let mut h2 = GnnHandle::default_handle();
384 let r1: Vec<u32> = (0..10).map(|_| h1.rng_mut().next_u32()).collect();
386 let r2: Vec<u32> = (0..10).map(|_| h2.rng_mut().next_u32()).collect();
387 assert_eq!(r1, r2);
388 }
389
390 #[test]
393 fn e2e_neighborhood_sampling() {
394 let g = CsrGraph::from_edges(
395 8,
396 &[
397 (0, 1),
398 (0, 2),
399 (1, 3),
400 (1, 4),
401 (2, 5),
402 (2, 6),
403 (3, 7),
404 (4, 7),
405 ],
406 )
407 .expect("test invariant: value must be valid");
408 let sampler =
409 NeighborhoodSampler::new(vec![2, 2]).expect("test invariant: value must be valid");
410 let mut rng = LcgRng::new(42);
411 let result = sampler
412 .sample(&g, &[0], &mut rng)
413 .expect("test invariant: value must be valid");
414 assert!(result.n_nodes() >= 1);
415 assert!(result.local_to_global.contains(&0));
416 }
417}