use super::min_matching::min_weight_perfect_matching;
use crate::error::{QuantRS2Error, QuantRS2Result};
use crate::error_correction::pauli::{Pauli, PauliString};
use crate::error_correction::rotated_surface_code::RotatedSurfaceCode;
use crate::error_correction::SyndromeDecoder;
const VIRT_TOP: usize = 0;
const VIRT_BOTTOM: usize = 1;
const VIRT_LEFT: usize = 2;
const VIRT_RIGHT: usize = 3;
pub struct MwpmSurfaceDecoder {
pub code: RotatedSurfaceCode,
}
impl MwpmSurfaceDecoder {
pub fn new(code: RotatedSurfaceCode) -> Self {
Self { code }
}
fn decode_one_type(
&self,
syndrome: &[bool],
stabilizers: &[Vec<usize>],
correction_type: Pauli,
) -> QuantRS2Result<Vec<(usize, Pauli)>> {
let d = self.code.distance;
let n_stabs = stabilizers.len();
let defects: Vec<usize> = syndrome
.iter()
.enumerate()
.filter(|(_, &s)| s)
.map(|(i, _)| i)
.collect();
if defects.is_empty() {
return Ok(Vec::new());
}
let total_nodes = n_stabs + 4;
let v_top = n_stabs + VIRT_TOP;
let v_bot = n_stabs + VIRT_BOTTOM;
let v_left = n_stabs + VIRT_LEFT;
let v_right = n_stabs + VIRT_RIGHT;
let n_qubits = d * d;
let mut qubit_edge: Vec<(usize, usize)> = vec![(usize::MAX, usize::MAX); n_qubits];
for q in 0..n_qubits {
let containing: Vec<usize> = stabilizers
.iter()
.enumerate()
.filter(|(_, s)| s.contains(&q))
.map(|(i, _)| i)
.collect();
let boundary_node = self.qubit_boundary_node(q, d, n_stabs);
match containing.len() {
2 => {
qubit_edge[q] = (containing[0], containing[1]);
}
1 => {
qubit_edge[q] = (containing[0], boundary_node);
}
0 => {
qubit_edge[q] = (boundary_node, boundary_node);
}
_ => {
qubit_edge[q] = (containing[0], containing[1]);
}
}
}
let mut dist = vec![vec![f64::INFINITY; total_nodes]; total_nodes];
let mut parent = vec![vec![usize::MAX; total_nodes]; total_nodes];
for i in 0..total_nodes {
dist[i][i] = 0.0;
}
for q in 0..n_qubits {
let (a, b) = qubit_edge[q];
if a == usize::MAX || b == usize::MAX {
continue;
}
if a == b {
continue;
}
if dist[a][b] > 1.0 {
dist[a][b] = 1.0;
dist[b][a] = 1.0;
parent[a][b] = b;
parent[b][a] = a;
}
}
for &va in &[v_top, v_bot, v_left, v_right] {
for &vb in &[v_top, v_bot, v_left, v_right] {
if va != vb {
dist[va][vb] = 0.0;
dist[vb][va] = 0.0;
parent[va][vb] = vb;
parent[vb][va] = va;
}
}
}
for k in 0..total_nodes {
for i in 0..total_nodes {
for j in 0..total_nodes {
let through_k = dist[i][k] + dist[k][j];
if through_k < dist[i][j] {
dist[i][j] = through_k;
parent[i][j] = parent[i][k];
}
}
}
}
let n_defects = defects.len();
let total_mwpm = if n_defects % 2 != 0 {
n_defects + 1
} else {
n_defects
};
if total_mwpm > 24 {
return Err(QuantRS2Error::InvalidInput(
"Too many defects for bitmask-DP decoder: use d ≤ 7".to_string(),
));
}
let mut mwpm_edges: Vec<(usize, usize, f64)> = Vec::new();
for i in 0..n_defects {
for j in i + 1..n_defects {
let a = defects[i]; let b = defects[j];
let d_ab = dist[a][b];
if d_ab.is_finite() {
mwpm_edges.push((i, j, d_ab));
}
}
if total_mwpm > n_defects {
let a = defects[i];
let d_boundary = [v_top, v_bot, v_left, v_right]
.iter()
.map(|&vb| dist[a][vb])
.fold(f64::INFINITY, f64::min);
if d_boundary.is_finite() {
mwpm_edges.push((i, n_defects, d_boundary));
} else {
mwpm_edges.push((i, n_defects, (d as f64) * 2.0));
}
}
}
let matching = min_weight_perfect_matching(total_mwpm, &mwpm_edges)
.map_err(|e| QuantRS2Error::InvalidInput(format!("MWPM failed: {e}")))?
.ok_or_else(|| QuantRS2Error::InvalidInput("No perfect matching found".to_string()))?;
let mut corrections: Vec<(usize, Pauli)> = Vec::new();
let virt_mwpm = n_defects;
for (u, v) in &matching {
let (real_u, real_v_opt) = if *u == virt_mwpm {
(*v, None)
} else if *v == virt_mwpm {
(*u, None)
} else {
(*u, Some(*v))
};
let a = defects[real_u];
if let Some(real_v) = real_v_opt {
let b = defects[real_v];
let path_qubits =
self.path_between_nodes(a, b, &parent, &qubit_edge, n_qubits, total_nodes);
for q in path_qubits {
corrections.push((q, correction_type));
}
} else {
let nearest_virt = [v_top, v_bot, v_left, v_right]
.iter()
.copied()
.min_by(|&va, &vb| {
dist[a][va]
.partial_cmp(&dist[a][vb])
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(v_top);
let path_qubits = self.path_between_nodes(
a,
nearest_virt,
&parent,
&qubit_edge,
n_qubits,
total_nodes,
);
for q in path_qubits {
corrections.push((q, correction_type));
}
}
}
Ok(corrections)
}
fn qubit_boundary_node(&self, q: usize, d: usize, n_stabs: usize) -> usize {
let r = q / d;
let c = q % d;
if r == 0 {
n_stabs + VIRT_TOP
} else if r == d - 1 {
n_stabs + VIRT_BOTTOM
} else if c == 0 {
n_stabs + VIRT_LEFT
} else if c == d - 1 {
n_stabs + VIRT_RIGHT
} else {
n_stabs + VIRT_TOP
}
}
fn path_between_nodes(
&self,
a: usize,
b: usize,
parent: &[Vec<usize>],
qubit_edge: &[(usize, usize)],
n_qubits: usize,
_total_nodes: usize,
) -> Vec<usize> {
if a == b {
return Vec::new();
}
let mut path_nodes = Vec::new();
let mut cur = a;
let mut safety = 0;
loop {
path_nodes.push(cur);
if cur == b {
break;
}
let next = parent[cur][b];
if next == usize::MAX || next == cur {
break; }
cur = next;
safety += 1;
if safety > 100 {
break; }
}
let mut qubits = Vec::new();
for i in 0..path_nodes.len().saturating_sub(1) {
let u = path_nodes[i];
let v = path_nodes[i + 1];
for q in 0..n_qubits {
let (ea, eb) = qubit_edge[q];
if (ea == u && eb == v) || (ea == v && eb == u) {
qubits.push(q);
break;
}
}
}
qubits
}
}
impl SyndromeDecoder for MwpmSurfaceDecoder {
fn decode(&self, syndrome: &[bool]) -> QuantRS2Result<PauliString> {
let n = self.code.n_data_qubits();
let x_stabs = self.code.x_stabilizers();
let z_stabs = self.code.z_stabilizers();
let n_x = x_stabs.len();
let n_z = z_stabs.len();
if syndrome.len() != n_x + n_z {
return Err(QuantRS2Error::InvalidInput(format!(
"Syndrome length {} does not match expected {} (n_x={n_x}, n_z={n_z})",
syndrome.len(),
n_x + n_z
)));
}
let x_syndrome = &syndrome[..n_x]; let z_syndrome = &syndrome[n_x..];
let mut error_paulis = vec![Pauli::I; n];
let x_corrections = self.decode_one_type(z_syndrome, &z_stabs, Pauli::X)?;
for (qubit, pauli) in x_corrections {
if qubit < n {
combine_pauli(&mut error_paulis[qubit], pauli);
}
}
let z_corrections = self.decode_one_type(x_syndrome, &x_stabs, Pauli::Z)?;
for (qubit, pauli) in z_corrections {
if qubit < n {
combine_pauli(&mut error_paulis[qubit], pauli);
}
}
Ok(PauliString::new(error_paulis))
}
}
fn combine_pauli(existing: &mut Pauli, new_pauli: Pauli) {
*existing = match (*existing, new_pauli) {
(Pauli::I, p) | (p, Pauli::I) => p,
(Pauli::X, Pauli::X) | (Pauli::Z, Pauli::Z) | (Pauli::Y, Pauli::Y) => Pauli::I,
(Pauli::X, Pauli::Z) | (Pauli::Z, Pauli::X) => Pauli::Y,
(Pauli::X, Pauli::Y) | (Pauli::Y, Pauli::X) => Pauli::Z,
(Pauli::Z, Pauli::Y) | (Pauli::Y, Pauli::Z) => Pauli::X,
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error_correction::rotated_surface_code::RotatedSurfaceCode;
fn is_logical_error(composed: &PauliString, code: &RotatedSurfaceCode) -> bool {
let lx = code.logical_x_operator();
let lz = code.logical_z_operator();
let anticommutes_lx = composed.commutes_with(&lz).is_ok_and(|c| !c);
let anticommutes_lz = composed.commutes_with(&lx).is_ok_and(|c| !c);
anticommutes_lx || anticommutes_lz
}
fn make_decoder(d: usize) -> MwpmSurfaceDecoder {
MwpmSurfaceDecoder::new(RotatedSurfaceCode::new(d))
}
#[test]
fn test_mwpm_no_errors_d3() {
let decoder = make_decoder(3);
let n_x = decoder.code.x_stabilizers().len();
let n_z = decoder.code.z_stabilizers().len();
let syndrome = vec![false; n_x + n_z];
let correction = decoder.decode(&syndrome).expect("Decoding should succeed");
assert_eq!(
correction.weight(),
0,
"No-error syndrome should yield identity correction"
);
}
#[test]
fn test_mwpm_single_x_error_all_qubits_d3() {
let code = RotatedSurfaceCode::new(3);
let decoder = MwpmSurfaceDecoder::new(code.clone());
let n = code.n_data_qubits();
for qubit in 0..n {
let mut paulis = vec![Pauli::I; n];
paulis[qubit] = Pauli::X;
let error = PauliString::new(paulis);
let syndrome = code.syndrome(&error).expect("syndrome ok");
let correction = decoder.decode(&syndrome).expect("decode ok");
let composed = error.multiply(&correction).expect("multiply ok");
assert!(
!is_logical_error(&composed, &code),
"MWPM: X error on qubit {qubit} should not cause logical error"
);
}
}
#[test]
fn test_mwpm_single_z_error_all_qubits_d3() {
let code = RotatedSurfaceCode::new(3);
let decoder = MwpmSurfaceDecoder::new(code.clone());
let n = code.n_data_qubits();
for qubit in 0..n {
let mut paulis = vec![Pauli::I; n];
paulis[qubit] = Pauli::Z;
let error = PauliString::new(paulis);
let syndrome = code.syndrome(&error).expect("syndrome ok");
let correction = decoder.decode(&syndrome).expect("decode ok");
let composed = error.multiply(&correction).expect("multiply ok");
assert!(
!is_logical_error(&composed, &code),
"MWPM: Z error on qubit {qubit} should not cause logical error"
);
}
}
#[test]
fn test_mwpm_no_errors_d5() {
let decoder = make_decoder(5);
let n_x = decoder.code.x_stabilizers().len();
let n_z = decoder.code.z_stabilizers().len();
let syndrome = vec![false; n_x + n_z];
let correction = decoder.decode(&syndrome).expect("decode ok");
assert_eq!(correction.weight(), 0);
}
#[test]
fn test_mwpm_wrong_syndrome_length() {
let decoder = make_decoder(3);
let result = decoder.decode(&[true, false]);
assert!(result.is_err(), "Wrong syndrome length should return Err");
}
}