use scirs2_core::ndarray::{Array1, Array2, Array3};
use scirs2_core::Complex64;
use std::collections::HashMap;
use std::f64::consts::PI;
use crate::autodiff::DifferentiableParam;
use crate::error::{MLError, Result};
use crate::utils::VariationalCircuit;
use quantrs2_circuit::prelude::*;
use quantrs2_core::gate::{multi::*, single::*, GateOp};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ActivationType {
Linear,
ReLU,
Sigmoid,
Tanh,
}
#[derive(Debug, Clone)]
pub struct QuantumGraph {
num_nodes: usize,
adjacency: Array2<f64>,
node_features: Array2<f64>,
edge_features: Option<HashMap<(usize, usize), Array1<f64>>>,
graph_features: Option<Array1<f64>>,
}
impl QuantumGraph {
pub fn new(num_nodes: usize, edges: Vec<(usize, usize)>, node_features: Array2<f64>) -> Self {
let mut adjacency = Array2::zeros((num_nodes, num_nodes));
for (src, dst) in edges {
adjacency[[src, dst]] = 1.0;
adjacency[[dst, src]] = 1.0; }
Self {
num_nodes,
adjacency,
node_features,
edge_features: None,
graph_features: None,
}
}
pub fn with_edge_features(
mut self,
edge_features: HashMap<(usize, usize), Array1<f64>>,
) -> Self {
self.edge_features = Some(edge_features);
self
}
pub fn with_graph_features(mut self, graph_features: Array1<f64>) -> Self {
self.graph_features = Some(graph_features);
self
}
pub fn degree(&self, node: usize) -> usize {
self.adjacency
.row(node)
.iter()
.filter(|&&x| x > 0.0)
.count()
}
pub fn neighbors(&self, node: usize) -> Vec<usize> {
self.adjacency
.row(node)
.iter()
.enumerate()
.filter(|(_, &val)| val > 0.0)
.map(|(idx, _)| idx)
.collect()
}
pub fn laplacian(&self) -> Array2<f64> {
let mut degree_matrix = Array2::zeros((self.num_nodes, self.num_nodes));
for i in 0..self.num_nodes {
degree_matrix[[i, i]] = self.degree(i) as f64;
}
°ree_matrix - &self.adjacency
}
pub fn normalized_laplacian(&self) -> Array2<f64> {
let mut degree_matrix = Array2::zeros((self.num_nodes, self.num_nodes));
let mut degree_sqrt_inv = Array1::zeros(self.num_nodes);
for i in 0..self.num_nodes {
let degree = self.degree(i) as f64;
degree_matrix[[i, i]] = degree;
if degree > 0.0 {
degree_sqrt_inv[i] = 1.0 / degree.sqrt();
}
}
let mut norm_laplacian = Array2::eye(self.num_nodes);
for i in 0..self.num_nodes {
for j in 0..self.num_nodes {
if self.adjacency[[i, j]] > 0.0 {
norm_laplacian[[i, j]] -=
degree_sqrt_inv[i] * self.adjacency[[i, j]] * degree_sqrt_inv[j];
}
}
}
norm_laplacian
}
}
#[derive(Debug)]
pub struct QuantumGCNLayer {
input_dim: usize,
output_dim: usize,
num_qubits: usize,
node_circuit: VariationalCircuit,
aggregation_circuit: VariationalCircuit,
parameters: HashMap<String, f64>,
activation: ActivationType,
}
impl QuantumGCNLayer {
pub fn new(input_dim: usize, output_dim: usize, activation: ActivationType) -> Self {
let num_qubits = ((input_dim.max(output_dim)) as f64).log2().ceil() as usize;
let node_circuit = Self::build_node_circuit(num_qubits);
let aggregation_circuit = Self::build_aggregation_circuit(num_qubits);
Self {
input_dim,
output_dim,
num_qubits,
node_circuit,
aggregation_circuit,
parameters: HashMap::new(),
activation,
}
}
fn build_node_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits);
for q in 0..num_qubits {
circuit.add_gate("RY", vec![q], vec![format!("node_encode_{}", q)]);
}
for layer in 0..2 {
for q in 0..num_qubits - 1 {
circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
}
if num_qubits > 2 {
circuit.add_gate("CNOT", vec![num_qubits - 1, 0], vec![]);
}
for q in 0..num_qubits {
circuit.add_gate("RX", vec![q], vec![format!("node_rx_{}_{}", layer, q)]);
circuit.add_gate("RZ", vec![q], vec![format!("node_rz_{}_{}", layer, q)]);
}
}
circuit
}
fn build_aggregation_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits * 2);
for q in 0..num_qubits {
circuit.add_gate("CZ", vec![q, q + num_qubits], vec![]);
}
for q in 0..num_qubits * 2 {
circuit.add_gate("RY", vec![q], vec![format!("agg_ry_{}", q)]);
}
for q in 0..num_qubits * 2 - 1 {
circuit.add_gate("CNOT", vec![q, q + 1], vec![]);
}
for q in 0..num_qubits {
circuit.add_gate("RX", vec![q], vec![format!("agg_final_{}", q)]);
}
circuit
}
pub fn forward(&self, graph: &QuantumGraph) -> Result<Array2<f64>> {
let mut output_features = Array2::zeros((graph.num_nodes, self.output_dim));
for node in 0..graph.num_nodes {
let node_feat = graph.node_features.row(node);
let neighbors = graph.neighbors(node);
let mut aggregated = Array1::zeros(self.input_dim);
for &neighbor in &neighbors {
let neighbor_feat = graph.node_features.row(neighbor);
aggregated = &aggregated + &neighbor_feat.to_owned();
}
let degree = neighbors.len().max(1) as f64;
aggregated = aggregated / degree;
let transformed = self.quantum_transform(&node_feat.to_owned(), &aggregated)?;
for i in 0..self.output_dim {
output_features[[node, i]] = transformed[i];
}
}
Ok(output_features)
}
fn quantum_transform(
&self,
node_features: &Array1<f64>,
aggregated_features: &Array1<f64>,
) -> Result<Array1<f64>> {
let node_encoded = self.encode_features(node_features)?;
let agg_encoded = self.encode_features(aggregated_features)?;
let mut output = Array1::zeros(self.output_dim);
for i in 0..self.output_dim {
let idx_node = i % node_features.len();
let idx_agg = i % aggregated_features.len();
output[i] = match self.activation {
ActivationType::ReLU => {
(0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]).max(0.0)
}
ActivationType::Tanh => {
(0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]).tanh()
}
ActivationType::Sigmoid => {
let x = 0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg];
1.0 / (1.0 + (-x).exp())
}
ActivationType::Linear => {
0.5 * node_features[idx_node] + 0.5 * aggregated_features[idx_agg]
}
};
}
Ok(output)
}
fn encode_features(&self, features: &Array1<f64>) -> Result<Vec<Complex64>> {
let state_dim = 2_usize.pow(self.num_qubits as u32);
let mut quantum_state = vec![Complex64::new(0.0, 0.0); state_dim];
let norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-10 {
quantum_state[0] = Complex64::new(1.0, 0.0);
} else {
for (i, &val) in features.iter().enumerate() {
if i < state_dim {
quantum_state[i] = Complex64::new(val / norm, 0.0);
}
}
}
Ok(quantum_state)
}
}
#[derive(Debug)]
pub struct QuantumGATLayer {
input_dim: usize,
output_dim: usize,
num_heads: usize,
attention_circuits: Vec<VariationalCircuit>,
transform_circuits: Vec<VariationalCircuit>,
dropout_rate: f64,
}
impl QuantumGATLayer {
pub fn new(input_dim: usize, output_dim: usize, num_heads: usize, dropout_rate: f64) -> Self {
let mut attention_circuits = Vec::new();
let mut transform_circuits = Vec::new();
let qubits_per_head = ((output_dim / num_heads) as f64).log2().ceil() as usize;
for _ in 0..num_heads {
attention_circuits.push(Self::build_attention_circuit(qubits_per_head));
transform_circuits.push(Self::build_transform_circuit(qubits_per_head));
}
Self {
input_dim,
output_dim,
num_heads,
attention_circuits,
transform_circuits,
dropout_rate,
}
}
fn build_attention_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits * 2);
for q in 0..num_qubits {
circuit.add_gate("RY", vec![q], vec![format!("att_src_{}", q)]);
circuit.add_gate("RY", vec![q + num_qubits], vec![format!("att_dst_{}", q)]);
}
for q in 0..num_qubits {
circuit.add_gate("CZ", vec![q, q + num_qubits], vec![]);
}
circuit.add_gate("H", vec![0], vec![]);
for q in 1..num_qubits * 2 {
circuit.add_gate("CNOT", vec![0, q], vec![]);
}
circuit
}
fn build_transform_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits);
for layer in 0..2 {
for q in 0..num_qubits {
circuit.add_gate("RY", vec![q], vec![format!("trans_ry_{}_{}", layer, q)]);
circuit.add_gate("RZ", vec![q], vec![format!("trans_rz_{}_{}", layer, q)]);
}
for q in 0..num_qubits - 1 {
circuit.add_gate("CX", vec![q, q + 1], vec![]);
}
}
circuit
}
pub fn forward(&self, graph: &QuantumGraph) -> Result<Array2<f64>> {
let head_dim = self.output_dim / self.num_heads;
let mut all_head_outputs = Vec::new();
for head in 0..self.num_heads {
let head_output = self.process_attention_head(graph, head)?;
all_head_outputs.push(head_output);
}
let mut output = Array2::zeros((graph.num_nodes, self.output_dim));
for (h, head_output) in all_head_outputs.iter().enumerate() {
for node in 0..graph.num_nodes {
for d in 0..head_dim {
output[[node, h * head_dim + d]] = head_output[[node, d]];
}
}
}
Ok(output)
}
fn process_attention_head(&self, graph: &QuantumGraph, head: usize) -> Result<Array2<f64>> {
let head_dim = self.output_dim / self.num_heads;
let mut output = Array2::zeros((graph.num_nodes, head_dim));
let attention_scores = self.compute_attention_scores(graph, head)?;
for node in 0..graph.num_nodes {
let neighbors = graph.neighbors(node);
let feature_dim = graph.node_features.ncols();
let mut weighted_features = Array1::zeros(feature_dim);
let self_score = attention_scores[[node, node]];
weighted_features =
&weighted_features + &(&graph.node_features.row(node).to_owned() * self_score);
for &neighbor in &neighbors {
let score = attention_scores[[node, neighbor]];
weighted_features =
&weighted_features + &(&graph.node_features.row(neighbor).to_owned() * score);
}
let transformed = self.transform_features(&weighted_features, head)?;
for d in 0..head_dim {
output[[node, d]] = transformed[d];
}
}
Ok(output)
}
fn compute_attention_scores(&self, graph: &QuantumGraph, head: usize) -> Result<Array2<f64>> {
let mut scores = Array2::zeros((graph.num_nodes, graph.num_nodes));
for i in 0..graph.num_nodes {
for j in 0..graph.num_nodes {
if i == j || graph.adjacency[[i, j]] > 0.0 {
let score = self.quantum_attention_score(
&graph.node_features.row(i).to_owned(),
&graph.node_features.row(j).to_owned(),
head,
)?;
scores[[i, j]] = score;
}
}
let neighbors = graph.neighbors(i);
if !neighbors.is_empty() {
let mut sum_exp = (scores[[i, i]]).exp();
for &j in &neighbors {
sum_exp += scores[[i, j]].exp();
}
scores[[i, i]] = scores[[i, i]].exp() / sum_exp;
for &j in &neighbors {
scores[[i, j]] = scores[[i, j]].exp() / sum_exp;
}
} else {
scores[[i, i]] = 1.0;
}
}
Ok(scores)
}
fn quantum_attention_score(
&self,
feat_i: &Array1<f64>,
feat_j: &Array1<f64>,
head: usize,
) -> Result<f64> {
let dot_product: f64 = feat_i.iter().zip(feat_j.iter()).map(|(a, b)| a * b).sum();
Ok((dot_product / (self.input_dim as f64).sqrt()).tanh())
}
fn transform_features(&self, features: &Array1<f64>, head: usize) -> Result<Array1<f64>> {
let head_dim = self.output_dim / self.num_heads;
let mut output = Array1::zeros(head_dim);
for i in 0..head_dim {
if i < features.len() {
output[i] = features[i] * (1.0 + 0.1 * (i as f64).sin());
}
}
Ok(output)
}
}
#[derive(Debug)]
pub struct QuantumMPNN {
message_circuit: VariationalCircuit,
update_circuit: VariationalCircuit,
readout_circuit: VariationalCircuit,
hidden_dim: usize,
num_steps: usize,
}
impl QuantumMPNN {
pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize, num_steps: usize) -> Self {
let num_qubits = (hidden_dim as f64).log2().ceil() as usize;
Self {
message_circuit: Self::build_message_circuit(num_qubits),
update_circuit: Self::build_update_circuit(num_qubits),
readout_circuit: Self::build_readout_circuit(num_qubits),
hidden_dim,
num_steps,
}
}
fn build_message_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits * 3);
for q in 0..num_qubits * 3 {
circuit.add_gate("RY", vec![q], vec![format!("msg_encode_{}", q)]);
}
for layer in 0..2 {
for q in 0..num_qubits {
circuit.add_gate("CZ", vec![q, q + num_qubits * 2], vec![]);
}
for q in 0..num_qubits {
circuit.add_gate("CZ", vec![q + num_qubits, q + num_qubits * 2], vec![]);
}
for q in 0..num_qubits * 3 {
circuit.add_gate("RX", vec![q], vec![format!("msg_rx_{}_{}", layer, q)]);
}
}
circuit
}
fn build_update_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits * 2);
for q in 0..num_qubits {
circuit.add_gate("CNOT", vec![q, q + num_qubits], vec![]);
}
for layer in 0..2 {
for q in 0..num_qubits * 2 {
circuit.add_gate("RY", vec![q], vec![format!("upd_ry_{}_{}", layer, q)]);
circuit.add_gate("RZ", vec![q], vec![format!("upd_rz_{}_{}", layer, q)]);
}
for q in 0..num_qubits * 2 - 1 {
circuit.add_gate("CX", vec![q, q + 1], vec![]);
}
}
circuit
}
fn build_readout_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits);
for layer in 0..3 {
for q in 0..num_qubits {
circuit.add_gate("RY", vec![q], vec![format!("read_ry_{}_{}", layer, q)]);
}
for i in 0..num_qubits {
for j in i + 1..num_qubits {
circuit.add_gate("CZ", vec![i, j], vec![]);
}
}
}
circuit
}
pub fn forward(&self, graph: &QuantumGraph) -> Result<Array1<f64>> {
let mut hidden_states = Array2::zeros((graph.num_nodes, self.hidden_dim));
for node in 0..graph.num_nodes {
for d in 0..self.hidden_dim.min(graph.node_features.ncols()) {
hidden_states[[node, d]] = graph.node_features[[node, d]];
}
}
for _ in 0..self.num_steps {
hidden_states = self.message_passing_step(graph, &hidden_states)?;
}
self.readout(graph, &hidden_states)
}
fn message_passing_step(
&self,
graph: &QuantumGraph,
hidden_states: &Array2<f64>,
) -> Result<Array2<f64>> {
let mut new_hidden = Array2::zeros((graph.num_nodes, self.hidden_dim));
for node in 0..graph.num_nodes {
let neighbors = graph.neighbors(node);
let mut messages = Array1::zeros(self.hidden_dim);
for &neighbor in &neighbors {
let message = self.compute_message(
&hidden_states.row(neighbor).to_owned(),
&hidden_states.row(node).to_owned(),
graph
.edge_features
.as_ref()
.and_then(|ef| ef.get(&(neighbor, node))),
)?;
messages = &messages + &message;
}
let updated = self.update_node(&hidden_states.row(node).to_owned(), &messages)?;
new_hidden.row_mut(node).assign(&updated);
}
Ok(new_hidden)
}
fn compute_message(
&self,
source_hidden: &Array1<f64>,
dest_hidden: &Array1<f64>,
edge_features: Option<&Array1<f64>>,
) -> Result<Array1<f64>> {
let mut message = Array1::zeros(self.hidden_dim);
for i in 0..self.hidden_dim {
let src_val = if i < source_hidden.len() {
source_hidden[i]
} else {
0.0
};
let dst_val = if i < dest_hidden.len() {
dest_hidden[i]
} else {
0.0
};
let edge_val = edge_features
.and_then(|ef| ef.get(i))
.copied()
.unwrap_or(1.0);
message[i] = (src_val + dst_val) * edge_val * 0.5;
}
Ok(message)
}
fn update_node(&self, hidden: &Array1<f64>, messages: &Array1<f64>) -> Result<Array1<f64>> {
let mut new_hidden = Array1::zeros(self.hidden_dim);
for i in 0..self.hidden_dim {
let h = if i < hidden.len() { hidden[i] } else { 0.0 };
let m = if i < messages.len() { messages[i] } else { 0.0 };
let z = (h + m).tanh(); let r = 1.0 / (1.0 + (-(h * m)).exp()); let h_tilde = ((r * h) + m).tanh();
new_hidden[i] = (1.0 - z) * h + z * h_tilde;
}
Ok(new_hidden)
}
fn readout(&self, graph: &QuantumGraph, hidden_states: &Array2<f64>) -> Result<Array1<f64>> {
let mut global_state: Array1<f64> = Array1::zeros(self.hidden_dim);
for node in 0..graph.num_nodes {
global_state = &global_state + &hidden_states.row(node).to_owned();
}
global_state = global_state / (graph.num_nodes as f64);
let mut output = Array1::zeros(self.hidden_dim);
for i in 0..self.hidden_dim {
output[i] = global_state[i].tanh();
}
Ok(output)
}
}
#[derive(Debug)]
pub struct QuantumGraphPool {
pool_ratio: f64,
method: PoolingMethod,
score_circuit: VariationalCircuit,
}
#[derive(Debug, Clone)]
pub enum PoolingMethod {
TopK,
SelfAttention,
DiffPool,
}
impl QuantumGraphPool {
pub fn new(pool_ratio: f64, method: PoolingMethod, feature_dim: usize) -> Self {
let num_qubits = (feature_dim as f64).log2().ceil() as usize;
Self {
pool_ratio,
method,
score_circuit: Self::build_score_circuit(num_qubits),
}
}
fn build_score_circuit(num_qubits: usize) -> VariationalCircuit {
let mut circuit = VariationalCircuit::new(num_qubits);
for layer in 0..2 {
for q in 0..num_qubits {
circuit.add_gate("RY", vec![q], vec![format!("pool_ry_{}_{}", layer, q)]);
}
for q in 0..num_qubits - 1 {
circuit.add_gate("CZ", vec![q, q + 1], vec![]);
}
}
for q in 0..num_qubits {
circuit.add_gate("RX", vec![q], vec![format!("pool_measure_{}", q)]);
}
circuit
}
pub fn pool(
&self,
graph: &QuantumGraph,
node_features: &Array2<f64>,
) -> Result<(Vec<usize>, Array2<f64>)> {
match self.method {
PoolingMethod::TopK => self.topk_pool(graph, node_features),
PoolingMethod::SelfAttention => self.attention_pool(graph, node_features),
PoolingMethod::DiffPool => self.diff_pool(graph, node_features),
}
}
fn topk_pool(
&self,
graph: &QuantumGraph,
node_features: &Array2<f64>,
) -> Result<(Vec<usize>, Array2<f64>)> {
let mut scores = Vec::new();
for node in 0..graph.num_nodes {
let score = self.compute_node_score(&node_features.row(node).to_owned())?;
scores.push((node, score));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
let selected_nodes: Vec<usize> = scores.iter().take(k).map(|(idx, _)| *idx).collect();
let mut pooled_features = Array2::zeros((k, node_features.ncols()));
for (i, &node) in selected_nodes.iter().enumerate() {
pooled_features.row_mut(i).assign(&node_features.row(node));
}
Ok((selected_nodes, pooled_features))
}
fn attention_pool(
&self,
graph: &QuantumGraph,
node_features: &Array2<f64>,
) -> Result<(Vec<usize>, Array2<f64>)> {
let mut attention_scores = Array1::zeros(graph.num_nodes);
for node in 0..graph.num_nodes {
attention_scores[node] =
self.compute_node_score(&node_features.row(node).to_owned())?;
}
let max_score = attention_scores
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Array1<f64> = attention_scores.mapv(|x| (x - max_score).exp());
let sum_exp = exp_scores.sum();
let normalized_scores = exp_scores / sum_exp;
let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
let mut selected_nodes = Vec::new();
let mut remaining_scores = normalized_scores.clone();
for _ in 0..k {
let node = self.sample_node(&remaining_scores);
selected_nodes.push(node);
remaining_scores[node] = 0.0;
}
let mut pooled_features = Array2::zeros((k, node_features.ncols()));
for (i, &node) in selected_nodes.iter().enumerate() {
let weighted_feature = &node_features.row(node).to_owned() * normalized_scores[node];
pooled_features.row_mut(i).assign(&weighted_feature);
}
Ok((selected_nodes, pooled_features))
}
fn diff_pool(
&self,
graph: &QuantumGraph,
node_features: &Array2<f64>,
) -> Result<(Vec<usize>, Array2<f64>)> {
let k = ((graph.num_nodes as f64) * self.pool_ratio).ceil() as usize;
let mut assignments = Array2::zeros((graph.num_nodes, k));
for node in 0..graph.num_nodes {
for cluster in 0..k {
let score =
self.compute_cluster_assignment(&node_features.row(node).to_owned(), cluster)?;
assignments[[node, cluster]] = score;
}
}
for node in 0..graph.num_nodes {
let row_sum: f64 = assignments.row(node).sum();
if row_sum > 0.0 {
for cluster in 0..k {
assignments[[node, cluster]] /= row_sum;
}
}
}
let pooled_features = assignments.t().dot(node_features);
let mut selected_nodes = Vec::new();
for cluster in 0..k {
let mut best_node = 0;
let mut best_score = 0.0;
for node in 0..graph.num_nodes {
if assignments[[node, cluster]] > best_score {
best_score = assignments[[node, cluster]];
best_node = node;
}
}
selected_nodes.push(best_node);
}
Ok((selected_nodes, pooled_features))
}
fn compute_node_score(&self, features: &Array1<f64>) -> Result<f64> {
let norm: f64 = features.iter().map(|x| x * x).sum::<f64>().sqrt();
Ok(norm * (1.0 + 0.1 * fastrand::f64()))
}
fn compute_cluster_assignment(&self, features: &Array1<f64>, cluster: usize) -> Result<f64> {
let base_score = features.iter().sum::<f64>() / features.len() as f64;
let cluster_bias = (cluster as f64) * 0.1;
Ok((base_score + cluster_bias).exp() / (1.0 + (base_score + cluster_bias).exp()))
}
fn sample_node(&self, scores: &Array1<f64>) -> usize {
let cumsum: Vec<f64> = scores
.iter()
.scan(0.0, |acc, &x| {
*acc += x;
Some(*acc)
})
.collect();
let r = fastrand::f64() * cumsum.last().unwrap_or(&1.0);
for (i, &cs) in cumsum.iter().enumerate() {
if r <= cs {
return i;
}
}
scores.len() - 1
}
}
#[derive(Debug)]
pub struct QuantumGNN {
layers: Vec<GNNLayer>,
pooling: Vec<Option<QuantumGraphPool>>,
readout: ReadoutType,
output_dim: usize,
}
#[derive(Debug)]
enum GNNLayer {
GCN(QuantumGCNLayer),
GAT(QuantumGATLayer),
MPNN(QuantumMPNN),
}
#[derive(Debug, Clone)]
pub enum ReadoutType {
Mean,
Max,
Sum,
Attention,
}
impl QuantumGNN {
pub fn new(
layer_configs: Vec<(String, usize, usize)>, pooling_configs: Vec<Option<(f64, PoolingMethod)>>,
readout: ReadoutType,
output_dim: usize,
) -> Result<Self> {
let mut layers = Vec::new();
let mut pooling = Vec::new();
for (layer_type, input_dim, output_dim) in layer_configs {
let layer = match layer_type.as_str() {
"gcn" => GNNLayer::GCN(QuantumGCNLayer::new(
input_dim,
output_dim,
ActivationType::ReLU,
)),
"gat" => GNNLayer::GAT(QuantumGATLayer::new(
input_dim, output_dim, 4, 0.1, )),
"mpnn" => GNNLayer::MPNN(QuantumMPNN::new(
input_dim, output_dim, output_dim, 3, )),
_ => {
return Err(MLError::InvalidConfiguration(format!(
"Unknown layer type: {}",
layer_type
)))
}
};
layers.push(layer);
}
for pool_config in pooling_configs {
let pool_layer = pool_config.map(|(ratio, method)| {
QuantumGraphPool::new(ratio, method, 64) });
pooling.push(pool_layer);
}
Ok(Self {
layers,
pooling,
readout,
output_dim,
})
}
pub fn forward(&self, graph: &QuantumGraph) -> Result<Array1<f64>> {
let mut current_graph = graph.clone();
let mut current_features = graph.node_features.clone();
let mut selected_nodes: Vec<usize> = (0..graph.num_nodes).collect();
for (i, layer) in self.layers.iter().enumerate() {
current_features = match layer {
GNNLayer::GCN(gcn) => gcn.forward(¤t_graph)?,
GNNLayer::GAT(gat) => gat.forward(¤t_graph)?,
GNNLayer::MPNN(mpnn) => {
let graph_features = mpnn.forward(¤t_graph)?;
let mut node_features =
Array2::zeros((current_graph.num_nodes, graph_features.len()));
for node in 0..current_graph.num_nodes {
node_features.row_mut(node).assign(&graph_features);
}
node_features
}
};
if let Some(Some(pool)) = self.pooling.get(i) {
let (new_selected, pooled_features) =
pool.pool(¤t_graph, ¤t_features)?;
current_graph =
self.create_subgraph(¤t_graph, &new_selected, &pooled_features);
current_features = pooled_features;
selected_nodes = new_selected;
}
}
self.apply_readout(¤t_features)
}
fn create_subgraph(
&self,
graph: &QuantumGraph,
selected_nodes: &[usize],
pooled_features: &Array2<f64>,
) -> QuantumGraph {
let num_nodes = selected_nodes.len();
let mut new_adjacency = Array2::zeros((num_nodes, num_nodes));
let index_map: HashMap<usize, usize> = selected_nodes
.iter()
.enumerate()
.map(|(new_idx, &old_idx)| (old_idx, new_idx))
.collect();
for (i, &old_i) in selected_nodes.iter().enumerate() {
for (j, &old_j) in selected_nodes.iter().enumerate() {
new_adjacency[[i, j]] = graph.adjacency[[old_i, old_j]];
}
}
let mut edges = Vec::new();
for i in 0..num_nodes {
for j in i + 1..num_nodes {
if new_adjacency[[i, j]] > 0.0 {
edges.push((i, j));
}
}
}
QuantumGraph::new(num_nodes, edges, pooled_features.clone())
}
fn apply_readout(&self, node_features: &Array2<f64>) -> Result<Array1<f64>> {
let readout_features = match self.readout {
ReadoutType::Mean => node_features
.mean_axis(scirs2_core::ndarray::Axis(0))
.ok_or_else(|| {
MLError::InvalidInput("Cannot compute mean of empty array".to_string())
})?,
ReadoutType::Max => {
let mut max_features = Array1::from_elem(node_features.ncols(), f64::NEG_INFINITY);
for row in node_features.rows() {
for (i, &val) in row.iter().enumerate() {
max_features[i] = max_features[i].max(val);
}
}
max_features
}
ReadoutType::Sum => node_features.sum_axis(scirs2_core::ndarray::Axis(0)),
ReadoutType::Attention => {
let mut weights = Array1::zeros(node_features.nrows());
for (i, row) in node_features.rows().into_iter().enumerate() {
weights[i] = row.sum(); }
let max_weight = weights.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_weights = weights.mapv(|x| (x - max_weight).exp());
let weights_norm = exp_weights.clone() / exp_weights.sum();
let mut result = Array1::zeros(node_features.ncols());
for (i, row) in node_features.rows().into_iter().enumerate() {
result = &result + &(&row.to_owned() * weights_norm[i]);
}
result
}
};
let mut output = Array1::zeros(self.output_dim);
for i in 0..self.output_dim {
if i < readout_features.len() {
output[i] = readout_features[i];
}
}
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantum_graph() {
let nodes = 5;
let edges = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)];
let features = Array2::ones((nodes, 4));
let graph = QuantumGraph::new(nodes, edges, features);
assert_eq!(graph.num_nodes, 5);
assert_eq!(graph.degree(0), 2);
assert_eq!(graph.neighbors(0), vec![1, 4]);
}
#[test]
fn test_quantum_gcn_layer() {
let graph = QuantumGraph::new(
3,
vec![(0, 1), (1, 2)],
Array2::from_shape_vec(
(3, 4),
vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
)
.expect("Failed to create node features"),
);
let gcn = QuantumGCNLayer::new(4, 8, ActivationType::ReLU);
let output = gcn.forward(&graph).expect("Forward pass failed");
assert_eq!(output.shape(), &[3, 8]);
}
#[test]
fn test_quantum_gat_layer() {
let graph = QuantumGraph::new(
4,
vec![(0, 1), (1, 2), (2, 3), (3, 0)],
Array2::ones((4, 8)),
);
let gat = QuantumGATLayer::new(8, 16, 4, 0.1);
let output = gat.forward(&graph).expect("Forward pass failed");
assert_eq!(output.shape(), &[4, 16]);
}
#[test]
fn test_quantum_mpnn() {
let graph = QuantumGraph::new(3, vec![(0, 1), (1, 2)], Array2::zeros((3, 4)));
let mpnn = QuantumMPNN::new(4, 8, 16, 2);
let output = mpnn.forward(&graph).expect("Forward pass failed");
assert_eq!(output.len(), 8);
}
#[test]
fn test_graph_pooling() {
let graph = QuantumGraph::new(
6,
vec![(0, 1), (1, 2), (3, 4), (4, 5)],
Array2::ones((6, 4)),
);
let pool = QuantumGraphPool::new(0.5, PoolingMethod::TopK, 4);
let (selected, pooled) = pool
.pool(&graph, &graph.node_features)
.expect("Pooling failed");
assert_eq!(selected.len(), 3);
assert_eq!(pooled.shape(), &[3, 4]);
}
#[test]
fn test_complete_gnn() {
let layer_configs = vec![("gcn".to_string(), 4, 8), ("gat".to_string(), 8, 16)];
let pooling_configs = vec![None, Some((0.5, PoolingMethod::TopK))];
let gnn = QuantumGNN::new(layer_configs, pooling_configs, ReadoutType::Mean, 10)
.expect("Failed to create GNN");
let graph = QuantumGraph::new(
5,
vec![(0, 1), (1, 2), (2, 3), (3, 4)],
Array2::ones((5, 4)),
);
let output = gnn.forward(&graph).expect("Forward pass failed");
assert_eq!(output.len(), 10);
}
}