use crate::error::{Result, SimulatorError};
use quantrs2_circuit::builder::Circuit;
use quantrs2_core::qubit::QubitId;
mod codes;
pub use codes::*;
pub trait ErrorCorrection {
fn physical_qubits(&self) -> usize;
fn logical_qubits(&self) -> usize;
fn distance(&self) -> usize;
fn encode_circuit(
&self,
logical_qubits: &[QubitId],
ancilla_qubits: &[QubitId],
) -> Result<Circuit<16>>;
fn decode_circuit(
&self,
encoded_qubits: &[QubitId],
syndrome_qubits: &[QubitId],
) -> Result<Circuit<16>>;
}
pub mod utils {
use super::{ErrorCorrection, QubitId, Result, SimulatorError};
use quantrs2_circuit::builder::Circuit;
pub fn create_error_corrected_circuit<T: ErrorCorrection, const N: usize>(
initial_circuit: &Circuit<N>,
code: &T,
logical_qubits: &[QubitId],
ancilla_qubits: &[QubitId],
syndrome_qubits: &[QubitId],
) -> Result<Circuit<N>> {
let mut result = Circuit::<N>::new();
for op in initial_circuit.gates() {
if op.qubits().is_empty() {
continue;
}
if op.name() == "H" && !op.qubits().is_empty() {
let _ = result.h(op.qubits()[0]);
} else if op.name() == "X" && !op.qubits().is_empty() {
let _ = result.x(op.qubits()[0]);
} else if op.name() == "Y" && !op.qubits().is_empty() {
let _ = result.y(op.qubits()[0]);
} else if op.name() == "Z" && !op.qubits().is_empty() {
let _ = result.z(op.qubits()[0]);
} else if op.name() == "S" && !op.qubits().is_empty() {
let _ = result.s(op.qubits()[0]);
} else if op.name() == "T" && !op.qubits().is_empty() {
let _ = result.t(op.qubits()[0]);
} else if op.name() == "CNOT" && op.qubits().len() >= 2 {
let _ = result.cnot(op.qubits()[0], op.qubits()[1]);
} else if op.name() == "CZ" && op.qubits().len() >= 2 {
let _ = result.cz(op.qubits()[0], op.qubits()[1]);
} else if op.name() == "CY" && op.qubits().len() >= 2 {
let _ = result.cy(op.qubits()[0], op.qubits()[1]);
} else if op.name() == "SWAP" && op.qubits().len() >= 2 {
let _ = result.swap(op.qubits()[0], op.qubits()[1]);
}
}
let encoder = code.encode_circuit(logical_qubits, ancilla_qubits)?;
for op in encoder.gates() {
if op.qubits().is_empty() {
continue;
}
if op.name() == "H" && !op.qubits().is_empty() {
let _ = result.h(op.qubits()[0]);
} else if op.name() == "X" && !op.qubits().is_empty() {
let _ = result.x(op.qubits()[0]);
} else if op.name() == "Y" && !op.qubits().is_empty() {
let _ = result.y(op.qubits()[0]);
} else if op.name() == "Z" && !op.qubits().is_empty() {
let _ = result.z(op.qubits()[0]);
} else if op.name() == "CNOT" && op.qubits().len() >= 2 {
let _ = result.cnot(op.qubits()[0], op.qubits()[1]);
} else if op.name() == "CZ" && op.qubits().len() >= 2 {
let _ = result.cz(op.qubits()[0], op.qubits()[1]);
}
}
let mut encoded_qubits = logical_qubits.to_vec();
encoded_qubits.extend_from_slice(ancilla_qubits);
let correction = code.decode_circuit(&encoded_qubits, syndrome_qubits)?;
for op in correction.gates() {
if op.qubits().is_empty() {
continue;
}
if op.name() == "H" && !op.qubits().is_empty() {
let _ = result.h(op.qubits()[0]);
} else if op.name() == "X" && !op.qubits().is_empty() {
let _ = result.x(op.qubits()[0]);
} else if op.name() == "Y" && !op.qubits().is_empty() {
let _ = result.y(op.qubits()[0]);
} else if op.name() == "Z" && !op.qubits().is_empty() {
let _ = result.z(op.qubits()[0]);
} else if op.name() == "CNOT" && op.qubits().len() >= 2 {
let _ = result.cnot(op.qubits()[0], op.qubits()[1]);
} else if op.name() == "CZ" && op.qubits().len() >= 2 {
let _ = result.cz(op.qubits()[0], op.qubits()[1]);
}
}
Ok(result)
}
pub fn analyze_correction_quality(
ideal_state: &[scirs2_core::Complex64],
noisy_state: &[scirs2_core::Complex64],
corrected_state: &[scirs2_core::Complex64],
) -> Result<(f64, f64)> {
let fidelity_before = calculate_fidelity(ideal_state, noisy_state)?;
let fidelity_after = calculate_fidelity(ideal_state, corrected_state)?;
Ok((fidelity_before, fidelity_after))
}
pub fn calculate_fidelity(
state1: &[scirs2_core::Complex64],
state2: &[scirs2_core::Complex64],
) -> Result<f64> {
use scirs2_core::Complex64;
if state1.len() != state2.len() {
return Err(SimulatorError::DimensionMismatch(format!(
"States have different dimensions: {} vs {}",
state1.len(),
state2.len()
)));
}
let mut inner_product = Complex64::new(0.0, 0.0);
for (a1, a2) in state1.iter().zip(state2.iter()) {
inner_product += a1.conj() * a2;
}
Ok(inner_product.norm_sqr())
}
}