use super::pauli::{Pauli, PauliString};
use super::stabilizer::StabilizerCode;
use super::surface_code::SurfaceCode;
use super::SyndromeDecoder;
use crate::error::{QuantRS2Error, QuantRS2Result};
use std::collections::HashMap;
pub struct LookupDecoder {
syndrome_table: HashMap<Vec<bool>, PauliString>,
}
impl LookupDecoder {
pub fn new(code: &StabilizerCode) -> QuantRS2Result<Self> {
let mut syndrome_table = HashMap::new();
let max_weight = (code.d - 1) / 2;
let all_errors = Self::generate_pauli_errors(code.n, max_weight);
for error in all_errors {
let syndrome = code.syndrome(&error)?;
syndrome_table
.entry(syndrome)
.and_modify(|e: &mut PauliString| {
if error.weight() < e.weight() {
*e = error.clone();
}
})
.or_insert(error);
}
Ok(Self { syndrome_table })
}
fn generate_pauli_errors(n: usize, max_weight: usize) -> Vec<PauliString> {
let mut errors = vec![PauliString::identity(n)];
for weight in 1..=max_weight {
let weight_errors = Self::generate_weight_k_errors(n, weight);
errors.extend(weight_errors);
}
errors
}
fn generate_weight_k_errors(n: usize, k: usize) -> Vec<PauliString> {
let mut errors = Vec::new();
let paulis = [Pauli::X, Pauli::Y, Pauli::Z];
let positions = Self::combinations(n, k);
for pos_set in positions {
let pauli_combinations = Self::cartesian_power(&paulis, k);
for pauli_combo in pauli_combinations {
let mut error_paulis = vec![Pauli::I; n];
for (i, &pos) in pos_set.iter().enumerate() {
error_paulis[pos] = pauli_combo[i];
}
errors.push(PauliString::new(error_paulis));
}
}
errors
}
fn combinations(n: usize, k: usize) -> Vec<Vec<usize>> {
let mut result = Vec::new();
let mut combo = (0..k).collect::<Vec<_>>();
loop {
result.push(combo.clone());
let mut i = k;
while i > 0 && (i == k || combo[i] == n - k + i) {
i -= 1;
}
if i == 0 && combo[0] == n - k {
break;
}
combo[i] += 1;
for j in i + 1..k {
combo[j] = combo[j - 1] + 1;
}
}
result
}
fn cartesian_power<T: Clone>(set: &[T], k: usize) -> Vec<Vec<T>> {
if k == 0 {
return vec![vec![]];
}
let mut result = Vec::new();
let smaller = Self::cartesian_power(set, k - 1);
for item in set {
for mut combo in smaller.clone() {
combo.push(item.clone());
result.push(combo);
}
}
result
}
}
impl SyndromeDecoder for LookupDecoder {
fn decode(&self, syndrome: &[bool]) -> QuantRS2Result<PauliString> {
self.syndrome_table
.get(syndrome)
.cloned()
.ok_or_else(|| QuantRS2Error::InvalidInput("Unknown syndrome".to_string()))
}
}
pub struct MWPMDecoder {
surface_code: SurfaceCode,
}
impl MWPMDecoder {
pub const fn new(surface_code: SurfaceCode) -> Self {
Self { surface_code }
}
pub fn decode_syndrome(
&self,
x_syndrome: &[bool],
z_syndrome: &[bool],
) -> QuantRS2Result<PauliString> {
let n = self.surface_code.qubit_map.len();
let mut error_paulis = vec![Pauli::I; n];
let z_defects = self.find_defects(z_syndrome, &self.surface_code.z_stabilizers);
let x_corrections = self.minimum_weight_matching(&z_defects, Pauli::X)?;
for (qubit, pauli) in x_corrections {
error_paulis[qubit] = pauli;
}
let x_defects = self.find_defects(x_syndrome, &self.surface_code.x_stabilizers);
let z_corrections = self.minimum_weight_matching(&x_defects, Pauli::Z)?;
for (qubit, pauli) in z_corrections {
if error_paulis[qubit] == Pauli::I {
error_paulis[qubit] = pauli;
} else {
error_paulis[qubit] = Pauli::Y;
}
}
Ok(PauliString::new(error_paulis))
}
fn find_defects(&self, syndrome: &[bool], _stabilizers: &[Vec<usize>]) -> Vec<usize> {
syndrome
.iter()
.enumerate()
.filter_map(|(i, &s)| if s { Some(i) } else { None })
.collect()
}
fn minimum_weight_matching(
&self,
defects: &[usize],
error_type: Pauli,
) -> QuantRS2Result<Vec<(usize, Pauli)>> {
let mut corrections = Vec::new();
if defects.len() % 2 != 0 {
return Err(QuantRS2Error::InvalidInput(
"Odd number of defects".to_string(),
));
}
let mut paired = vec![false; defects.len()];
for i in 0..defects.len() {
if paired[i] {
continue;
}
let mut min_dist = usize::MAX;
let mut min_j = i;
for j in i + 1..defects.len() {
if !paired[j] {
let dist = self.defect_distance(defects[i], defects[j]);
if dist < min_dist {
min_dist = dist;
min_j = j;
}
}
}
if min_j != i {
paired[i] = true;
paired[min_j] = true;
let path = self.shortest_path(defects[i], defects[min_j])?;
for qubit in path {
corrections.push((qubit, error_type));
}
}
}
Ok(corrections)
}
const fn defect_distance(&self, defect1: usize, defect2: usize) -> usize {
(defect1 as isize - defect2 as isize).unsigned_abs()
}
fn shortest_path(&self, start: usize, end: usize) -> QuantRS2Result<Vec<usize>> {
let path = if start < end {
(start..=end).collect()
} else {
(end..=start).rev().collect()
};
Ok(path)
}
}