1use crate::error::{GraphError, Result};
25use scirs2_core::ndarray::{Array1, Array2};
26use scirs2_core::random::{Rng, RngExt, SeedableRng};
27
28#[derive(Debug, Clone)]
34pub struct HgnnLayerConfig {
35 pub in_dim: usize,
37 pub out_dim: usize,
39 pub use_attention: bool,
41 pub n_heads: usize,
43 pub dropout: f64,
45}
46
47impl Default for HgnnLayerConfig {
48 fn default() -> Self {
49 Self {
50 in_dim: 64,
51 out_dim: 64,
52 use_attention: false,
53 n_heads: 1,
54 dropout: 0.0,
55 }
56 }
57}
58
59pub struct HgnnLayer {
68 theta: Array2<f64>,
70 attn_vec: Array1<f64>,
72 config: HgnnLayerConfig,
73}
74
75impl HgnnLayer {
76 pub fn new(config: HgnnLayerConfig, seed: u64) -> Self {
82 let mut rng = scirs2_core::random::ChaCha20Rng::seed_from_u64(seed);
83
84 let scale = (6.0 / (config.in_dim + config.out_dim) as f64).sqrt();
86 let theta = Array2::from_shape_fn((config.in_dim, config.out_dim), |_| {
87 rng.random::<f64>() * 2.0 * scale - scale
88 });
89
90 let attn_scale = (6.0 / (config.in_dim + 1) as f64).sqrt();
91 let attn_vec = Array1::from_shape_fn(config.in_dim, |_| {
92 rng.random::<f64>() * 2.0 * attn_scale - attn_scale
93 });
94
95 HgnnLayer {
96 theta,
97 attn_vec,
98 config,
99 }
100 }
101
102 pub fn forward(
115 &self,
116 incidence: &Array2<f64>,
117 node_feats: &Array2<f64>,
118 edge_weights: Option<&Array1<f64>>,
119 ) -> Result<Array2<f64>> {
120 let (n_nodes, n_edges) = incidence.dim();
121 let (feat_n, in_dim) = node_feats.dim();
122
123 if feat_n != n_nodes {
124 return Err(GraphError::InvalidParameter {
125 param: "node_feats".to_string(),
126 value: format!("rows={feat_n}"),
127 expected: format!("rows={n_nodes} (matching incidence rows)"),
128 context: "HgnnLayer::forward".to_string(),
129 });
130 }
131 if in_dim != self.config.in_dim {
132 return Err(GraphError::InvalidParameter {
133 param: "node_feats".to_string(),
134 value: format!("cols={in_dim}"),
135 expected: format!("cols={}", self.config.in_dim),
136 context: "HgnnLayer::forward".to_string(),
137 });
138 }
139
140 let default_w = Array1::ones(n_edges);
142 let w: &Array1<f64> = edge_weights.unwrap_or(&default_w);
143
144 if self.config.use_attention {
145 self.forward_attention(incidence, node_feats, w)
146 } else {
147 self.forward_standard(incidence, node_feats, w)
148 }
149 }
150
151 fn forward_standard(
153 &self,
154 incidence: &Array2<f64>,
155 node_feats: &Array2<f64>,
156 w: &Array1<f64>,
157 ) -> Result<Array2<f64>> {
158 let (n_nodes, n_edges) = incidence.dim();
159
160 let mut dv: Array1<f64> = Array1::zeros(n_nodes);
162 for i in 0..n_nodes {
163 for e in 0..n_edges {
164 dv[i] += incidence[[i, e]] * w[e];
165 }
166 }
167
168 let mut de: Array1<f64> = Array1::zeros(n_edges);
170 for e in 0..n_edges {
171 for i in 0..n_nodes {
172 de[e] += incidence[[i, e]];
173 }
174 }
175
176 let dv_inv_sqrt: Array1<f64> =
178 dv.mapv(|d: f64| if d > 1e-12 { 1.0 / d.sqrt() } else { 0.0 });
179
180 let de_inv: Array1<f64> = de.mapv(|d: f64| if d > 1e-12 { 1.0 / d } else { 0.0 });
182
183 let mut t1: Array2<f64> = Array2::zeros((n_nodes, n_edges));
187 for i in 0..n_nodes {
188 for e in 0..n_edges {
189 t1[[i, e]] = incidence[[i, e]] * w[e] * de_inv[e];
190 }
191 }
192
193 let mut t2: Array2<f64> = Array2::zeros((n_nodes, n_nodes));
195 for i in 0..n_nodes {
196 for j in 0..n_nodes {
197 let mut val = 0.0_f64;
198 for e in 0..n_edges {
199 val += t1[[i, e]] * incidence[[j, e]];
200 }
201 t2[[i, j]] = val;
202 }
203 }
204
205 let mut a_tilde: Array2<f64> = Array2::zeros((n_nodes, n_nodes));
207 for i in 0..n_nodes {
208 for j in 0..n_nodes {
209 a_tilde[[i, j]] = dv_inv_sqrt[i] * t2[[i, j]] * dv_inv_sqrt[j];
210 }
211 }
212
213 let in_dim = node_feats.dim().1;
216 let out_dim = self.config.out_dim;
217 let mut z: Array2<f64> = Array2::zeros((n_nodes, in_dim));
218 for i in 0..n_nodes {
219 for k in 0..in_dim {
220 let mut val = 0.0_f64;
221 for j in 0..n_nodes {
222 val += a_tilde[[i, j]] * node_feats[[j, k]];
223 }
224 z[[i, k]] = val;
225 }
226 }
227
228 let mut output: Array2<f64> = Array2::zeros((n_nodes, out_dim));
230 for i in 0..n_nodes {
231 for k in 0..out_dim {
232 let mut val = 0.0;
233 for d in 0..in_dim {
234 val += z[[i, d]] * self.theta[[d, k]];
235 }
236 output[[i, k]] = val;
237 }
238 }
239
240 Ok(output)
241 }
242
243 fn forward_attention(
251 &self,
252 incidence: &Array2<f64>,
253 node_feats: &Array2<f64>,
254 w: &Array1<f64>,
255 ) -> Result<Array2<f64>> {
256 let (n_nodes, n_edges) = incidence.dim();
257 let in_dim = node_feats.dim().1;
258 let out_dim = self.config.out_dim;
259
260 let mut de: Array1<f64> = Array1::zeros(n_edges);
262 for e in 0..n_edges {
263 for i in 0..n_nodes {
264 de[e] += incidence[[i, e]];
265 }
266 }
267 let de_inv: Array1<f64> = de.mapv(|d: f64| if d > 1e-12 { 1.0 / d } else { 0.0 });
268
269 let mut m_edge: Array2<f64> = Array2::zeros((n_edges, in_dim));
271 for e in 0..n_edges {
272 let mut sum: Array1<f64> = Array1::zeros(in_dim);
273 for i in 0..n_nodes {
274 if incidence[[i, e]] > 0.0 {
275 for d in 0..in_dim {
276 sum[d] += node_feats[[i, d]];
277 }
278 }
279 }
280 for d in 0..in_dim {
281 m_edge[[e, d]] = sum[d] * de_inv[e];
282 }
283 }
284
285 let leaky_alpha = 0.2_f64;
287 let mut output: Array2<f64> = Array2::zeros((n_nodes, out_dim));
288
289 for v in 0..n_nodes {
290 let edges_of_v: Vec<usize> =
292 (0..n_edges).filter(|&e| incidence[[v, e]] > 0.0).collect();
293
294 if edges_of_v.is_empty() {
295 continue;
297 }
298
299 let score_v: f64 = {
302 let raw: f64 = (0..in_dim)
303 .map(|d| self.attn_vec[d] * node_feats[[v, d]])
304 .sum();
305 if raw >= 0.0 {
306 raw
307 } else {
308 leaky_alpha * raw
309 }
310 };
311
312 let edge_scores: Vec<f64> = edges_of_v
315 .iter()
316 .map(|&e| {
317 let raw: f64 = (0..in_dim).map(|d| self.attn_vec[d] * m_edge[[e, d]]).sum();
318 let s = raw + score_v;
319 let leaky = if s >= 0.0 { s } else { leaky_alpha * s };
320 leaky * w[e]
321 })
322 .collect();
323
324 let max_s = edge_scores
326 .iter()
327 .cloned()
328 .fold(f64::NEG_INFINITY, f64::max);
329 let exps: Vec<f64> = edge_scores.iter().map(|&s| (s - max_s).exp()).collect();
330 let sum_exp: f64 = exps.iter().sum();
331 let alphas: Vec<f64> = exps.iter().map(|&e| e / sum_exp).collect();
332
333 let mut agg: Array1<f64> = Array1::zeros(in_dim);
335 for (k, &e) in edges_of_v.iter().enumerate() {
336 for d in 0..in_dim {
337 agg[d] += alphas[k] * m_edge[[e, d]];
338 }
339 }
340
341 for k in 0..out_dim {
343 let mut val = 0.0_f64;
344 for d in 0..in_dim {
345 val += agg[d] * self.theta[[d, k]];
346 }
347 output[[v, k]] = val;
348 }
349 }
350
351 Ok(output)
352 }
353
354 pub fn relu(x: Array2<f64>) -> Array2<f64> {
356 x.mapv(|v| if v > 0.0 { v } else { 0.0 })
357 }
358
359 pub fn n_params(&self) -> usize {
361 self.theta.len() + self.attn_vec.len()
362 }
363}
364
365pub struct HgnnNetwork {
375 layers: Vec<HgnnLayer>,
376}
377
378impl HgnnNetwork {
379 pub fn new(dims: &[usize], use_attention: bool, seed: u64) -> Self {
390 assert!(
391 dims.len() >= 2,
392 "dims must have at least 2 elements (in_dim, out_dim)"
393 );
394 let layers = dims
395 .windows(2)
396 .enumerate()
397 .map(|(i, w)| {
398 let cfg = HgnnLayerConfig {
399 in_dim: w[0],
400 out_dim: w[1],
401 use_attention,
402 n_heads: 1,
403 dropout: 0.0,
404 };
405 HgnnLayer::new(cfg, seed + i as u64)
406 })
407 .collect();
408 HgnnNetwork { layers }
409 }
410
411 pub fn forward(
424 &self,
425 incidence: &Array2<f64>,
426 node_feats: &Array2<f64>,
427 edge_weights: Option<&Array1<f64>>,
428 ) -> Result<Array2<f64>> {
429 let n_layers = self.layers.len();
430 let mut x = node_feats.clone();
431
432 for (i, layer) in self.layers.iter().enumerate() {
433 x = layer.forward(incidence, &x, edge_weights)?;
434 if i + 1 < n_layers {
436 x = HgnnLayer::relu(x);
437 }
438 }
439
440 Ok(x)
441 }
442
443 pub fn n_params(&self) -> usize {
445 self.layers.iter().map(|l| l.n_params()).sum()
446 }
447
448 pub fn depth(&self) -> usize {
450 self.layers.len()
451 }
452}
453
454#[cfg(test)]
459mod tests {
460 use super::*;
461 use scirs2_core::ndarray::Array2;
462
463 fn small_incidence() -> Array2<f64> {
466 let mut h = Array2::zeros((4, 3));
467 h[[0, 0]] = 1.0;
468 h[[1, 0]] = 1.0;
469 h[[1, 1]] = 1.0;
470 h[[2, 1]] = 1.0;
471 h[[2, 2]] = 1.0;
472 h[[3, 2]] = 1.0;
473 h
474 }
475
476 fn identity_incidence(n: usize) -> Array2<f64> {
478 let mut h = Array2::zeros((n, n));
479 for i in 0..n {
480 h[[i, i]] = 1.0;
481 }
482 h
483 }
484
485 #[test]
486 fn test_output_shape() {
487 let h = small_incidence();
488 let x = Array2::ones((4, 8));
489 let cfg = HgnnLayerConfig {
490 in_dim: 8,
491 out_dim: 16,
492 ..Default::default()
493 };
494 let layer = HgnnLayer::new(cfg, 42);
495 let out = layer.forward(&h, &x, None).expect("forward ok");
496 assert_eq!(out.dim(), (4, 16));
497 }
498
499 #[test]
500 fn test_identity_incidence_is_diagonal() {
501 let n = 4;
504 let h = identity_incidence(n);
505 let in_dim = 4;
507 let out_dim = 4;
508 let x = Array2::eye(n);
509 let cfg = HgnnLayerConfig {
510 in_dim,
511 out_dim,
512 ..Default::default()
513 };
514 let layer = HgnnLayer::new(cfg, 7);
515 let out = layer.forward(&h, &x, None).expect("forward ok");
516 for i in 0..n {
519 for k in 0..out_dim {
520 let diff = (out[[i, k]] - layer.theta[[i, k]]).abs();
521 assert!(
522 diff < 1e-10,
523 "out[{i},{k}]={} != theta[{i},{k}]={}",
524 out[[i, k]],
525 layer.theta[[i, k]]
526 );
527 }
528 }
529 }
530
531 #[test]
532 fn test_output_bounded_with_unit_features() {
533 let h = small_incidence();
535 let x = Array2::ones((4, 4));
536 let cfg = HgnnLayerConfig {
537 in_dim: 4,
538 out_dim: 4,
539 ..Default::default()
540 };
541 let layer = HgnnLayer::new(cfg, 99);
542 let out = layer.forward(&h, &x, None).expect("forward ok");
543 for v in out.iter() {
544 assert!(v.is_finite(), "output contains non-finite value: {v}");
545 }
546 }
547
548 #[test]
549 fn test_zero_dropout_is_identity_of_forward() {
550 let h = small_incidence();
552 let x = Array2::from_shape_fn((4, 8), |(i, j)| (i + j) as f64 * 0.1);
553 let cfg = HgnnLayerConfig {
554 in_dim: 8,
555 out_dim: 4,
556 dropout: 0.0,
557 ..Default::default()
558 };
559 let layer = HgnnLayer::new(cfg, 1);
560 let out1 = layer.forward(&h, &x, None).expect("ok");
561 let out2 = layer.forward(&h, &x, None).expect("ok");
562 for (a, b) in out1.iter().zip(out2.iter()) {
563 assert!((a - b).abs() < 1e-12);
564 }
565 }
566
567 #[test]
568 fn test_n_params_counts_correctly() {
569 let cfg = HgnnLayerConfig {
570 in_dim: 8,
571 out_dim: 16,
572 use_attention: false,
573 ..Default::default()
574 };
575 let layer = HgnnLayer::new(cfg, 0);
576 assert_eq!(layer.n_params(), 8 * 16 + 8);
578 }
579
580 #[test]
581 fn test_multi_layer_output_shape() {
582 let h = small_incidence(); let x = Array2::ones((4, 16));
584 let net = HgnnNetwork::new(&[16, 32, 8], false, 42);
585 let out = net.forward(&h, &x, None).expect("ok");
586 assert_eq!(out.dim(), (4, 8));
587 }
588
589 #[test]
590 fn test_network_n_params() {
591 let net = HgnnNetwork::new(&[8, 16, 4], false, 0);
592 assert_eq!(net.n_params(), 8 * 16 + 8 + 16 * 4 + 16);
596 }
597
598 #[test]
599 fn test_theta_small_init() {
600 let cfg = HgnnLayerConfig {
601 in_dim: 64,
602 out_dim: 64,
603 ..Default::default()
604 };
605 let layer = HgnnLayer::new(cfg, 1234);
606 let scale = (6.0_f64 / (64.0 + 64.0)).sqrt();
607 for v in layer.theta.iter() {
608 assert!(
609 v.abs() <= scale + 1e-9,
610 "theta value {v} exceeds Xavier bound {scale}"
611 );
612 }
613 }
614
615 #[test]
616 fn test_attention_output_shape() {
617 let h = small_incidence();
618 let x = Array2::ones((4, 8));
619 let cfg = HgnnLayerConfig {
620 in_dim: 8,
621 out_dim: 4,
622 use_attention: true,
623 n_heads: 1,
624 dropout: 0.0,
625 };
626 let layer = HgnnLayer::new(cfg, 5);
627 let out = layer.forward(&h, &x, None).expect("ok");
628 assert_eq!(out.dim(), (4, 4));
629 }
630
631 #[test]
632 fn test_relu_zeros_negatives() {
633 let x = Array2::from_shape_vec((2, 3), vec![-1.0, 0.0, 1.0, -0.5, 2.0, -3.0]).expect("ok");
634 let r = HgnnLayer::relu(x);
635 assert_eq!(r[[0, 0]], 0.0);
636 assert_eq!(r[[0, 2]], 1.0);
637 assert_eq!(r[[1, 1]], 2.0);
638 assert_eq!(r[[1, 2]], 0.0);
639 }
640
641 #[test]
642 fn test_edge_weights_change_output() {
643 let h = small_incidence();
644 let x = Array2::ones((4, 4));
645 let cfg = HgnnLayerConfig {
646 in_dim: 4,
647 out_dim: 4,
648 ..Default::default()
649 };
650 let layer = HgnnLayer::new(cfg, 42);
651 let w1 = Array1::ones(3);
652 let w2 = Array1::from_vec(vec![2.0, 1.0, 0.5]);
653 let out1 = layer.forward(&h, &x, Some(&w1)).expect("ok");
654 let out2 = layer.forward(&h, &x, Some(&w2)).expect("ok");
655 let diff: f64 = out1
657 .iter()
658 .zip(out2.iter())
659 .map(|(a, b)| (a - b).abs())
660 .sum();
661 assert!(
662 diff > 1e-10,
663 "different edge weights should produce different output"
664 );
665 }
666}