use ndarray::{Array2, s};
use num_complex::Complex;
use std::f64::consts::PI;
use super::error::{QuantumError, QuantumResult};
use super::qubits::QubitRegister;
pub trait QuantumGate: Send + Sync {
fn matrix(&self) -> Array2<Complex<f64>>;
fn n_qubits(&self) -> usize;
fn name(&self) -> &str;
}
fn mat2(a: Complex<f64>, b: Complex<f64>, c: Complex<f64>, d: Complex<f64>) -> Array2<Complex<f64>> {
Array2::from_shape_vec((2, 2), vec![a, b, c, d])
.expect("2x2 matrix construction is infallible")
}
fn c(re: f64, im: f64) -> Complex<f64> {
Complex::new(re, im)
}
fn cr(re: f64) -> Complex<f64> {
Complex::new(re, 0.0)
}
fn ci(im: f64) -> Complex<f64> {
Complex::new(0.0, im)
}
pub struct Identity;
impl QuantumGate for Identity {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(1.0), cr(0.0), cr(0.0), cr(1.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "I" }
}
pub struct PauliX;
impl QuantumGate for PauliX {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(0.0), cr(1.0), cr(1.0), cr(0.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "X" }
}
pub struct PauliY;
impl QuantumGate for PauliY {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(0.0), ci(-1.0), ci(1.0), cr(0.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Y" }
}
pub struct PauliZ;
impl QuantumGate for PauliZ {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(1.0), cr(0.0), cr(0.0), cr(-1.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Z" }
}
pub struct Hadamard;
impl QuantumGate for Hadamard {
fn matrix(&self) -> Array2<Complex<f64>> {
let s = 1.0 / 2.0_f64.sqrt();
mat2(cr(s), cr(s), cr(s), cr(-s))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "H" }
}
pub struct PhaseS;
impl QuantumGate for PhaseS {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(1.0), cr(0.0), cr(0.0), ci(1.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "S" }
}
pub struct PhaseSdg;
impl QuantumGate for PhaseSdg {
fn matrix(&self) -> Array2<Complex<f64>> {
mat2(cr(1.0), cr(0.0), cr(0.0), ci(-1.0))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Sdg" }
}
pub struct PhaseT;
impl QuantumGate for PhaseT {
fn matrix(&self) -> Array2<Complex<f64>> {
let phase = Complex::from_polar(1.0, PI / 4.0);
mat2(cr(1.0), cr(0.0), cr(0.0), phase)
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "T" }
}
pub struct PhaseTdg;
impl QuantumGate for PhaseTdg {
fn matrix(&self) -> Array2<Complex<f64>> {
let phase = Complex::from_polar(1.0, -PI / 4.0);
mat2(cr(1.0), cr(0.0), cr(0.0), phase)
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Tdg" }
}
pub struct RotX {
pub theta: f64,
}
impl QuantumGate for RotX {
fn matrix(&self) -> Array2<Complex<f64>> {
let (s, co) = (self.theta / 2.0).sin_cos();
mat2(cr(co), ci(-s), ci(-s), cr(co))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Rx" }
}
pub struct RotY {
pub theta: f64,
}
impl QuantumGate for RotY {
fn matrix(&self) -> Array2<Complex<f64>> {
let (s, co) = (self.theta / 2.0).sin_cos();
mat2(cr(co), cr(-s), cr(s), cr(co))
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Ry" }
}
pub struct RotZ {
pub theta: f64,
}
impl QuantumGate for RotZ {
fn matrix(&self) -> Array2<Complex<f64>> {
let neg = Complex::from_polar(1.0, -self.theta / 2.0);
let pos = Complex::from_polar(1.0, self.theta / 2.0);
mat2(neg, cr(0.0), cr(0.0), pos)
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "Rz" }
}
pub struct PhaseShift {
pub lambda: f64,
}
impl QuantumGate for PhaseShift {
fn matrix(&self) -> Array2<Complex<f64>> {
let phase = Complex::from_polar(1.0, self.lambda);
mat2(cr(1.0), cr(0.0), cr(0.0), phase)
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "P" }
}
pub struct Unitary1Q {
pub theta: f64,
pub phi: f64,
pub lambda: f64,
}
impl QuantumGate for Unitary1Q {
fn matrix(&self) -> Array2<Complex<f64>> {
let (s, co) = (self.theta / 2.0).sin_cos();
let eiphi = Complex::from_polar(1.0, self.phi);
let eilambda = Complex::from_polar(1.0, self.lambda);
let eiphilambda = Complex::from_polar(1.0, self.phi + self.lambda);
mat2(cr(co), -eilambda * s, eiphi * s, eiphilambda * co)
}
fn n_qubits(&self) -> usize { 1 }
fn name(&self) -> &str { "U" }
}
pub struct CNOT;
impl QuantumGate for CNOT {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i = cr(1.0);
Array2::from_shape_vec(
(4, 4),
vec![
i, o, o, o,
o, i, o, o,
o, o, o, i,
o, o, i, o,
],
)
.expect("4x4 matrix construction is infallible")
}
fn n_qubits(&self) -> usize { 2 }
fn name(&self) -> &str { "CNOT" }
}
pub struct CZ;
impl QuantumGate for CZ {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i = cr(1.0);
let m = cr(-1.0);
Array2::from_shape_vec(
(4, 4),
vec![
i, o, o, o,
o, i, o, o,
o, o, i, o,
o, o, o, m,
],
)
.expect("4x4 matrix construction is infallible")
}
fn n_qubits(&self) -> usize { 2 }
fn name(&self) -> &str { "CZ" }
}
pub struct SWAP;
impl QuantumGate for SWAP {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i = cr(1.0);
Array2::from_shape_vec(
(4, 4),
vec![
i, o, o, o,
o, o, i, o,
o, i, o, o,
o, o, o, i,
],
)
.expect("4x4 matrix construction is infallible")
}
fn n_qubits(&self) -> usize { 2 }
fn name(&self) -> &str { "SWAP" }
}
pub struct ISWAP;
impl QuantumGate for ISWAP {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i_re = cr(1.0);
let i_im = ci(1.0);
Array2::from_shape_vec(
(4, 4),
vec![
i_re, o, o, o,
o, o, i_im, o,
o, i_im, o, o,
o, o, o, i_re,
],
)
.expect("4x4 matrix construction is infallible")
}
fn n_qubits(&self) -> usize { 2 }
fn name(&self) -> &str { "iSWAP" }
}
pub struct CU {
inner: Box<dyn QuantumGate>,
}
impl CU {
pub fn new(gate: impl QuantumGate + 'static) -> QuantumResult<Self> {
if gate.n_qubits() != 1 {
return Err(QuantumError::GateArityMismatch {
gate_qubits: gate.n_qubits(),
supplied: 1,
});
}
Ok(Self { inner: Box::new(gate) })
}
}
impl QuantumGate for CU {
fn matrix(&self) -> Array2<Complex<f64>> {
let u = self.inner.matrix();
let o = cr(0.0);
let i = cr(1.0);
let u00 = u[[0, 0]];
let u01 = u[[0, 1]];
let u10 = u[[1, 0]];
let u11 = u[[1, 1]];
Array2::from_shape_vec(
(4, 4),
vec![
i, o, o, o,
o, i, o, o,
o, o, u00, u01,
o, o, u10, u11,
],
)
.expect("4x4 matrix construction is infallible")
}
fn n_qubits(&self) -> usize { 2 }
fn name(&self) -> &str { "CU" }
}
pub struct Toffoli;
impl QuantumGate for Toffoli {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i = cr(1.0);
let mut m = Array2::<Complex<f64>>::from_elem((8, 8), o);
for k in 0..6usize {
m[[k, k]] = i;
}
m[[6, 7]] = i;
m[[7, 6]] = i;
m
}
fn n_qubits(&self) -> usize { 3 }
fn name(&self) -> &str { "Toffoli" }
}
pub struct Fredkin;
impl QuantumGate for Fredkin {
fn matrix(&self) -> Array2<Complex<f64>> {
let o = cr(0.0);
let i = cr(1.0);
let mut m = Array2::<Complex<f64>>::from_elem((8, 8), o);
for k in 0..8usize {
m[[k, k]] = i;
}
m[[5, 5]] = o;
m[[6, 6]] = o;
m[[5, 6]] = i;
m[[6, 5]] = i;
m
}
fn n_qubits(&self) -> usize { 3 }
fn name(&self) -> &str { "Fredkin" }
}
pub fn apply_gate(
state: &mut QubitRegister,
gate: &dyn QuantumGate,
target_qubits: &[usize],
) -> QuantumResult<()> {
let gate_qubits = gate.n_qubits();
if target_qubits.len() != gate_qubits {
return Err(QuantumError::GateArityMismatch {
gate_qubits,
supplied: target_qubits.len(),
});
}
for &q in target_qubits {
if q >= state.n_qubits() {
return Err(QuantumError::QubitIndexOutOfRange {
index: q,
n_qubits: state.n_qubits(),
});
}
}
for i in 0..target_qubits.len() {
for j in (i + 1)..target_qubits.len() {
if target_qubits[i] == target_qubits[j] {
return Err(QuantumError::DuplicateQubitIndex {
index: target_qubits[i],
});
}
}
}
let gate_mat = gate.matrix();
let gate_dim = 1usize << gate_qubits;
let total_qubits = state.n_qubits();
let total_dim = state.dim();
let mut new_amps = state.amplitudes.clone();
let non_target_dim = total_dim / gate_dim;
for outer in 0..non_target_dim {
let mut sub_amps = vec![Complex::new(0.0, 0.0); gate_dim];
let indices: Vec<usize> = (0..gate_dim)
.map(|g| {
gate_idx_to_full_idx(g, outer, target_qubits, total_qubits)
})
.collect();
for (g, &full_idx) in indices.iter().enumerate() {
sub_amps[g] = state.amplitudes[full_idx];
}
let mut result = vec![Complex::new(0.0, 0.0); gate_dim];
for row in 0..gate_dim {
for col in 0..gate_dim {
result[row] += gate_mat[[row, col]] * sub_amps[col];
}
}
for (g, &full_idx) in indices.iter().enumerate() {
new_amps[full_idx] = result[g];
}
}
state.amplitudes = new_amps;
Ok(())
}
fn gate_idx_to_full_idx(
gate_idx: usize,
outer: usize,
target_qubits: &[usize],
total_qubits: usize,
) -> usize {
let gate_qubits = target_qubits.len();
let mut full = 0usize;
let mut target_set = [usize::MAX; 64];
for (i, &t) in target_qubits.iter().enumerate() {
target_set[i] = t;
}
let mut outer_idx = 0usize;
for bit_pos in 0..total_qubits {
let mut target_local = usize::MAX;
for i in 0..gate_qubits {
if target_set[i] == bit_pos {
target_local = i;
break;
}
}
if target_local != usize::MAX {
let gate_bit = (gate_idx >> (gate_qubits - 1 - target_local)) & 1;
full |= gate_bit << bit_pos;
} else {
let outer_bit = (outer >> outer_idx) & 1;
full |= outer_bit << bit_pos;
outer_idx += 1;
}
}
full
}
pub fn tensor_product_matrices(
u1: &Array2<Complex<f64>>,
u2: &Array2<Complex<f64>>,
) -> Array2<Complex<f64>> {
let (r1, c1) = (u1.nrows(), u1.ncols());
let (r2, c2) = (u2.nrows(), u2.ncols());
let rows = r1 * r2;
let cols = c1 * c2;
let mut result = Array2::zeros((rows, cols));
for i in 0..r1 {
for j in 0..c1 {
for k in 0..r2 {
for l in 0..c2 {
result[[i * r2 + k, j * c2 + l]] = u1[[i, j]] * u2[[k, l]];
}
}
}
}
result
}
pub fn check_unitary(u: &Array2<Complex<f64>>, tol: f64) -> QuantumResult<()> {
let n = u.nrows();
if u.ncols() != n {
return Err(QuantumError::DimensionMismatch {
expected: n,
actual: u.ncols(),
});
}
let mut max_dev: f64 = 0.0;
for i in 0..n {
for j in 0..n {
let val: Complex<f64> = (0..n).map(|k| u[[k, i]].conj() * u[[k, j]]).sum();
let expected = if i == j { Complex::new(1.0, 0.0) } else { Complex::new(0.0, 0.0) };
let dev = (val - expected).norm();
if dev > max_dev {
max_dev = dev;
}
}
}
if max_dev > tol {
return Err(QuantumError::NonUnitaryGate { deviation: max_dev });
}
Ok(())
}
pub fn matrix_product(
a: &Array2<Complex<f64>>,
b: &Array2<Complex<f64>>,
) -> QuantumResult<Array2<Complex<f64>>> {
let n = a.nrows();
if a.ncols() != n || b.nrows() != n || b.ncols() != n {
return Err(QuantumError::DimensionMismatch {
expected: n,
actual: b.nrows(),
});
}
let mut result = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let val: Complex<f64> = (0..n).map(|k| a[[i, k]] * b[[k, j]]).sum();
result[[i, j]] = val;
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::qubits::QubitRegister;
const TOL: f64 = 1e-12;
fn assert_complex_close(a: Complex<f64>, b: Complex<f64>, tol: f64, msg: &str) {
assert!(
(a - b).norm() < tol,
"{}: expected {:?}, got {:?}",
msg,
b,
a
);
}
#[test]
fn test_pauli_x_unitary() {
check_unitary(&PauliX.matrix(), 1e-12).expect("PauliX should be unitary");
}
#[test]
fn test_hadamard_unitary() {
check_unitary(&Hadamard.matrix(), 1e-12).expect("H should be unitary");
}
#[test]
fn test_cnot_unitary() {
check_unitary(&CNOT.matrix(), 1e-12).expect("CNOT should be unitary");
}
#[test]
fn test_toffoli_unitary() {
check_unitary(&Toffoli.matrix(), 1e-12).expect("Toffoli should be unitary");
}
#[test]
fn test_x_flips_zero() {
let mut reg = QubitRegister::new_zero_state(1).expect("valid");
apply_gate(&mut reg, &PauliX, &[0]).expect("apply ok");
assert!((reg.probability(1).expect("ok") - 1.0).abs() < TOL);
}
#[test]
fn test_x_flips_one() {
let mut reg = QubitRegister::new_basis_state(1, 1).expect("valid");
apply_gate(&mut reg, &PauliX, &[0]).expect("apply ok");
assert!((reg.probability(0).expect("ok") - 1.0).abs() < TOL);
}
#[test]
fn test_hadamard_superposition() {
let mut reg = QubitRegister::new_zero_state(1).expect("valid");
apply_gate(&mut reg, &Hadamard, &[0]).expect("apply ok");
let p0 = reg.probability(0).expect("ok");
let p1 = reg.probability(1).expect("ok");
assert!((p0 - 0.5).abs() < TOL);
assert!((p1 - 0.5).abs() < TOL);
}
#[test]
fn test_cnot_creates_bell_state() {
let mut reg = QubitRegister::new_zero_state(2).expect("valid");
apply_gate(&mut reg, &Hadamard, &[0]).expect("H ok");
apply_gate(&mut reg, &CNOT, &[0, 1]).expect("CNOT ok");
let p00 = reg.probability(0).expect("ok");
let p11 = reg.probability(3).expect("ok");
let p01 = reg.probability(1).expect("ok");
let p10 = reg.probability(2).expect("ok");
assert!((p00 - 0.5).abs() < TOL, "p00={}", p00);
assert!((p11 - 0.5).abs() < TOL, "p11={}", p11);
assert!(p01.abs() < TOL, "p01={}", p01);
assert!(p10.abs() < TOL, "p10={}", p10);
}
#[test]
fn test_z_phase_flip() {
let mut reg = QubitRegister::new_zero_state(1).expect("valid");
apply_gate(&mut reg, &Hadamard, &[0]).expect("H ok");
apply_gate(&mut reg, &PauliZ, &[0]).expect("Z ok");
let amp1 = reg.amplitude(1).expect("ok");
assert!(amp1.re < 0.0);
}
#[test]
fn test_duplicate_qubit_error() {
let mut reg = QubitRegister::new_zero_state(2).expect("valid");
let err = apply_gate(&mut reg, &CNOT, &[0, 0]);
assert!(matches!(err, Err(QuantumError::DuplicateQubitIndex { .. })));
}
#[test]
fn test_arity_error() {
let mut reg = QubitRegister::new_zero_state(2).expect("valid");
let err = apply_gate(&mut reg, &PauliX, &[0, 1]);
assert!(matches!(err, Err(QuantumError::GateArityMismatch { .. })));
}
#[test]
fn test_swap_swaps_qubits() {
let mut reg = QubitRegister::new_basis_state(2, 2).expect("valid");
apply_gate(&mut reg, &SWAP, &[0, 1]).expect("SWAP ok");
let p1 = reg.probability(1).expect("ok");
assert!((p1 - 1.0).abs() < TOL, "SWAP should move |10⟩ to |01⟩, got p1={}", p1);
}
#[test]
fn test_rot_x_pi_equals_x() {
let rx_pi = RotX { theta: PI };
let mx = rx_pi.matrix();
assert!((mx[[0, 1]].norm() - 1.0).abs() < 1e-10);
assert!((mx[[1, 0]].norm() - 1.0).abs() < 1e-10);
}
#[test]
fn test_toffoli_flips_when_both_controls_set() {
let mut reg = QubitRegister::new_basis_state(3, 3).expect("valid");
apply_gate(&mut reg, &Toffoli, &[0, 1, 2]).expect("ok");
let p7 = reg.probability(7).expect("ok");
assert!((p7 - 1.0).abs() < TOL, "p7={}", p7);
}
}