use crate::error::{QuantRS2Error, QuantRS2Result};
use scirs2_core::Complex64;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum GateType {
PauliX,
PauliY,
PauliZ,
Hadamard,
Phase(u64), RX(u64), RY(u64), RZ(u64), S,
T,
CNOT,
CZ,
SWAP,
CRZ(u64),
Toffoli,
Fredkin,
}
#[derive(Debug, Clone)]
pub struct QuantumGate {
pub gate_type: GateType,
pub qubits: Vec<usize>,
pub matrix: Vec<Complex64>,
}
impl QuantumGate {
pub fn new(gate_type: GateType, qubits: Vec<usize>) -> QuantRS2Result<Self> {
let matrix = Self::compute_matrix(&gate_type)?;
Ok(Self {
gate_type,
qubits,
matrix,
})
}
fn compute_matrix(gate_type: &GateType) -> QuantRS2Result<Vec<Complex64>> {
use std::f64::consts::{FRAC_1_SQRT_2, PI};
let matrix = match gate_type {
GateType::PauliX => vec![
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
],
GateType::PauliY => vec![
Complex64::new(0.0, 0.0),
Complex64::new(0.0, -1.0),
Complex64::new(0.0, 1.0),
Complex64::new(0.0, 0.0),
],
GateType::PauliZ => vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(-1.0, 0.0),
],
GateType::Hadamard => vec![
Complex64::new(FRAC_1_SQRT_2, 0.0),
Complex64::new(FRAC_1_SQRT_2, 0.0),
Complex64::new(FRAC_1_SQRT_2, 0.0),
Complex64::new(-FRAC_1_SQRT_2, 0.0),
],
GateType::S => vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 1.0),
],
GateType::T => vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(FRAC_1_SQRT_2, FRAC_1_SQRT_2),
],
GateType::RX(quantized_angle) => {
let angle = (*quantized_angle as f64) / 1_000_000.0;
let cos_half = (angle / 2.0).cos();
let sin_half = (angle / 2.0).sin();
vec![
Complex64::new(cos_half, 0.0),
Complex64::new(0.0, -sin_half),
Complex64::new(0.0, -sin_half),
Complex64::new(cos_half, 0.0),
]
}
GateType::RY(quantized_angle) => {
let angle = (*quantized_angle as f64) / 1_000_000.0;
let cos_half = (angle / 2.0).cos();
let sin_half = (angle / 2.0).sin();
vec![
Complex64::new(cos_half, 0.0),
Complex64::new(-sin_half, 0.0),
Complex64::new(sin_half, 0.0),
Complex64::new(cos_half, 0.0),
]
}
GateType::RZ(quantized_angle) => {
let angle = (*quantized_angle as f64) / 1_000_000.0;
let cos_half = (angle / 2.0).cos();
let sin_half = (angle / 2.0).sin();
vec![
Complex64::new(cos_half, -sin_half),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(cos_half, sin_half),
]
}
GateType::Phase(quantized_angle) => {
let angle = (*quantized_angle as f64) / 1_000_000.0;
vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(angle.cos(), angle.sin()),
]
}
GateType::CNOT => vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
],
_ => return Err(QuantRS2Error::UnsupportedGate(format!("{gate_type:?}"))),
};
Ok(matrix)
}
pub const fn num_qubits(&self) -> usize {
match self.gate_type {
GateType::PauliX
| GateType::PauliY
| GateType::PauliZ
| GateType::Hadamard
| GateType::Phase(_)
| GateType::RX(_)
| GateType::RY(_)
| GateType::RZ(_)
| GateType::S
| GateType::T => 1,
GateType::CNOT | GateType::CZ | GateType::SWAP | GateType::CRZ(_) => 2,
GateType::Toffoli | GateType::Fredkin => 3,
}
}
}
#[derive(Debug, Clone)]
pub struct FusionRule {
pub pattern: Vec<GateType>,
pub replacement: Vec<GateType>,
pub efficiency_gain: f64, }
impl FusionRule {
pub fn common_rules() -> Vec<Self> {
vec![
Self {
pattern: vec![GateType::PauliX, GateType::PauliX],
replacement: vec![], efficiency_gain: 2.0,
},
Self {
pattern: vec![GateType::PauliY, GateType::PauliY],
replacement: vec![],
efficiency_gain: 2.0,
},
Self {
pattern: vec![GateType::PauliZ, GateType::PauliZ],
replacement: vec![],
efficiency_gain: 2.0,
},
Self {
pattern: vec![GateType::Hadamard, GateType::Hadamard],
replacement: vec![],
efficiency_gain: 2.0,
},
Self {
pattern: vec![GateType::S, GateType::S],
replacement: vec![GateType::PauliZ],
efficiency_gain: 2.0,
},
Self {
pattern: vec![GateType::T, GateType::T, GateType::T, GateType::T],
replacement: vec![],
efficiency_gain: 4.0,
},
]
}
}
#[derive(Debug, Clone)]
pub struct FusedGateSequence {
pub gates: Vec<QuantumGate>,
pub fused_matrix: Vec<Complex64>,
pub target_qubits: Vec<usize>,
pub efficiency_gain: f64,
}
impl FusedGateSequence {
pub fn from_gates(gates: Vec<QuantumGate>) -> QuantRS2Result<Self> {
if gates.is_empty() {
return Err(QuantRS2Error::InvalidInput(
"Empty gate sequence".to_string(),
));
}
let target_qubits = gates[0].qubits.clone();
for gate in &gates {
if gate.qubits != target_qubits {
return Err(QuantRS2Error::InvalidInput(
"All gates must act on the same qubits for fusion".to_string(),
));
}
}
let matrix_size = gates[0].matrix.len();
let sqrt_size = (matrix_size as f64).sqrt() as usize;
let mut fused_matrix = Self::identity_matrix(sqrt_size);
for gate in gates.iter().rev() {
fused_matrix = Self::matrix_multiply(&fused_matrix, &gate.matrix, sqrt_size)?;
}
let efficiency_gain = gates.len() as f64;
Ok(Self {
gates,
fused_matrix,
target_qubits,
efficiency_gain,
})
}
fn identity_matrix(size: usize) -> Vec<Complex64> {
let mut matrix = vec![Complex64::new(0.0, 0.0); size * size];
for i in 0..size {
matrix[i * size + i] = Complex64::new(1.0, 0.0);
}
matrix
}
fn is_identity_matrix(&self) -> bool {
let size = (self.fused_matrix.len() as f64).sqrt() as usize;
let identity = Self::identity_matrix(size);
for (a, b) in self.fused_matrix.iter().zip(identity.iter()) {
if (a - b).norm() > 1e-10 {
return false;
}
}
true
}
fn matrix_multiply(
a: &[Complex64],
b: &[Complex64],
size: usize,
) -> QuantRS2Result<Vec<Complex64>> {
if a.len() != size * size || b.len() != size * size {
return Err(QuantRS2Error::InvalidInput(
"Matrix size mismatch".to_string(),
));
}
let mut result = vec![Complex64::new(0.0, 0.0); size * size];
for i in 0..size {
for j in 0..size {
for k in 0..size {
result[i * size + j] += a[i * size + k] * b[k * size + j];
}
}
}
Ok(result)
}
}
pub struct GateFusionEngine {
rules: Vec<FusionRule>,
statistics: Arc<RwLock<FusionStatistics>>,
}
#[derive(Debug, Clone, Default)]
pub struct FusionStatistics {
pub total_fusions: u64,
pub gates_eliminated: u64,
pub total_efficiency_gain: f64,
pub fusion_types: HashMap<String, u64>,
}
impl GateFusionEngine {
pub fn new() -> Self {
Self {
rules: FusionRule::common_rules(),
statistics: Arc::new(RwLock::new(FusionStatistics::default())),
}
}
pub fn add_rule(&mut self, rule: FusionRule) {
self.rules.push(rule);
}
pub fn fuse_gates(&self, gates: Vec<QuantumGate>) -> QuantRS2Result<Vec<FusedGateSequence>> {
if gates.is_empty() {
return Ok(vec![]);
}
let mut fused_sequences = Vec::new();
let mut i = 0;
while i < gates.len() {
let gate = &gates[i];
if let Some(fusion_length) = self.find_fusion_pattern(&gates[i..]) {
let fusion_gates = gates[i..i + fusion_length].to_vec();
let fused_sequence = FusedGateSequence::from_gates(fusion_gates)?;
if fused_sequence.is_identity_matrix() {
if let Ok(mut stats) = self.statistics.write() {
stats.total_fusions += 1;
stats.gates_eliminated += fusion_length as u64; }
} else {
if let Ok(mut stats) = self.statistics.write() {
stats.total_fusions += 1;
stats.gates_eliminated += (fusion_length - 1) as u64;
stats.total_efficiency_gain += fused_sequence.efficiency_gain;
let fusion_type = format!("{:?}_fusion", gate.gate_type);
*stats.fusion_types.entry(fusion_type).or_insert(0) += 1;
}
fused_sequences.push(fused_sequence);
}
i += fusion_length;
} else {
let mut gate_group = vec![gate.clone()];
let mut j = i + 1;
while j < gates.len() && gates[j].qubits == gate.qubits {
gate_group.push(gates[j].clone());
j += 1;
}
let fused_sequence = FusedGateSequence::from_gates(gate_group)?;
fused_sequences.push(fused_sequence);
i = j;
}
}
Ok(fused_sequences)
}
fn find_fusion_pattern(&self, gates: &[QuantumGate]) -> Option<usize> {
for rule in &self.rules {
if gates.len() >= rule.pattern.len() {
let matches = gates[..rule.pattern.len()]
.iter()
.zip(&rule.pattern)
.all(|(gate, pattern_gate)| gate.gate_type == *pattern_gate);
let same_qubits = if rule.pattern.len() > 1 {
let first_qubits = &gates[0].qubits;
gates[1..rule.pattern.len()]
.iter()
.all(|gate| gate.qubits == *first_qubits)
} else {
true };
if matches && same_qubits {
return Some(rule.pattern.len());
}
}
}
if gates.len() >= 2 {
let first = &gates[0];
if first.num_qubits() == 1 {
let mut count = 1;
for gate in gates.iter().skip(1) {
if gate.gate_type == first.gate_type && gate.qubits == first.qubits {
count += 1;
} else {
break;
}
}
if count > 1 {
return Some(count); }
}
}
None
}
pub fn get_statistics(&self) -> FusionStatistics {
self.statistics
.read()
.map(|guard| guard.clone())
.unwrap_or_default()
}
pub fn get_global_statistics() -> FusionStatistics {
if let Some(engine) = GLOBAL_FUSION_ENGINE.get() {
engine.get_statistics()
} else {
FusionStatistics::default()
}
}
pub fn reset_statistics(&self) {
if let Ok(mut stats) = self.statistics.write() {
*stats = FusionStatistics::default();
}
}
}
impl Default for GateFusionEngine {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_FUSION_ENGINE: OnceLock<GateFusionEngine> = OnceLock::new();
pub fn get_global_fusion_engine() -> &'static GateFusionEngine {
GLOBAL_FUSION_ENGINE.get_or_init(GateFusionEngine::new)
}
pub fn apply_gate_fusion(gates: Vec<QuantumGate>) -> QuantRS2Result<Vec<FusedGateSequence>> {
let engine = get_global_fusion_engine();
engine.fuse_gates(gates)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pauli_x_fusion() {
let gates = vec![
QuantumGate::new(GateType::PauliX, vec![0]).expect("failed to create PauliX gate"),
QuantumGate::new(GateType::PauliX, vec![0]).expect("failed to create PauliX gate"),
];
let engine = GateFusionEngine::new();
let fused = engine.fuse_gates(gates).expect("failed to fuse gates");
assert_eq!(fused.len(), 0);
let stats = engine.get_statistics();
assert_eq!(stats.gates_eliminated, 2);
}
#[test]
fn test_hadamard_fusion() {
let gates = vec![
QuantumGate::new(GateType::Hadamard, vec![0]).expect("failed to create Hadamard gate"),
QuantumGate::new(GateType::Hadamard, vec![0]).expect("failed to create Hadamard gate"),
];
let engine = GateFusionEngine::new();
let fused = engine.fuse_gates(gates).expect("failed to fuse gates");
assert_eq!(fused.len(), 0);
}
#[test]
fn test_mixed_gate_fusion() {
let gates = vec![
QuantumGate::new(GateType::PauliX, vec![0]).expect("failed to create PauliX gate"),
QuantumGate::new(GateType::PauliY, vec![0]).expect("failed to create PauliY gate"),
QuantumGate::new(GateType::PauliZ, vec![0]).expect("failed to create PauliZ gate"),
];
let engine = GateFusionEngine::new();
let fused = engine.fuse_gates(gates).expect("failed to fuse gates");
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].gates.len(), 3);
}
#[test]
fn test_no_fusion_different_qubits() {
let gates = vec![
QuantumGate::new(GateType::PauliX, vec![0]).expect("failed to create PauliX gate"),
QuantumGate::new(GateType::PauliX, vec![1]).expect("failed to create PauliX gate"), ];
let engine = GateFusionEngine::new();
let fused = engine.fuse_gates(gates).expect("failed to fuse gates");
assert_eq!(fused.len(), 2);
}
#[test]
fn test_matrix_multiplication() {
let identity = vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
];
let pauli_x = vec![
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
];
let result = FusedGateSequence::matrix_multiply(&identity, &pauli_x, 2)
.expect("matrix multiplication failed");
for (a, b) in result.iter().zip(pauli_x.iter()) {
assert!((a - b).norm() < 1e-10);
}
}
#[test]
fn test_efficiency_gain_calculation() {
let gates = vec![
QuantumGate::new(GateType::S, vec![0]).expect("failed to create S gate"),
QuantumGate::new(GateType::T, vec![0]).expect("failed to create T gate"),
QuantumGate::new(GateType::Hadamard, vec![0]).expect("failed to create Hadamard gate"),
];
let fused = FusedGateSequence::from_gates(gates).expect("failed to create fused sequence");
assert_eq!(fused.efficiency_gain, 3.0); }
}