1use crate::types::CsrGraph;
8use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GNNConfig {
18 pub num_layers: usize,
20 pub hidden_dim: usize,
22 pub output_dim: usize,
24 pub aggregation: AggregationType,
26 pub activation: ActivationType,
28 pub dropout: f64,
30 pub add_self_loops: bool,
32 pub layer_norm: bool,
34}
35
36impl Default for GNNConfig {
37 fn default() -> Self {
38 Self {
39 num_layers: 2,
40 hidden_dim: 64,
41 output_dim: 32,
42 aggregation: AggregationType::Mean,
43 activation: ActivationType::ReLU,
44 dropout: 0.0,
45 add_self_loops: true,
46 layer_norm: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
53pub enum AggregationType {
54 Sum,
56 Mean,
58 Max,
60 SAGE,
62}
63
64#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
66pub enum ActivationType {
67 ReLU,
69 LeakyReLU,
71 ELU,
73 Sigmoid,
75 Tanh,
77 None,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct GNNWeights {
84 pub layer_weights: Vec<Vec<Vec<f64>>>,
86 pub layer_biases: Vec<Vec<f64>>,
88}
89
90impl GNNWeights {
91 pub fn random(input_dim: usize, config: &GNNConfig) -> Self {
93 use rand::{Rng, rng};
94 let mut r = rng();
95
96 let mut layer_weights = Vec::new();
97 let mut layer_biases = Vec::new();
98
99 let mut prev_dim = input_dim;
100
101 for i in 0..config.num_layers {
102 let out_dim = if i == config.num_layers - 1 {
103 config.output_dim
104 } else {
105 config.hidden_dim
106 };
107
108 let scale = (2.0 / (prev_dim + out_dim) as f64).sqrt();
110
111 let weights: Vec<Vec<f64>> = (0..prev_dim)
112 .map(|_| {
113 (0..out_dim)
114 .map(|_| r.random_range(-scale..scale))
115 .collect()
116 })
117 .collect();
118
119 let biases: Vec<f64> = (0..out_dim).map(|_| 0.0).collect();
120
121 layer_weights.push(weights);
122 layer_biases.push(biases);
123 prev_dim = out_dim;
124 }
125
126 Self {
127 layer_weights,
128 layer_biases,
129 }
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct GNNResult {
136 pub embeddings: Vec<Vec<f64>>,
138 pub predictions: Option<Vec<usize>>,
140 pub probabilities: Option<Vec<Vec<f64>>>,
142}
143
144#[derive(Debug, Clone)]
150pub struct GNNInference {
151 metadata: KernelMetadata,
152}
153
154impl Default for GNNInference {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl GNNInference {
161 #[must_use]
163 pub fn new() -> Self {
164 Self {
165 metadata: KernelMetadata::batch("graph/gnn-inference", Domain::GraphAnalytics)
166 .with_description("Message passing neural network inference")
167 .with_throughput(10_000)
168 .with_latency_us(100.0),
169 }
170 }
171
172 #[allow(clippy::needless_range_loop)]
174 pub fn compute(
175 graph: &CsrGraph,
176 node_features: &[Vec<f64>],
177 weights: &GNNWeights,
178 config: &GNNConfig,
179 ) -> GNNResult {
180 if graph.num_nodes == 0 || node_features.is_empty() {
181 return GNNResult {
182 embeddings: Vec::new(),
183 predictions: None,
184 probabilities: None,
185 };
186 }
187
188 let n = graph.num_nodes;
189
190 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
192 for node in 0..n {
193 let start = graph.row_offsets[node] as usize;
194 let end = graph.row_offsets[node + 1] as usize;
195 for &neighbor in &graph.col_indices[start..end] {
196 adj[node].push(neighbor as usize);
197 if !adj[neighbor as usize].contains(&node) {
199 adj[neighbor as usize].push(node);
200 }
201 }
202 }
203
204 if config.add_self_loops {
205 for i in 0..n {
206 if !adj[i].contains(&i) {
207 adj[i].push(i);
208 }
209 }
210 }
211
212 let mut embeddings: Vec<Vec<f64>> = node_features.to_vec();
214
215 for layer_idx in 0..config.num_layers {
217 embeddings = Self::message_passing_layer(
218 &embeddings,
219 &adj,
220 &weights.layer_weights[layer_idx],
221 &weights.layer_biases[layer_idx],
222 config,
223 layer_idx == config.num_layers - 1,
224 );
225 }
226
227 let (predictions, probabilities) = if config.output_dim > 1 {
229 let probs: Vec<Vec<f64>> = embeddings.iter().map(|e| Self::softmax(e)).collect();
230 let preds: Vec<usize> = probs
231 .iter()
232 .map(|p| {
233 p.iter()
234 .enumerate()
235 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
236 .map(|(i, _)| i)
237 .unwrap_or(0)
238 })
239 .collect();
240 (Some(preds), Some(probs))
241 } else {
242 (None, None)
243 };
244
245 GNNResult {
246 embeddings,
247 predictions,
248 probabilities,
249 }
250 }
251
252 #[allow(clippy::needless_range_loop)]
254 fn message_passing_layer(
255 embeddings: &[Vec<f64>],
256 adj: &[Vec<usize>],
257 weights: &[Vec<f64>],
258 biases: &[f64],
259 config: &GNNConfig,
260 is_last: bool,
261 ) -> Vec<Vec<f64>> {
262 let n = embeddings.len();
263 let out_dim = biases.len();
264 let mut new_embeddings = vec![vec![0.0; out_dim]; n];
265
266 for i in 0..n {
267 let aggregated = Self::aggregate_neighbors(embeddings, &adj[i], config.aggregation);
269
270 for j in 0..out_dim {
272 let mut val = biases[j];
273 for (k, &agg_val) in aggregated.iter().enumerate() {
274 if k < weights.len() && j < weights[k].len() {
275 val += weights[k][j] * agg_val;
276 }
277 }
278
279 if !is_last {
281 val = Self::activate(val, config.activation);
282 }
283
284 new_embeddings[i][j] = val;
285 }
286
287 if config.layer_norm && !is_last {
289 let mean: f64 = new_embeddings[i].iter().sum::<f64>() / out_dim as f64;
290 let var: f64 = new_embeddings[i]
291 .iter()
292 .map(|x| (x - mean).powi(2))
293 .sum::<f64>()
294 / out_dim as f64;
295 let std = (var + 1e-5).sqrt();
296
297 for j in 0..out_dim {
298 new_embeddings[i][j] = (new_embeddings[i][j] - mean) / std;
299 }
300 }
301 }
302
303 new_embeddings
304 }
305
306 fn aggregate_neighbors(
308 embeddings: &[Vec<f64>],
309 neighbors: &[usize],
310 agg_type: AggregationType,
311 ) -> Vec<f64> {
312 if neighbors.is_empty() {
313 return vec![0.0; embeddings.first().map(|e| e.len()).unwrap_or(0)];
314 }
315
316 let dim = embeddings[neighbors[0]].len();
317
318 match agg_type {
319 AggregationType::Sum => {
320 let mut result = vec![0.0; dim];
321 for &n in neighbors {
322 for (i, &v) in embeddings[n].iter().enumerate() {
323 result[i] += v;
324 }
325 }
326 result
327 }
328 AggregationType::Mean => {
329 let mut result = vec![0.0; dim];
330 for &n in neighbors {
331 for (i, &v) in embeddings[n].iter().enumerate() {
332 result[i] += v;
333 }
334 }
335 let count = neighbors.len() as f64;
336 result.iter_mut().for_each(|v| *v /= count);
337 result
338 }
339 AggregationType::Max => {
340 let mut result = vec![f64::NEG_INFINITY; dim];
341 for &n in neighbors {
342 for (i, &v) in embeddings[n].iter().enumerate() {
343 result[i] = result[i].max(v);
344 }
345 }
346 result
347 }
348 AggregationType::SAGE => {
349 let mut result = vec![0.0; dim];
352 for &n in neighbors {
353 for (i, &v) in embeddings[n].iter().enumerate() {
354 result[i] += v;
355 }
356 }
357 let count = neighbors.len() as f64;
358 result.iter_mut().for_each(|v| *v /= count);
359 result
360 }
361 }
362 }
363
364 fn activate(x: f64, activation: ActivationType) -> f64 {
366 match activation {
367 ActivationType::ReLU => x.max(0.0),
368 ActivationType::LeakyReLU => {
369 if x > 0.0 {
370 x
371 } else {
372 0.01 * x
373 }
374 }
375 ActivationType::ELU => {
376 if x > 0.0 {
377 x
378 } else {
379 x.exp() - 1.0
380 }
381 }
382 ActivationType::Sigmoid => 1.0 / (1.0 + (-x).exp()),
383 ActivationType::Tanh => x.tanh(),
384 ActivationType::None => x,
385 }
386 }
387
388 fn softmax(x: &[f64]) -> Vec<f64> {
390 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
391 let exp_sum: f64 = x.iter().map(|v| (v - max_val).exp()).sum();
392 x.iter().map(|v| (v - max_val).exp() / exp_sum).collect()
393 }
394}
395
396impl GpuKernel for GNNInference {
397 fn metadata(&self) -> &KernelMetadata {
398 &self.metadata
399 }
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct GraphAttentionConfig {
409 pub num_heads: usize,
411 pub head_dim: usize,
413 pub output_dim: usize,
415 pub attention_dropout: f64,
417 pub concat_heads: bool,
419 pub negative_slope: f64,
421}
422
423impl Default for GraphAttentionConfig {
424 fn default() -> Self {
425 Self {
426 num_heads: 4,
427 head_dim: 16,
428 output_dim: 64,
429 attention_dropout: 0.0,
430 concat_heads: true,
431 negative_slope: 0.2,
432 }
433 }
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct GATWeights {
439 pub query_weights: Vec<Vec<Vec<f64>>>,
441 pub key_weights: Vec<Vec<Vec<f64>>>,
443 pub value_weights: Vec<Vec<Vec<f64>>>,
445 pub attention_vectors: Vec<Vec<f64>>,
447 pub output_weights: Vec<Vec<f64>>,
449}
450
451impl GATWeights {
452 pub fn random(input_dim: usize, config: &GraphAttentionConfig) -> Self {
454 use rand::{Rng, rng};
455 let mut r = rng();
456
457 let scale = (2.0 / (input_dim + config.head_dim) as f64).sqrt();
458
459 let mut query_weights = Vec::new();
460 let mut key_weights = Vec::new();
461 let mut value_weights = Vec::new();
462 let mut attention_vectors = Vec::new();
463
464 for _ in 0..config.num_heads {
465 let q: Vec<Vec<f64>> = (0..input_dim)
466 .map(|_| {
467 (0..config.head_dim)
468 .map(|_| r.random_range(-scale..scale))
469 .collect()
470 })
471 .collect();
472 let k: Vec<Vec<f64>> = (0..input_dim)
473 .map(|_| {
474 (0..config.head_dim)
475 .map(|_| r.random_range(-scale..scale))
476 .collect()
477 })
478 .collect();
479 let v: Vec<Vec<f64>> = (0..input_dim)
480 .map(|_| {
481 (0..config.head_dim)
482 .map(|_| r.random_range(-scale..scale))
483 .collect()
484 })
485 .collect();
486 let a: Vec<f64> = (0..config.head_dim * 2)
487 .map(|_| r.random_range(-scale..scale))
488 .collect();
489
490 query_weights.push(q);
491 key_weights.push(k);
492 value_weights.push(v);
493 attention_vectors.push(a);
494 }
495
496 let total_dim = if config.concat_heads {
497 config.num_heads * config.head_dim
498 } else {
499 config.head_dim
500 };
501
502 let out_scale = (2.0 / (total_dim + config.output_dim) as f64).sqrt();
503 let output_weights: Vec<Vec<f64>> = (0..total_dim)
504 .map(|_| {
505 (0..config.output_dim)
506 .map(|_| r.random_range(-out_scale..out_scale))
507 .collect()
508 })
509 .collect();
510
511 Self {
512 query_weights,
513 key_weights,
514 value_weights,
515 attention_vectors,
516 output_weights,
517 }
518 }
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct GATResult {
524 pub embeddings: Vec<Vec<f64>>,
526 pub attention_weights: Vec<Vec<(usize, usize, f64)>>,
528}
529
530#[derive(Debug, Clone)]
536pub struct GraphAttention {
537 metadata: KernelMetadata,
538}
539
540impl Default for GraphAttention {
541 fn default() -> Self {
542 Self::new()
543 }
544}
545
546impl GraphAttention {
547 #[must_use]
549 pub fn new() -> Self {
550 Self {
551 metadata: KernelMetadata::batch("graph/graph-attention", Domain::GraphAnalytics)
552 .with_description("Graph attention networks with multi-head attention")
553 .with_throughput(5_000)
554 .with_latency_us(200.0),
555 }
556 }
557
558 #[allow(clippy::needless_range_loop)]
560 pub fn compute(
561 graph: &CsrGraph,
562 node_features: &[Vec<f64>],
563 weights: &GATWeights,
564 config: &GraphAttentionConfig,
565 ) -> GATResult {
566 if graph.num_nodes == 0 || node_features.is_empty() {
567 return GATResult {
568 embeddings: Vec::new(),
569 attention_weights: Vec::new(),
570 };
571 }
572
573 let n = graph.num_nodes;
574
575 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
577 for node in 0..n {
578 let start = graph.row_offsets[node] as usize;
579 let end = graph.row_offsets[node + 1] as usize;
580 for &neighbor in &graph.col_indices[start..end] {
581 adj[node].push(neighbor as usize);
582 if !adj[neighbor as usize].contains(&node) {
583 adj[neighbor as usize].push(node);
584 }
585 }
586 }
587 for i in 0..n {
588 if !adj[i].contains(&i) {
589 adj[i].push(i);
590 }
591 }
592
593 let mut head_outputs: Vec<Vec<Vec<f64>>> = Vec::new();
595 let mut all_attention_weights: Vec<Vec<(usize, usize, f64)>> = Vec::new();
596
597 for head in 0..config.num_heads {
598 let (output, attn_weights) = Self::compute_head(
599 node_features,
600 &adj,
601 &weights.query_weights[head],
602 &weights.key_weights[head],
603 &weights.value_weights[head],
604 &weights.attention_vectors[head],
605 config,
606 );
607 head_outputs.push(output);
608 all_attention_weights.push(attn_weights);
609 }
610
611 let combined: Vec<Vec<f64>> = if config.concat_heads {
613 (0..n)
614 .map(|i| head_outputs.iter().flat_map(|h| h[i].clone()).collect())
615 .collect()
616 } else {
617 (0..n)
619 .map(|i| {
620 let dim = head_outputs[0][i].len();
621 let mut avg = vec![0.0; dim];
622 for h in &head_outputs {
623 for (j, &v) in h[i].iter().enumerate() {
624 avg[j] += v;
625 }
626 }
627 avg.iter_mut().for_each(|v| *v /= config.num_heads as f64);
628 avg
629 })
630 .collect()
631 };
632
633 let embeddings: Vec<Vec<f64>> = combined
635 .iter()
636 .map(|c| Self::linear_transform(c, &weights.output_weights))
637 .collect();
638
639 GATResult {
640 embeddings,
641 attention_weights: all_attention_weights,
642 }
643 }
644
645 #[allow(clippy::type_complexity)]
647 fn compute_head(
648 features: &[Vec<f64>],
649 adj: &[Vec<usize>],
650 query_w: &[Vec<f64>],
651 key_w: &[Vec<f64>],
652 value_w: &[Vec<f64>],
653 attn_vec: &[f64],
654 config: &GraphAttentionConfig,
655 ) -> (Vec<Vec<f64>>, Vec<(usize, usize, f64)>) {
656 let n = features.len();
657 let head_dim = config.head_dim;
658
659 let queries: Vec<Vec<f64>> = features
661 .iter()
662 .map(|f| Self::linear_transform(f, query_w))
663 .collect();
664 let keys: Vec<Vec<f64>> = features
665 .iter()
666 .map(|f| Self::linear_transform(f, key_w))
667 .collect();
668 let values: Vec<Vec<f64>> = features
669 .iter()
670 .map(|f| Self::linear_transform(f, value_w))
671 .collect();
672
673 let mut output = vec![vec![0.0; head_dim]; n];
674 let mut attention_list: Vec<(usize, usize, f64)> = Vec::new();
675
676 for i in 0..n {
677 if adj[i].is_empty() {
678 continue;
679 }
680
681 let mut scores: Vec<f64> = Vec::with_capacity(adj[i].len());
683
684 for &j in &adj[i] {
685 let mut concat = queries[i].clone();
687 concat.extend(keys[j].iter().cloned());
688
689 let score: f64 = concat.iter().zip(attn_vec.iter()).map(|(c, a)| c * a).sum();
690
691 let score = if score > 0.0 {
693 score
694 } else {
695 config.negative_slope * score
696 };
697
698 scores.push(score);
699 }
700
701 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
703 let exp_scores: Vec<f64> = scores.iter().map(|s| (s - max_score).exp()).collect();
704 let sum_exp: f64 = exp_scores.iter().sum();
705 let attention: Vec<f64> = exp_scores.iter().map(|e| e / sum_exp).collect();
706
707 for (idx, &j) in adj[i].iter().enumerate() {
709 let attn = attention[idx];
710 attention_list.push((i, j, attn));
711
712 for (k, &v) in values[j].iter().enumerate() {
713 output[i][k] += attn * v;
714 }
715 }
716 }
717
718 (output, attention_list)
719 }
720
721 fn linear_transform(input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
723 if weights.is_empty() {
724 return Vec::new();
725 }
726
727 let out_dim = weights[0].len();
728 let mut output = vec![0.0; out_dim];
729
730 for (i, &x) in input.iter().enumerate() {
731 if i < weights.len() {
732 for (j, &w) in weights[i].iter().enumerate() {
733 output[j] += x * w;
734 }
735 }
736 }
737
738 output
739 }
740
741 pub fn node_importance(attention_weights: &[(usize, usize, f64)], n: usize) -> Vec<f64> {
743 let mut importance = vec![0.0; n];
744 let mut counts = vec![0usize; n];
745
746 for &(_, target, weight) in attention_weights {
747 if target < n {
748 importance[target] += weight;
749 counts[target] += 1;
750 }
751 }
752
753 for i in 0..n {
755 if counts[i] > 0 {
756 importance[i] /= counts[i] as f64;
757 }
758 }
759
760 importance
761 }
762}
763
764impl GpuKernel for GraphAttention {
765 fn metadata(&self) -> &KernelMetadata {
766 &self.metadata
767 }
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773 use std::collections::HashMap;
774
775 fn create_test_graph() -> CsrGraph {
776 CsrGraph::from_edges(3, &[(0, 1), (1, 2), (2, 0)])
778 }
779
780 #[test]
781 fn test_gnn_inference_metadata() {
782 let kernel = GNNInference::new();
783 assert_eq!(kernel.metadata().id, "graph/gnn-inference");
784 }
785
786 #[test]
787 fn test_gnn_inference_basic() {
788 let graph = create_test_graph();
789 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
790
791 let config = GNNConfig {
792 num_layers: 2,
793 hidden_dim: 4,
794 output_dim: 2,
795 ..Default::default()
796 };
797
798 let weights = GNNWeights::random(2, &config);
799 let result = GNNInference::compute(&graph, &features, &weights, &config);
800
801 assert_eq!(result.embeddings.len(), 3);
802 assert_eq!(result.embeddings[0].len(), 2);
803 assert!(result.predictions.is_some());
804 }
805
806 #[test]
807 fn test_gnn_aggregation_types() {
808 let graph = create_test_graph();
809 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
810
811 for agg in [
812 AggregationType::Sum,
813 AggregationType::Mean,
814 AggregationType::Max,
815 AggregationType::SAGE,
816 ] {
817 let config = GNNConfig {
818 aggregation: agg,
819 num_layers: 1,
820 hidden_dim: 4,
821 output_dim: 2,
822 ..Default::default()
823 };
824
825 let weights = GNNWeights::random(2, &config);
826 let result = GNNInference::compute(&graph, &features, &weights, &config);
827
828 assert_eq!(result.embeddings.len(), 3);
829 }
830 }
831
832 #[test]
833 fn test_gnn_empty_graph() {
834 let graph = CsrGraph::empty();
835 let features: Vec<Vec<f64>> = vec![];
836 let config = GNNConfig::default();
837 let weights = GNNWeights::random(2, &config);
838
839 let result = GNNInference::compute(&graph, &features, &weights, &config);
840 assert!(result.embeddings.is_empty());
841 }
842
843 #[test]
844 fn test_graph_attention_metadata() {
845 let kernel = GraphAttention::new();
846 assert_eq!(kernel.metadata().id, "graph/graph-attention");
847 }
848
849 #[test]
850 fn test_graph_attention_basic() {
851 let graph = create_test_graph();
852 let features = vec![
853 vec![1.0, 0.0, 0.0, 0.0],
854 vec![0.0, 1.0, 0.0, 0.0],
855 vec![0.0, 0.0, 1.0, 0.0],
856 ];
857
858 let config = GraphAttentionConfig {
859 num_heads: 2,
860 head_dim: 4,
861 output_dim: 3,
862 ..Default::default()
863 };
864
865 let weights = GATWeights::random(4, &config);
866 let result = GraphAttention::compute(&graph, &features, &weights, &config);
867
868 assert_eq!(result.embeddings.len(), 3);
869 assert_eq!(result.embeddings[0].len(), 3);
870 assert!(!result.attention_weights.is_empty());
871 }
872
873 #[test]
874 fn test_attention_weights_sum_to_one() {
875 let graph = create_test_graph();
876 let features = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
877
878 let config = GraphAttentionConfig {
879 num_heads: 1,
880 head_dim: 4,
881 output_dim: 2,
882 ..Default::default()
883 };
884
885 let weights = GATWeights::random(2, &config);
886 let result = GraphAttention::compute(&graph, &features, &weights, &config);
887
888 let mut sums: HashMap<usize, f64> = HashMap::new();
890 for &(src, _, weight) in &result.attention_weights[0] {
891 *sums.entry(src).or_insert(0.0) += weight;
892 }
893
894 for (_, sum) in sums {
896 assert!(
897 (sum - 1.0).abs() < 0.01,
898 "Attention should sum to 1, got {}",
899 sum
900 );
901 }
902 }
903
904 #[test]
905 fn test_node_importance() {
906 let attn_weights = vec![
907 (0, 1, 0.5),
908 (0, 2, 0.5),
909 (1, 0, 0.3),
910 (1, 2, 0.7),
911 (2, 0, 0.4),
912 (2, 1, 0.6),
913 ];
914
915 let importance = GraphAttention::node_importance(&attn_weights, 3);
916
917 assert_eq!(importance.len(), 3);
918 assert!(importance.iter().all(|&i| i >= 0.0));
920 }
921
922 #[test]
923 fn test_activation_functions() {
924 assert_eq!(GNNInference::activate(1.0, ActivationType::ReLU), 1.0);
925 assert_eq!(GNNInference::activate(-1.0, ActivationType::ReLU), 0.0);
926 assert!((GNNInference::activate(0.0, ActivationType::Sigmoid) - 0.5).abs() < 0.001);
927 assert_eq!(GNNInference::activate(1.0, ActivationType::None), 1.0);
928 }
929}