1use ndarray::{Array1, Array2, Array3};
7use num_complex::Complex64;
8use std::collections::HashMap;
9use std::f64::consts::PI;
10
11use crate::autodiff::DifferentiableParam;
12use crate::error::{MLError, Result};
13use crate::utils::VariationalCircuit;
14use quantrs2_circuit::prelude::*;
15use quantrs2_core::gate::{multi::*, single::*, GateOp};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum ActivationType {
20 Linear,
22 ReLU,
24 Sigmoid,
26 Tanh,
28}
29
30#[derive(Debug, Clone)]
32pub struct QuantumGraph {
33 num_nodes: usize,
35 adjacency: Array2<f64>,
37 node_features: Array2<f64>,
39 edge_features: Option<HashMap<(usize, usize), Array1<f64>>>,
41 graph_features: Option<Array1<f64>>,
43}
44
45impl QuantumGraph {
46 pub fn new(num_nodes: usize, edges: Vec<(usize, usize)>, node_features: Array2<f64>) -> Self {
48 let mut adjacency = Array2::zeros((num_nodes, num_nodes));
49
50 for (src, dst) in edges {
52 adjacency[[src, dst]] = 1.0;
53 adjacency[[dst, src]] = 1.0; }
55
56 Self {
57 num_nodes,
58 adjacency,
59 node_features,
60 edge_features: None,
61 graph_features: None,
62 }
63 }
64
65 pub fn with_edge_features(
67 mut self,
68 edge_features: HashMap<(usize, usize), Array1<f64>>,
69 ) -> Self {
70 self.edge_features = Some(edge_features);
71 self
72 }
73
74 pub fn with_graph_features(mut self, graph_features: Array1<f64>) -> Self {
76 self.graph_features = Some(graph_features);
77 self
78 }
79
80 pub fn degree(&self, node: usize) -> usize {
82 self.adjacency
83 .row(node)
84 .iter()
85 .filter(|&&x| x > 0.0)
86 .count()
87 }
88
89 pub fn neighbors(&self, node: usize) -> Vec<usize> {
91 self.adjacency
92 .row(node)
93 .iter()
94 .enumerate()
95 .filter(|(_, &val)| val > 0.0)
96 .map(|(idx, _)| idx)
97 .collect()
98 }
99
100 pub fn laplacian(&self) -> Array2<f64> {
102 let mut degree_matrix = Array2::zeros((self.num_nodes, self.num_nodes));
103 for i in 0..self.num_nodes {
104 degree_matrix[[i, i]] = self.degree(i) as f64;
105 }
106 °ree_matrix - &self.adjacency
107 }
108
109 pub fn normalized_laplacian(&self) -> Array2<f64> {
111 let mut degree_matrix = Array2::zeros((self.num_nodes, self.num_nodes));
112 let mut degree_sqrt_inv = Array1::zeros(self.num_nodes);
113
114 for i in 0..self.num_nodes {
115 let degree = self.degree(i) as f64;
116 degree_matrix[[i, i]] = degree;
117 if degree > 0.0 {
118 degree_sqrt_inv[i] = 1.0 / degree.sqrt();
119 }
120 }
121
122 let mut norm_laplacian = Array2::eye(self.num_nodes);
123 for i in 0..self.num_nodes {
124 for j in 0..self.num_nodes {
125 if self.adjacency[[i, j]] > 0.0 {
126 norm_laplacian[[i, j]] -=
127 degree_sqrt_inv[i] * self.adjacency[[i, j]] * degree_sqrt_inv[j];
128 }
129 }
130 }
131
132 norm_laplacian
133 }
134}
135
136#[derive(Debug)]
138pub struct QuantumGCNLayer {
139 input_dim: usize,
141 output_dim: usize,
143 num_qubits: usize,
145 node_circuit: VariationalCircuit,
147 aggregation_circuit: VariationalCircuit,
149 parameters: HashMap<String, f64>,
151 activation: ActivationType,
153}
154
155impl QuantumGCNLayer {
156 pub fn new(input_dim: usize, output_dim: usize, activation: ActivationType) -> Self {
158 let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
159 let node_circuit = Self::build_node_circuit(num_qubits);
160 let aggregation_circuit = Self::build_aggregation_circuit(num_qubits);
161
162 Self {
163 input_dim,
164 output_dim,
165 num_qubits,
166 node_circuit,
167 aggregation_circuit,
168 parameters: HashMap::new(),
169 activation,
170 }
171 }
172
173 fn build_node_circuit(num_qubits: usize) -> VariationalCircuit {
175 let mut circuit = VariationalCircuit::new(num_qubits);
176
177 for q in 0..num_qubits {
179 circuit.add_gate("RY", vec![q], vec![format!("node_encode_{}", q)]);
180 }
181
182 for layer in 0..2 {
184 for q in 0..num_qubits - 1 {
185 circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
186 }
187 if num_qubits > 2 {
188 circuit.add_gate("CNOT", vec![num_qubits - 1, 0], vec![]);
189 }
190
191 for q in 0..num_qubits {
193 circuit.add_gate("RX", vec![q], vec![format!("node_rx_{}_{}", layer, q)]);
194 circuit.add_gate("RZ", vec![q], vec![format!("node_rz_{}_{}", layer, q)]);
195 }
196 }
197
198 circuit
199 }
200
201 fn build_aggregation_circuit(num_qubits: usize) -> VariationalCircuit {
203 let mut circuit = VariationalCircuit::new(num_qubits * 2); for q in 0..num_qubits {
207 circuit.add_gate("CZ", vec![q, q + num_qubits], vec![]);
208 }
209
210 for q in 0..num_qubits * 2 {
212 circuit.add_gate("RY", vec![q], vec![format!("agg_ry_{}", q)]);
213 }
214
215 for q in 0..num_qubits * 2 - 1 {
217 circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
218 }
219
220 for q in 0..num_qubits {
222 circuit.add_gate("RX", vec![q], vec![format!("agg_final_{}", q)]);
223 }
224
225 circuit
226 }
227
228 pub fn forward(&self, graph: &QuantumGraph) -> Result<Array2<f64>> {
230 let mut output_features = Array2::zeros((graph.num_nodes, self.output_dim));
231
232 for node in 0..graph.num_nodes {
234 let node_feat = graph.node_features.row(node);
236
237 let neighbors = graph.neighbors(node);
239 let mut aggregated = Array1::zeros(self.input_dim);
240
241 for &neighbor in &neighbors {
243 let neighbor_feat = graph.node_features.row(neighbor);
244 aggregated = &aggregated + &neighbor_feat.to_owned();
245 }
246
247 let degree = neighbors.len().max(1) as f64;
249 aggregated = aggregated / degree;
250
251 let transformed = self.quantum_transform(&node_feat.to_owned(), &aggregated)?;
253
254 for i in 0..self.output_dim {
256 output_features[[node, i]] = transformed[i];
257 }
258 }
259
260 Ok(output_features)
261 }
262
263 fn quantum_transform(
265 &self,
266 node_features: &Array1<f64>,
267 aggregated_features: &Array1<f64>,
268 ) -> Result<Array1<f64>> {
269 let node_encoded = self.encode_features(node_features)?;
271 let agg_encoded = self.encode_features(aggregated_features)?;
272
273 let mut output = Array1::zeros(self.output_dim);
275
276 for i in 0..self.output_dim {
278 let idx_node = i % node_features.len();
279 let idx_agg = i % aggregated_features.len();
280
281 output[i] = match self.activation {
282 ActivationType::ReLU => {
283 (0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]).max(0.0)
284 }
285 ActivationType::Tanh => {
286 (0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]).tanh()
287 }
288 ActivationType::Sigmoid => {
289 let x = 0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg];
290 1.0 / (1.0 + (-x).exp())
291 }
292 ActivationType::Linear => {
293 0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]
294 }
295 };
296 }
297
298 Ok(output)
299 }
300
301 fn encode_features(&self, features: &Array1<f64>) -> Result<Vec<Complex64>> {
303 let state_dim = 2_usize.pow(self.num_qubits as u32);
304 let mut quantum_state = vec![Complex64::new(0.0, 0.0); state_dim];
305
306 let norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt();
308 if norm < 1e-10 {
309 quantum_state[0] = Complex64::new(1.0, 0.0);
310 } else {
311 for (i, &val) in features.iter().enumerate() {
312 if i < state_dim {
313 quantum_state[i] = Complex64::new(val / norm, 0.0);
314 }
315 }
316 }
317
318 Ok(quantum_state)
319 }
320}
321
322#[derive(Debug)]
324pub struct QuantumGATLayer {
325 input_dim: usize,
327 output_dim: usize,
329 num_heads: usize,
331 attention_circuits: Vec<VariationalCircuit>,
333 transform_circuits: Vec<VariationalCircuit>,
335 dropout_rate: f64,
337}
338
339impl QuantumGATLayer {
340 pub fn new(input_dim: usize, output_dim: usize, num_heads: usize, dropout_rate: f64) -> Self {
342 let mut attention_circuits = Vec::new();
343 let mut transform_circuits = Vec::new();
344
345 let qubits_per_head = ((output_dim / num_heads) as f64).log2().ceil() as usize;
346
347 for _ in 0..num_heads {
348 attention_circuits.push(Self::build_attention_circuit(qubits_per_head));
349 transform_circuits.push(Self::build_transform_circuit(qubits_per_head));
350 }
351
352 Self {
353 input_dim,
354 output_dim,
355 num_heads,
356 attention_circuits,
357 transform_circuits,
358 dropout_rate,
359 }
360 }
361
362 fn build_attention_circuit(num_qubits: usize) -> VariationalCircuit {
364 let mut circuit = VariationalCircuit::new(num_qubits * 2);
365
366 for q in 0..num_qubits {
368 circuit.add_gate("RY", vec![q], vec![format!("att_src_{}", q)]);
369 circuit.add_gate("RY", vec![q + num_qubits], vec![format!("att_dst_{}", q)]);
370 }
371
372 for q in 0..num_qubits {
374 circuit.add_gate("CZ", vec![q, q + num_qubits], vec![]);
375 }
376
377 circuit.add_gate("H", vec![0], vec![]);
379 for q in 1..num_qubits * 2 {
380 circuit.add_gate("CNOT", vec![0, q], vec![]);
381 }
382
383 circuit
384 }
385
386 fn build_transform_circuit(num_qubits: usize) -> VariationalCircuit {
388 let mut circuit = VariationalCircuit::new(num_qubits);
389
390 for layer in 0..2 {
392 for q in 0..num_qubits {
393 circuit.add_gate("RY", vec![q], vec![format!("trans_ry_{}_{}", layer, q)]);
394 circuit.add_gate("RZ", vec![q], vec![format!("trans_rz_{}_{}", layer, q)]);
395 }
396
397 for q in 0..num_qubits - 1 {
399 circuit.add_gate("CX", vec![q, q + 1], vec![]);
400 }
401 }
402
403 circuit
404 }
405
406 pub fn forward(&self, graph: &QuantumGraph) -> Result<Array2<f64>> {
408 let head_dim = self.output_dim / self.num_heads;
409 let mut all_head_outputs = Vec::new();
410
411 for head in 0..self.num_heads {
413 let head_output = self.process_attention_head(graph, head)?;
414 all_head_outputs.push(head_output);
415 }
416
417 let mut output = Array2::zeros((graph.num_nodes, self.output_dim));
419 for (h, head_output) in all_head_outputs.iter().enumerate() {
420 for node in 0..graph.num_nodes {
421 for d in 0..head_dim {
422 output[[node, h * head_dim + d]] = head_output[[node, d]];
423 }
424 }
425 }
426
427 Ok(output)
428 }
429
430 fn process_attention_head(&self, graph: &QuantumGraph, head: usize) -> Result<Array2<f64>> {
432 let head_dim = self.output_dim / self.num_heads;
433 let mut output = Array2::zeros((graph.num_nodes, head_dim));
434
435 let attention_scores = self.compute_attention_scores(graph, head)?;
437
438 for node in 0..graph.num_nodes {
440 let neighbors = graph.neighbors(node);
441 let feature_dim = graph.node_features.ncols();
442 let mut weighted_features = Array1::zeros(feature_dim);
443
444 let self_score = attention_scores[[node, node]];
446 weighted_features =
447 &weighted_features + &(&graph.node_features.row(node).to_owned() * self_score);
448
449 for &neighbor in &neighbors {
451 let score = attention_scores[[node, neighbor]];
452 weighted_features =
453 &weighted_features + &(&graph.node_features.row(neighbor).to_owned() * score);
454 }
455
456 let transformed = self.transform_features(&weighted_features, head)?;
458
459 for d in 0..head_dim {
460 output[[node, d]] = transformed[d];
461 }
462 }
463
464 Ok(output)
465 }
466
467 fn compute_attention_scores(&self, graph: &QuantumGraph, head: usize) -> Result<Array2<f64>> {
469 let mut scores = Array2::zeros((graph.num_nodes, graph.num_nodes));
470
471 for i in 0..graph.num_nodes {
473 for j in 0..graph.num_nodes {
474 if i == j || graph.adjacency[[i, j]] > 0.0 {
475 let score = self.quantum_attention_score(
477 &graph.node_features.row(i).to_owned(),
478 &graph.node_features.row(j).to_owned(),
479 head,
480 )?;
481 scores[[i, j]] = score;
482 }
483 }
484
485 let neighbors = graph.neighbors(i);
487 if !neighbors.is_empty() {
488 let mut sum_exp = (scores[[i, i]]).exp();
489 for &j in &neighbors {
490 sum_exp += scores[[i, j]].exp();
491 }
492
493 scores[[i, i]] = scores[[i, i]].exp() / sum_exp;
494 for &j in &neighbors {
495 scores[[i, j]] = scores[[i, j]].exp() / sum_exp;
496 }
497 } else {
498 scores[[i, i]] = 1.0;
499 }
500 }
501
502 Ok(scores)
503 }
504
505 fn quantum_attention_score(
507 &self,
508 feat_i: &Array1<f64>,
509 feat_j: &Array1<f64>,
510 head: usize,
511 ) -> Result<f64> {
512 let dot_product: f64 = feat_i.iter().zip(feat_j.iter()).map(|(a, b)| a * b).sum();
514
515 Ok((dot_product / (self.input_dim as f64).sqrt()).tanh())
516 }
517
518 fn transform_features(&self, features: &Array1<f64>, head: usize) -> Result<Array1<f64>> {
520 let head_dim = self.output_dim / self.num_heads;
521 let mut output = Array1::zeros(head_dim);
522
523 for i in 0..head_dim {
525 if i < features.len() {
526 output[i] = features[i] * (1.0 + 0.1 * (i as f64).sin());
527 }
528 }
529
530 Ok(output)
531 }
532}
533
534#[derive(Debug)]
536pub struct QuantumMPNN {
537 message_circuit: VariationalCircuit,
539 update_circuit: VariationalCircuit,
541 readout_circuit: VariationalCircuit,
543 hidden_dim: usize,
545 num_steps: usize,
547}
548
549impl QuantumMPNN {
550 pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, num_steps: usize) -> Self {
552 let num_qubits = (hidden_dim as f64).log2().ceil() as usize;
553
554 Self {
555 message_circuit: Self::build_message_circuit(num_qubits),
556 update_circuit: Self::build_update_circuit(num_qubits),
557 readout_circuit: Self::build_readout_circuit(num_qubits),
558 hidden_dim,
559 num_steps,
560 }
561 }
562
563 fn build_message_circuit(num_qubits: usize) -> VariationalCircuit {
565 let mut circuit = VariationalCircuit::new(num_qubits * 3); for q in 0..num_qubits * 3 {
569 circuit.add_gate("RY", vec![q], vec![format!("msg_encode_{}", q)]);
570 }
571
572 for layer in 0..2 {
574 for q in 0..num_qubits {
576 circuit.add_gate("CZ", vec![q, q + num_qubits * 2], vec![]);
577 }
578
579 for q in 0..num_qubits {
581 circuit.add_gate("CZ", vec![q + num_qubits, q + num_qubits * 2], vec![]);
582 }
583
584 for q in 0..num_qubits * 3 {
586 circuit.add_gate("RX", vec![q], vec![format!("msg_rx_{}_{}", layer, q)]);
587 }
588 }
589
590 circuit
591 }
592
593 fn build_update_circuit(num_qubits: usize) -> VariationalCircuit {
595 let mut circuit = VariationalCircuit::new(num_qubits * 2); for q in 0..num_qubits {
599 circuit.add_gate("CNOT", vec![q, q + num_qubits], vec![]);
600 }
601
602 for layer in 0..2 {
604 for q in 0..num_qubits * 2 {
605 circuit.add_gate("RY", vec![q], vec![format!("upd_ry_{}_{}", layer, q)]);
606 circuit.add_gate("RZ", vec![q], vec![format!("upd_rz_{}_{}", layer, q)]);
607 }
608
609 for q in 0..num_qubits * 2 - 1 {
611 circuit.add_gate("CX", vec![q, q + 1], vec![]);
612 }
613 }
614
615 circuit
616 }
617
618 fn build_readout_circuit(num_qubits: usize) -> VariationalCircuit {
620 let mut circuit = VariationalCircuit::new(num_qubits);
621
622 for layer in 0..3 {
624 for q in 0..num_qubits {
625 circuit.add_gate("RY", vec![q], vec![format!("read_ry_{}_{}", layer, q)]);
626 }
627
628 for i in 0..num_qubits {
630 for j in i + 1..num_qubits {
631 circuit.add_gate("CZ", vec![i, j], vec![]);
632 }
633 }
634 }
635
636 circuit
637 }
638
639 pub fn forward(&self, graph: &QuantumGraph) -> Result<Array1<f64>> {
641 let mut hidden_states = Array2::zeros((graph.num_nodes, self.hidden_dim));
643
644 for node in 0..graph.num_nodes {
646 for d in 0..self.hidden_dim.min(graph.node_features.ncols()) {
647 hidden_states[[node, d]] = graph.node_features[[node, d]];
648 }
649 }
650
651 for _ in 0..self.num_steps {
653 hidden_states = self.message_passing_step(graph, &hidden_states)?;
654 }
655
656 self.readout(graph, &hidden_states)
658 }
659
660 fn message_passing_step(
662 &self,
663 graph: &QuantumGraph,
664 hidden_states: &Array2<f64>,
665 ) -> Result<Array2<f64>> {
666 let mut new_hidden = Array2::zeros((graph.num_nodes, self.hidden_dim));
667
668 for node in 0..graph.num_nodes {
669 let neighbors = graph.neighbors(node);
670 let mut messages = Array1::zeros(self.hidden_dim);
671
672 for &neighbor in &neighbors {
674 let message = self.compute_message(
675 &hidden_states.row(neighbor).to_owned(),
676 &hidden_states.row(node).to_owned(),
677 graph
678 .edge_features
679 .as_ref()
680 .and_then(|ef| ef.get(&(neighbor, node))),
681 )?;
682 messages = &messages + &message;
683 }
684
685 let updated = self.update_node(&hidden_states.row(node).to_owned(), &messages)?;
687
688 new_hidden.row_mut(node).assign(&updated);
689 }
690
691 Ok(new_hidden)
692 }
693
694 fn compute_message(
696 &self,
697 source_hidden: &Array1<f64>,
698 dest_hidden: &Array1<f64>,
699 edge_features: Option<&Array1<f64>>,
700 ) -> Result<Array1<f64>> {
701 let mut message = Array1::zeros(self.hidden_dim);
703
704 for i in 0..self.hidden_dim {
705 let src_val = if i < source_hidden.len() {
706 source_hidden[i]
707 } else {
708 0.0
709 };
710 let dst_val = if i < dest_hidden.len() {
711 dest_hidden[i]
712 } else {
713 0.0
714 };
715 let edge_val = edge_features
716 .and_then(|ef| ef.get(i))
717 .copied()
718 .unwrap_or(1.0);
719
720 message[i] = (src_val + dst_val) * edge_val * 0.5;
721 }
722
723 Ok(message)
724 }
725
726 fn update_node(&self, hidden: &Array1<f64>, messages: &Array1<f64>) -> Result<Array1<f64>> {
728 let mut new_hidden = Array1::zeros(self.hidden_dim);
730
731 for i in 0..self.hidden_dim {
732 let h = if i < hidden.len() { hidden[i] } else { 0.0 };
733 let m = if i < messages.len() { messages[i] } else { 0.0 };
734
735 let z = (h + m).tanh(); let r = 1.0 / (1.0 + (-(h * m)).exp()); let h_tilde = ((r * h) + m).tanh(); new_hidden[i] = (1.0 - z) * h + z * h_tilde;
741 }
742
743 Ok(new_hidden)
744 }
745
746 fn readout(&self, graph: &QuantumGraph, hidden_states: &Array2<f64>) -> Result<Array1<f64>> {
748 let mut global_state: Array1<f64> = Array1::zeros(self.hidden_dim);
750
751 for node in 0..graph.num_nodes {
752 global_state = &global_state + &hidden_states.row(node).to_owned();
753 }
754 global_state = global_state / (graph.num_nodes as f64);
755
756 let mut output = Array1::zeros(self.hidden_dim);
758 for i in 0..self.hidden_dim {
759 output[i] = global_state[i].tanh();
760 }
761
762 Ok(output)
763 }
764}
765
766#[derive(Debug)]
768pub struct QuantumGraphPool {
769 pool_ratio: f64,
771 method: PoolingMethod,
773 score_circuit: VariationalCircuit,
775}
776
777#[derive(Debug, Clone)]
778pub enum PoolingMethod {
779 TopK,
781 SelfAttention,
783 DiffPool,
785}
786
787impl QuantumGraphPool {
788 pub fn new(pool_ratio: f64, method: PoolingMethod, feature_dim: usize) -> Self {
790 let num_qubits = (feature_dim as f64).log2().ceil() as usize;
791
792 Self {
793 pool_ratio,
794 method,
795 score_circuit: Self::build_score_circuit(num_qubits),
796 }
797 }
798
799 fn build_score_circuit(num_qubits: usize) -> VariationalCircuit {
801 let mut circuit = VariationalCircuit::new(num_qubits);
802
803 for layer in 0..2 {
805 for q in 0..num_qubits {
806 circuit.add_gate("RY", vec![q], vec![format!("pool_ry_{}_{}", layer, q)]);
807 }
808
809 for q in 0..num_qubits - 1 {
811 circuit.add_gate("CZ", vec![q, q + 1], vec![]);
812 }
813 }
814
815 for q in 0..num_qubits {
817 circuit.add_gate("RX", vec![q], vec![format!("pool_measure_{}", q)]);
818 }
819
820 circuit
821 }
822
823 pub fn pool(
825 &self,
826 graph: &QuantumGraph,
827 node_features: &Array2<f64>,
828 ) -> Result<(Vec<usize>, Array2<f64>)> {
829 match self.method {
830 PoolingMethod::TopK => self.topk_pool(graph, node_features),
831 PoolingMethod::SelfAttention => self.attention_pool(graph, node_features),
832 PoolingMethod::DiffPool => self.diff_pool(graph, node_features),
833 }
834 }
835
836 fn topk_pool(
838 &self,
839 graph: &QuantumGraph,
840 node_features: &Array2<f64>,
841 ) -> Result<(Vec<usize>, Array2<f64>)> {
842 let mut scores = Vec::new();
844 for node in 0..graph.num_nodes {
845 let score = self.compute_node_score(&node_features.row(node).to_owned())?;
846 scores.push((node, score));
847 }
848
849 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
851
852 let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
854 let selected_nodes: Vec<usize> = scores.iter().take(k).map(|(idx, _)| *idx).collect();
855
856 let mut pooled_features = Array2::zeros((k, node_features.ncols()));
858 for (i, &node) in selected_nodes.iter().enumerate() {
859 pooled_features.row_mut(i).assign(&node_features.row(node));
860 }
861
862 Ok((selected_nodes, pooled_features))
863 }
864
865 fn attention_pool(
867 &self,
868 graph: &QuantumGraph,
869 node_features: &Array2<f64>,
870 ) -> Result<(Vec<usize>, Array2<f64>)> {
871 let mut attention_scores = Array1::zeros(graph.num_nodes);
873 for node in 0..graph.num_nodes {
874 attention_scores[node] =
875 self.compute_node_score(&node_features.row(node).to_owned())?;
876 }
877
878 let max_score = attention_scores
880 .iter()
881 .cloned()
882 .fold(f64::NEG_INFINITY, f64::max);
883 let exp_scores: Array1<f64> = attention_scores.mapv(|x| (x - max_score).exp());
884 let sum_exp = exp_scores.sum();
885 let normalized_scores = exp_scores / sum_exp;
886
887 let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
889 let mut selected_nodes = Vec::new();
890 let mut remaining_scores = normalized_scores.clone();
891
892 for _ in 0..k {
893 let node = self.sample_node(&remaining_scores);
894 selected_nodes.push(node);
895 remaining_scores[node] = 0.0;
896 }
897
898 let mut pooled_features = Array2::zeros((k, node_features.ncols()));
900 for (i, &node) in selected_nodes.iter().enumerate() {
901 let weighted_feature = &node_features.row(node).to_owned() * normalized_scores[node];
902 pooled_features.row_mut(i).assign(&weighted_feature);
903 }
904
905 Ok((selected_nodes, pooled_features))
906 }
907
908 fn diff_pool(
910 &self,
911 graph: &QuantumGraph,
912 node_features: &Array2<f64>,
913 ) -> Result<(Vec<usize>, Array2<f64>)> {
914 let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
916 let mut assignments = Array2::zeros((graph.num_nodes, k));
917
918 for node in 0..graph.num_nodes {
920 for cluster in 0..k {
921 let score =
922 self.compute_cluster_assignment(&node_features.row(node).to_owned(), cluster)?;
923 assignments[[node, cluster]] = score;
924 }
925 }
926
927 for node in 0..graph.num_nodes {
929 let row_sum: f64 = assignments.row(node).sum();
930 if row_sum > 0.0 {
931 for cluster in 0..k {
932 assignments[[node, cluster]] /= row_sum;
933 }
934 }
935 }
936
937 let pooled_features = assignments.t().dot(node_features);
939
940 let mut selected_nodes = Vec::new();
942 for cluster in 0..k {
943 let mut best_node = 0;
944 let mut best_score = 0.0;
945
946 for node in 0..graph.num_nodes {
947 if assignments[[node, cluster]] > best_score {
948 best_score = assignments[[node, cluster]];
949 best_node = node;
950 }
951 }
952
953 selected_nodes.push(best_node);
954 }
955
956 Ok((selected_nodes, pooled_features))
957 }
958
959 fn compute_node_score(&self, features: &Array1<f64>) -> Result<f64> {
961 let norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt();
963 Ok(norm * (1.0 + 0.1 * fastrand::f64()))
964 }
965
966 fn compute_cluster_assignment(&self, features: &Array1<f64>, cluster: usize) -> Result<f64> {
968 let base_score = features.iter().sum::<f64>() / features.len() as f64;
970 let cluster_bias = (cluster as f64) * 0.1;
971 Ok((base_score + cluster_bias).exp() / (1.0 + (base_score + cluster_bias).exp()))
972 }
973
974 fn sample_node(&self, scores: &Array1<f64>) -> usize {
976 let cumsum: Vec<f64> = scores
977 .iter()
978 .scan(0.0, |acc, &x| {
979 *acc += x;
980 Some(*acc)
981 })
982 .collect();
983
984 let r = fastrand::f64() * cumsum.last().unwrap();
985
986 for (i, &cs) in cumsum.iter().enumerate() {
987 if r <= cs {
988 return i;
989 }
990 }
991
992 scores.len() - 1
993 }
994}
995
996#[derive(Debug)]
998pub struct QuantumGNN {
999 layers: Vec<GNNLayer>,
1001 pooling: Vec<Option<QuantumGraphPool>>,
1003 readout: ReadoutType,
1005 output_dim: usize,
1007}
1008
1009#[derive(Debug)]
1010enum GNNLayer {
1011 GCN(QuantumGCNLayer),
1012 GAT(QuantumGATLayer),
1013 MPNN(QuantumMPNN),
1014}
1015
1016#[derive(Debug, Clone)]
1017pub enum ReadoutType {
1018 Mean,
1019 Max,
1020 Sum,
1021 Attention,
1022}
1023
1024impl QuantumGNN {
1025 pub fn new(
1027 layer_configs: Vec<(String, usize, usize)>, pooling_configs: Vec<Option<(f64, PoolingMethod)>>,
1029 readout: ReadoutType,
1030 output_dim: usize,
1031 ) -> Result<Self> {
1032 let mut layers = Vec::new();
1033 let mut pooling = Vec::new();
1034
1035 for (layer_type, input_dim, output_dim) in layer_configs {
1036 let layer = match layer_type.as_str() {
1037 "gcn" => GNNLayer::GCN(QuantumGCNLayer::new(
1038 input_dim,
1039 output_dim,
1040 ActivationType::ReLU,
1041 )),
1042 "gat" => GNNLayer::GAT(QuantumGATLayer::new(
1043 input_dim, output_dim, 4, 0.1, )),
1046 "mpnn" => GNNLayer::MPNN(QuantumMPNN::new(
1047 input_dim, output_dim, output_dim, 3, )),
1049 _ => {
1050 return Err(MLError::InvalidConfiguration(format!(
1051 "Unknown layer type: {}",
1052 layer_type
1053 )))
1054 }
1055 };
1056 layers.push(layer);
1057 }
1058
1059 for pool_config in pooling_configs {
1060 let pool_layer = pool_config.map(|(ratio, method)| {
1061 QuantumGraphPool::new(ratio, method, 64) });
1063 pooling.push(pool_layer);
1064 }
1065
1066 Ok(Self {
1067 layers,
1068 pooling,
1069 readout,
1070 output_dim,
1071 })
1072 }
1073
1074 pub fn forward(&self, graph: &QuantumGraph) -> Result<Array1<f64>> {
1076 let mut current_graph = graph.clone();
1077 let mut current_features = graph.node_features.clone();
1078 let mut selected_nodes: Vec<usize> = (0..graph.num_nodes).collect();
1079
1080 for (i, layer) in self.layers.iter().enumerate() {
1082 current_features = match layer {
1084 GNNLayer::GCN(gcn) => gcn.forward(¤t_graph)?,
1085 GNNLayer::GAT(gat) => gat.forward(¤t_graph)?,
1086 GNNLayer::MPNN(mpnn) => {
1087 let graph_features = mpnn.forward(¤t_graph)?;
1089 let mut node_features =
1091 Array2::zeros((current_graph.num_nodes, graph_features.len()));
1092 for node in 0..current_graph.num_nodes {
1093 node_features.row_mut(node).assign(&graph_features);
1094 }
1095 node_features
1096 }
1097 };
1098
1099 if let Some(Some(pool)) = self.pooling.get(i) {
1101 let (new_selected, pooled_features) =
1102 pool.pool(¤t_graph, ¤t_features)?;
1103
1104 current_graph =
1106 self.create_subgraph(¤t_graph, &new_selected, &pooled_features);
1107 current_features = pooled_features;
1108 selected_nodes = new_selected;
1109 }
1110 }
1111
1112 self.apply_readout(¤t_features)
1114 }
1115
1116 fn create_subgraph(
1118 &self,
1119 graph: &QuantumGraph,
1120 selected_nodes: &[usize],
1121 pooled_features: &Array2<f64>,
1122 ) -> QuantumGraph {
1123 let num_nodes = selected_nodes.len();
1124 let mut new_adjacency = Array2::zeros((num_nodes, num_nodes));
1125
1126 let index_map: HashMap<usize, usize> = selected_nodes
1128 .iter()
1129 .enumerate()
1130 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
1131 .collect();
1132
1133 for (i, &old_i) in selected_nodes.iter().enumerate() {
1135 for (j, &old_j) in selected_nodes.iter().enumerate() {
1136 new_adjacency[[i, j]] = graph.adjacency[[old_i, old_j]];
1137 }
1138 }
1139
1140 let mut edges = Vec::new();
1142 for i in 0..num_nodes {
1143 for j in i + 1..num_nodes {
1144 if new_adjacency[[i, j]] > 0.0 {
1145 edges.push((i, j));
1146 }
1147 }
1148 }
1149
1150 QuantumGraph::new(num_nodes, edges, pooled_features.clone())
1152 }
1153
1154 fn apply_readout(&self, node_features: &Array2<f64>) -> Result<Array1<f64>> {
1156 let readout_features = match self.readout {
1157 ReadoutType::Mean => node_features.mean_axis(ndarray::Axis(0)).unwrap(),
1158 ReadoutType::Max => {
1159 let mut max_features = Array1::from_elem(node_features.ncols(), f64::NEG_INFINITY);
1160 for row in node_features.rows() {
1161 for (i, &val) in row.iter().enumerate() {
1162 max_features[i] = max_features[i].max(val);
1163 }
1164 }
1165 max_features
1166 }
1167 ReadoutType::Sum => node_features.sum_axis(ndarray::Axis(0)),
1168 ReadoutType::Attention => {
1169 let mut weights = Array1::zeros(node_features.nrows());
1171 for (i, row) in node_features.rows().into_iter().enumerate() {
1172 weights[i] = row.sum(); }
1174
1175 let max_weight = weights.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1177 let exp_weights = weights.mapv(|x| (x - max_weight).exp());
1178 let weights_norm = exp_weights.clone() / exp_weights.sum();
1179
1180 let mut result = Array1::zeros(node_features.ncols());
1182 for (i, row) in node_features.rows().into_iter().enumerate() {
1183 result = &result + &(&row.to_owned() * weights_norm[i]);
1184 }
1185 result
1186 }
1187 };
1188
1189 let mut output = Array1::zeros(self.output_dim);
1191 for i in 0..self.output_dim {
1192 if i < readout_features.len() {
1193 output[i] = readout_features[i];
1194 }
1195 }
1196
1197 Ok(output)
1198 }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203 use super::*;
1204
1205 #[test]
1206 fn test_quantum_graph() {
1207 let nodes = 5;
1208 let edges = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)];
1209 let features = Array2::ones((nodes, 4));
1210
1211 let graph = QuantumGraph::new(nodes, edges, features);
1212
1213 assert_eq!(graph.num_nodes, 5);
1214 assert_eq!(graph.degree(0), 2);
1215 assert_eq!(graph.neighbors(0), vec![1, 4]);
1216 }
1217
1218 #[test]
1219 fn test_quantum_gcn_layer() {
1220 let graph = QuantumGraph::new(
1221 3,
1222 vec![(0, 1), (1, 2)],
1223 Array2::from_shape_vec(
1224 (3, 4),
1225 vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
1226 )
1227 .unwrap(),
1228 );
1229
1230 let gcn = QuantumGCNLayer::new(4, 8, ActivationType::ReLU);
1231 let output = gcn.forward(&graph).unwrap();
1232
1233 assert_eq!(output.shape(), &[3, 8]);
1234 }
1235
1236 #[test]
1237 fn test_quantum_gat_layer() {
1238 let graph = QuantumGraph::new(
1239 4,
1240 vec![(0, 1), (1, 2), (2, 3), (3, 0)],
1241 Array2::ones((4, 8)),
1242 );
1243
1244 let gat = QuantumGATLayer::new(8, 16, 4, 0.1);
1245 let output = gat.forward(&graph).unwrap();
1246
1247 assert_eq!(output.shape(), &[4, 16]);
1248 }
1249
1250 #[test]
1251 fn test_quantum_mpnn() {
1252 let graph = QuantumGraph::new(3, vec![(0, 1), (1, 2)], Array2::zeros((3, 4)));
1253
1254 let mpnn = QuantumMPNN::new(4, 8, 16, 2);
1255 let output = mpnn.forward(&graph).unwrap();
1256
1257 assert_eq!(output.len(), 8);
1258 }
1259
1260 #[test]
1261 fn test_graph_pooling() {
1262 let graph = QuantumGraph::new(
1263 6,
1264 vec![(0, 1), (1, 2), (3, 4), (4, 5)],
1265 Array2::ones((6, 4)),
1266 );
1267
1268 let pool = QuantumGraphPool::new(0.5, PoolingMethod::TopK, 4);
1269 let (selected, pooled) = pool.pool(&graph, &graph.node_features).unwrap();
1270
1271 assert_eq!(selected.len(), 3);
1272 assert_eq!(pooled.shape(), &[3, 4]);
1273 }
1274
1275 #[test]
1276 fn test_complete_gnn() {
1277 let layer_configs = vec![("gcn".to_string(), 4, 8), ("gat".to_string(), 8, 16)];
1278 let pooling_configs = vec![None, Some((0.5, PoolingMethod::TopK))];
1279
1280 let gnn = QuantumGNN::new(layer_configs, pooling_configs, ReadoutType::Mean, 10).unwrap();
1281
1282 let graph = QuantumGraph::new(
1283 5,
1284 vec![(0, 1), (1, 2), (2, 3), (3, 4)],
1285 Array2::ones((5, 4)),
1286 );
1287
1288 let output = gnn.forward(&graph).unwrap();
1289 assert_eq!(output.len(), 10);
1290 }
1291}