use crate::{
cartan::OptimizedCartanDecomposer,
controlled::make_controlled,
error::{QuantRS2Error, QuantRS2Result},
gate::{single::*, GateOp},
matrix_ops::{DenseMatrix, QuantumMatrix},
qubit::QubitId,
synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
};
use rustc_hash::FxHashMap;
use scirs2_core::ndarray::{s, Array2};
use scirs2_core::Complex;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct ShannonDecomposition {
pub gates: Vec<Box<dyn GateOp>>,
pub cnot_count: usize,
pub single_qubit_count: usize,
pub depth: usize,
}
pub struct ShannonDecomposer {
tolerance: f64,
cache: FxHashMap<u64, ShannonDecomposition>,
max_depth: usize,
}
impl ShannonDecomposer {
pub fn new() -> Self {
Self {
tolerance: 1e-10,
cache: FxHashMap::default(),
max_depth: 20,
}
}
pub fn with_tolerance(tolerance: f64) -> Self {
Self {
tolerance,
cache: FxHashMap::default(),
max_depth: 20,
}
}
pub fn decompose(
&mut self,
unitary: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
) -> QuantRS2Result<ShannonDecomposition> {
let n = qubit_ids.len();
let size = 1 << n;
if unitary.shape() != [size, size] {
return Err(QuantRS2Error::InvalidInput(format!(
"Unitary size {} doesn't match {} qubits",
unitary.shape()[0],
n
)));
}
let mat = DenseMatrix::new(unitary.clone())?;
if !mat.is_unitary(self.tolerance)? {
return Err(QuantRS2Error::InvalidInput(
"Matrix is not unitary".to_string(),
));
}
if n == 0 {
return Ok(ShannonDecomposition {
gates: vec![],
cnot_count: 0,
single_qubit_count: 0,
depth: 0,
});
}
if n == 1 {
let decomp = decompose_single_qubit_zyz(&unitary.view())?;
let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
let count = gates.len();
return Ok(ShannonDecomposition {
gates,
cnot_count: 0,
single_qubit_count: count,
depth: count,
});
}
if n == 2 {
return self.decompose_two_qubit(unitary, qubit_ids);
}
self.decompose_recursive(unitary, qubit_ids, 0)
}
fn decompose_recursive(
&mut self,
unitary: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
depth: usize,
) -> QuantRS2Result<ShannonDecomposition> {
if depth > self.max_depth {
return Err(QuantRS2Error::InvalidInput(
"Maximum recursion depth exceeded".to_string(),
));
}
let n = qubit_ids.len();
let half_size = 1 << (n - 1);
let a = unitary.slice(s![..half_size, ..half_size]).to_owned();
let b = unitary.slice(s![..half_size, half_size..]).to_owned();
let c = unitary.slice(s![half_size.., ..half_size]).to_owned();
let d = unitary.slice(s![half_size.., half_size..]).to_owned();
let (v, w, u_diag) = self.block_diagonalize(&a, &b, &c, &d)?;
let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
let mut cnot_count = 0;
let mut single_qubit_count = 0;
if !self.is_identity(&w) {
let w_decomp = self.decompose_recursive(&w, &qubit_ids[1..], depth + 1)?;
gates.extend(w_decomp.gates);
cnot_count += w_decomp.cnot_count;
single_qubit_count += w_decomp.single_qubit_count;
}
let diag_gates = self.decompose_controlled_diagonal(&u_diag, qubit_ids)?;
cnot_count += diag_gates.1;
single_qubit_count += diag_gates.2;
gates.extend(diag_gates.0);
if !self.is_identity(&v) {
let v_dag = v.mapv(|z| z.conj()).t().to_owned();
let v_decomp = self.decompose_recursive(&v_dag, &qubit_ids[1..], depth + 1)?;
gates.extend(v_decomp.gates);
cnot_count += v_decomp.cnot_count;
single_qubit_count += v_decomp.single_qubit_count;
}
let depth = gates.len();
Ok(ShannonDecomposition {
gates,
cnot_count,
single_qubit_count,
depth,
})
}
fn block_diagonalize(
&self,
a: &Array2<Complex<f64>>,
b: &Array2<Complex<f64>>,
c: &Array2<Complex<f64>>,
d: &Array2<Complex<f64>>,
) -> QuantRS2Result<(
Array2<Complex<f64>>,
Array2<Complex<f64>>,
Array2<Complex<f64>>,
)> {
let size = a.shape()[0];
let b_norm = b.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
let c_norm = c.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
if b_norm < self.tolerance && c_norm < self.tolerance {
let identity = Array2::eye(size);
let combined = self.combine_blocks(a, b, c, d);
return Ok((identity.clone(), identity, combined));
}
let combined = self.combine_blocks(a, b, c, d);
let identity = Array2::eye(size);
Ok((identity.clone(), identity, combined))
}
fn combine_blocks(
&self,
a: &Array2<Complex<f64>>,
b: &Array2<Complex<f64>>,
c: &Array2<Complex<f64>>,
d: &Array2<Complex<f64>>,
) -> Array2<Complex<f64>> {
let size = a.shape()[0];
let total_size = 2 * size;
let mut result = Array2::zeros((total_size, total_size));
result.slice_mut(s![..size, ..size]).assign(a);
result.slice_mut(s![..size, size..]).assign(b);
result.slice_mut(s![size.., ..size]).assign(c);
result.slice_mut(s![size.., size..]).assign(d);
result
}
fn decompose_controlled_diagonal(
&self,
diagonal: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
) -> QuantRS2Result<(Vec<Box<dyn GateOp>>, usize, usize)> {
let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
let mut cnot_count = 0;
let mut single_qubit_count = 0;
let n = diagonal.shape()[0];
let mut phases = Vec::with_capacity(n);
for i in 0..n {
let phase = diagonal[[i, i]].arg();
phases.push(phase);
}
let control = qubit_ids[0];
for (i, &phase) in phases.iter().enumerate() {
if phase.abs() > self.tolerance {
if i == 0 {
let gate: Box<dyn GateOp> = Box::new(RotationZ {
target: control,
theta: phase,
});
gates.push(gate);
single_qubit_count += 1;
} else {
let base_gate = Box::new(RotationZ {
target: qubit_ids[1],
theta: phase,
});
let controlled = Box::new(make_controlled(vec![control], *base_gate));
gates.push(controlled);
cnot_count += 2; single_qubit_count += 3; }
}
}
Ok((gates, cnot_count, single_qubit_count))
}
fn decompose_two_qubit(
&self,
unitary: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
) -> QuantRS2Result<ShannonDecomposition> {
if self.is_identity(unitary) {
return Ok(ShannonDecomposition {
gates: vec![],
cnot_count: 0,
single_qubit_count: 0,
depth: 0,
});
}
let mut cartan_decomposer = OptimizedCartanDecomposer::new();
let cartan_decomp = cartan_decomposer.decompose(unitary)?;
let gates = cartan_decomposer.base.to_gates(&cartan_decomp, qubit_ids)?;
let mut cnot_count = 0;
let mut single_qubit_count = 0;
for gate in &gates {
match gate.name() {
"CNOT" => cnot_count += 1,
_ => single_qubit_count += 1,
}
}
let depth = gates.len();
Ok(ShannonDecomposition {
gates,
cnot_count,
single_qubit_count,
depth,
})
}
fn single_qubit_to_gates(
&self,
decomp: &SingleQubitDecomposition,
qubit: QubitId,
) -> Vec<Box<dyn GateOp>> {
let mut gates = Vec::new();
if decomp.theta1.abs() > self.tolerance {
gates.push(Box::new(RotationZ {
target: qubit,
theta: decomp.theta1,
}) as Box<dyn GateOp>);
}
if decomp.phi.abs() > self.tolerance {
gates.push(Box::new(RotationY {
target: qubit,
theta: decomp.phi,
}) as Box<dyn GateOp>);
}
if decomp.theta2.abs() > self.tolerance {
gates.push(Box::new(RotationZ {
target: qubit,
theta: decomp.theta2,
}) as Box<dyn GateOp>);
}
gates
}
fn is_identity(&self, matrix: &Array2<Complex<f64>>) -> bool {
let n = matrix.shape()[0];
for i in 0..n {
for j in 0..n {
let expected = if i == j {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
};
if (matrix[[i, j]] - expected).norm() > self.tolerance {
return false;
}
}
}
true
}
}
pub struct OptimizedShannonDecomposer {
base: ShannonDecomposer,
peephole: bool,
commutation: bool,
}
impl OptimizedShannonDecomposer {
pub fn new() -> Self {
Self {
base: ShannonDecomposer::new(),
peephole: true,
commutation: true,
}
}
pub fn decompose(
&mut self,
unitary: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
) -> QuantRS2Result<ShannonDecomposition> {
let mut decomp = self.base.decompose(unitary, qubit_ids)?;
if self.peephole {
decomp = self.apply_peephole_optimization(decomp)?;
}
if self.commutation {
decomp = self.apply_commutation_optimization(decomp)?;
}
Ok(decomp)
}
fn apply_peephole_optimization(
&self,
mut decomp: ShannonDecomposition,
) -> QuantRS2Result<ShannonDecomposition> {
let mut optimized_gates = Vec::new();
let mut i = 0;
while i < decomp.gates.len() {
if i + 1 < decomp.gates.len() {
if self.gates_cancel(&decomp.gates[i], &decomp.gates[i + 1]) {
i += 2;
decomp.cnot_count =
decomp
.cnot_count
.saturating_sub(if decomp.gates[i - 2].name() == "CNOT" {
2
} else {
0
});
decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(
if decomp.gates[i - 2].name() == "CNOT" {
0
} else {
2
},
);
continue;
}
if let Some(merged) =
self.try_merge_rotations(&decomp.gates[i], &decomp.gates[i + 1])
{
optimized_gates.push(merged);
i += 2;
decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(1);
continue;
}
}
optimized_gates.push(decomp.gates[i].clone());
i += 1;
}
decomp.gates = optimized_gates;
decomp.depth = decomp.gates.len();
Ok(decomp)
}
const fn apply_commutation_optimization(
&self,
decomp: ShannonDecomposition,
) -> QuantRS2Result<ShannonDecomposition> {
Ok(decomp)
}
fn gates_cancel(&self, gate1: &Box<dyn GateOp>, gate2: &Box<dyn GateOp>) -> bool {
if gate1.name() == gate2.name() && gate1.qubits() == gate2.qubits() {
match gate1.name() {
"X" | "Y" | "Z" | "H" | "CNOT" | "SWAP" => true,
_ => false,
}
} else {
false
}
}
fn try_merge_rotations(
&self,
gate1: &Box<dyn GateOp>,
gate2: &Box<dyn GateOp>,
) -> Option<Box<dyn GateOp>> {
if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
return None;
}
let qubit = gate1.qubits()[0];
match (gate1.name(), gate2.name()) {
("RZ", "RZ") => {
let theta1 = gate1.as_any().downcast_ref::<RotationZ>()?.theta;
let theta2 = gate2.as_any().downcast_ref::<RotationZ>()?.theta;
Some(Box::new(RotationZ {
target: qubit,
theta: theta1 + theta2,
}))
}
("RX", "RX") => {
let theta1 = gate1.as_any().downcast_ref::<RotationX>()?.theta;
let theta2 = gate2.as_any().downcast_ref::<RotationX>()?.theta;
Some(Box::new(RotationX {
target: qubit,
theta: theta1 + theta2,
}))
}
("RY", "RY") => {
let theta1 = gate1.as_any().downcast_ref::<RotationY>()?.theta;
let theta2 = gate2.as_any().downcast_ref::<RotationY>()?.theta;
Some(Box::new(RotationY {
target: qubit,
theta: theta1 + theta2,
}))
}
_ => None,
}
}
}
pub fn shannon_decompose(
unitary: &Array2<Complex<f64>>,
qubit_ids: &[QubitId],
) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
let mut decomposer = ShannonDecomposer::new();
let decomp = decomposer.decompose(unitary, qubit_ids)?;
Ok(decomp.gates)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
use scirs2_core::Complex;
#[test]
fn test_shannon_single_qubit() {
let mut decomposer = ShannonDecomposer::new();
let h = Array2::from_shape_vec(
(2, 2),
vec![
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(-1.0, 0.0),
],
)
.expect("Failed to create Hadamard matrix")
/ Complex::new(2.0_f64.sqrt(), 0.0);
let qubit_ids = vec![QubitId(0)];
let decomp = decomposer
.decompose(&h, &qubit_ids)
.expect("Failed to decompose Hadamard gate");
assert!(decomp.single_qubit_count <= 3);
assert_eq!(decomp.cnot_count, 0);
}
#[test]
fn test_shannon_two_qubit() {
let mut decomposer = ShannonDecomposer::new();
let cnot = Array2::from_shape_vec(
(4, 4),
vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(0.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(0.0, 0.0),
],
)
.expect("Failed to create CNOT matrix");
let qubit_ids = vec![QubitId(0), QubitId(1)];
let decomp = decomposer
.decompose(&cnot, &qubit_ids)
.expect("Failed to decompose CNOT gate");
assert!(decomp.cnot_count <= 3);
}
#[test]
fn test_optimized_decomposer() {
let mut decomposer = OptimizedShannonDecomposer::new();
let identity = Array2::eye(4);
let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
let qubit_ids = vec![QubitId(0), QubitId(1)];
let decomp = decomposer
.decompose(&identity_complex, &qubit_ids)
.expect("Failed to decompose identity matrix");
assert_eq!(decomp.gates.len(), 0);
}
#[test]
fn test_merge_rz_rotations() {
let decomposer = OptimizedShannonDecomposer::new();
let qubit = QubitId(0);
let g1 = Box::new(RotationZ {
target: qubit,
theta: 0.3,
}) as Box<dyn GateOp>;
let g2 = Box::new(RotationZ {
target: qubit,
theta: 0.4,
}) as Box<dyn GateOp>;
let merged = decomposer
.try_merge_rotations(&g1, &g2)
.expect("should merge RZ+RZ");
let rz = merged
.as_any()
.downcast_ref::<RotationZ>()
.expect("merged gate must be RotationZ");
assert!(
(rz.theta - 0.7).abs() < 1e-10,
"merged theta should be 0.7, got {}",
rz.theta
);
}
#[test]
fn test_merge_rx_rotations() {
let decomposer = OptimizedShannonDecomposer::new();
let qubit = QubitId(0);
let g1 = Box::new(RotationX {
target: qubit,
theta: 0.5,
}) as Box<dyn GateOp>;
let g2 = Box::new(RotationX {
target: qubit,
theta: 0.3,
}) as Box<dyn GateOp>;
let merged = decomposer
.try_merge_rotations(&g1, &g2)
.expect("should merge RX+RX");
let rx = merged
.as_any()
.downcast_ref::<RotationX>()
.expect("merged gate must be RotationX");
assert!(
(rx.theta - 0.8).abs() < 1e-10,
"merged theta should be 0.8, got {}",
rx.theta
);
}
#[test]
fn test_no_merge_different_axes() {
let decomposer = OptimizedShannonDecomposer::new();
let qubit = QubitId(0);
let g1 = Box::new(RotationZ {
target: qubit,
theta: 0.3,
}) as Box<dyn GateOp>;
let g2 = Box::new(RotationX {
target: qubit,
theta: 0.4,
}) as Box<dyn GateOp>;
assert!(
decomposer.try_merge_rotations(&g1, &g2).is_none(),
"RZ and RX should not merge"
);
}
}