quantrs2_core/error_correction/
decoders.rs1use super::pauli::{Pauli, PauliString};
4use super::stabilizer::StabilizerCode;
5use super::surface_code::SurfaceCode;
6use super::SyndromeDecoder;
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use std::collections::HashMap;
9
10pub struct LookupDecoder {
12 syndrome_table: HashMap<Vec<bool>, PauliString>,
14}
15
16impl LookupDecoder {
17 pub fn new(code: &StabilizerCode) -> QuantRS2Result<Self> {
19 let mut syndrome_table = HashMap::new();
20
21 let max_weight = (code.d - 1) / 2;
23 let all_errors = Self::generate_pauli_errors(code.n, max_weight);
24
25 for error in all_errors {
26 let syndrome = code.syndrome(&error)?;
27
28 syndrome_table
30 .entry(syndrome)
31 .and_modify(|e: &mut PauliString| {
32 if error.weight() < e.weight() {
33 *e = error.clone();
34 }
35 })
36 .or_insert(error);
37 }
38
39 Ok(Self { syndrome_table })
40 }
41
42 fn generate_pauli_errors(n: usize, max_weight: usize) -> Vec<PauliString> {
44 let mut errors = vec![PauliString::identity(n)];
45
46 for weight in 1..=max_weight {
47 let weight_errors = Self::generate_weight_k_errors(n, weight);
48 errors.extend(weight_errors);
49 }
50
51 errors
52 }
53
54 fn generate_weight_k_errors(n: usize, k: usize) -> Vec<PauliString> {
56 let mut errors = Vec::new();
57 let paulis = [Pauli::X, Pauli::Y, Pauli::Z];
58
59 let positions = Self::combinations(n, k);
61
62 for pos_set in positions {
63 let pauli_combinations = Self::cartesian_power(&paulis, k);
65
66 for pauli_combo in pauli_combinations {
67 let mut error_paulis = vec![Pauli::I; n];
68 for (i, &pos) in pos_set.iter().enumerate() {
69 error_paulis[pos] = pauli_combo[i];
70 }
71 errors.push(PauliString::new(error_paulis));
72 }
73 }
74
75 errors
76 }
77
78 fn combinations(n: usize, k: usize) -> Vec<Vec<usize>> {
80 let mut result = Vec::new();
81 let mut combo = (0..k).collect::<Vec<_>>();
82
83 loop {
84 result.push(combo.clone());
85
86 let mut i = k;
88 while i > 0 && (i == k || combo[i] == n - k + i) {
89 i -= 1;
90 }
91
92 if i == 0 && combo[0] == n - k {
93 break;
94 }
95
96 combo[i] += 1;
98 for j in i + 1..k {
99 combo[j] = combo[j - 1] + 1;
100 }
101 }
102
103 result
104 }
105
106 fn cartesian_power<T: Clone>(set: &[T], k: usize) -> Vec<Vec<T>> {
108 if k == 0 {
109 return vec![vec![]];
110 }
111
112 let mut result = Vec::new();
113 let smaller = Self::cartesian_power(set, k - 1);
114
115 for item in set {
116 for mut combo in smaller.clone() {
117 combo.push(item.clone());
118 result.push(combo);
119 }
120 }
121
122 result
123 }
124}
125
126impl SyndromeDecoder for LookupDecoder {
127 fn decode(&self, syndrome: &[bool]) -> QuantRS2Result<PauliString> {
128 self.syndrome_table
129 .get(syndrome)
130 .cloned()
131 .ok_or_else(|| QuantRS2Error::InvalidInput("Unknown syndrome".to_string()))
132 }
133}
134
135pub struct MWPMDecoder {
137 surface_code: SurfaceCode,
138}
139
140impl MWPMDecoder {
141 pub const fn new(surface_code: SurfaceCode) -> Self {
143 Self { surface_code }
144 }
145
146 pub fn decode_syndrome(
148 &self,
149 x_syndrome: &[bool],
150 z_syndrome: &[bool],
151 ) -> QuantRS2Result<PauliString> {
152 let n = self.surface_code.qubit_map.len();
153 let mut error_paulis = vec![Pauli::I; n];
154
155 let z_defects = self.find_defects(z_syndrome, &self.surface_code.z_stabilizers);
157 let x_corrections = self.minimum_weight_matching(&z_defects, Pauli::X)?;
158
159 for (qubit, pauli) in x_corrections {
160 error_paulis[qubit] = pauli;
161 }
162
163 let x_defects = self.find_defects(x_syndrome, &self.surface_code.x_stabilizers);
165 let z_corrections = self.minimum_weight_matching(&x_defects, Pauli::Z)?;
166
167 for (qubit, pauli) in z_corrections {
168 if error_paulis[qubit] == Pauli::I {
169 error_paulis[qubit] = pauli;
170 } else {
171 error_paulis[qubit] = Pauli::Y;
173 }
174 }
175
176 Ok(PauliString::new(error_paulis))
177 }
178
179 fn find_defects(&self, syndrome: &[bool], _stabilizers: &[Vec<usize>]) -> Vec<usize> {
181 syndrome
182 .iter()
183 .enumerate()
184 .filter_map(|(i, &s)| if s { Some(i) } else { None })
185 .collect()
186 }
187
188 fn minimum_weight_matching(
190 &self,
191 defects: &[usize],
192 error_type: Pauli,
193 ) -> QuantRS2Result<Vec<(usize, Pauli)>> {
194 let mut corrections = Vec::new();
196
197 if defects.len() % 2 != 0 {
198 return Err(QuantRS2Error::InvalidInput(
199 "Odd number of defects".to_string(),
200 ));
201 }
202
203 let mut paired = vec![false; defects.len()];
205
206 for i in 0..defects.len() {
207 if paired[i] {
208 continue;
209 }
210
211 let mut min_dist = usize::MAX;
213 let mut min_j = i;
214
215 for j in i + 1..defects.len() {
216 if !paired[j] {
217 let dist = self.defect_distance(defects[i], defects[j]);
218 if dist < min_dist {
219 min_dist = dist;
220 min_j = j;
221 }
222 }
223 }
224
225 if min_j != i {
226 paired[i] = true;
227 paired[min_j] = true;
228
229 let path = self.shortest_path(defects[i], defects[min_j])?;
231 for qubit in path {
232 corrections.push((qubit, error_type));
233 }
234 }
235 }
236
237 Ok(corrections)
238 }
239
240 const fn defect_distance(&self, defect1: usize, defect2: usize) -> usize {
242 (defect1 as isize - defect2 as isize).unsigned_abs()
244 }
245
246 fn shortest_path(&self, start: usize, end: usize) -> QuantRS2Result<Vec<usize>> {
248 let path = if start < end {
250 (start..=end).collect()
251 } else {
252 (end..=start).rev().collect()
253 };
254
255 Ok(path)
256 }
257}