1use crate::decoder::{Bit, DecodeResult, Decoder, DecoderRunner, Mod2Mul, SparseBitMatrix};
12use serde::{Deserialize, Serialize};
13
14use indicatif::{ParallelProgressIterator, ProgressFinish, ProgressIterator, ProgressStyle};
15use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis};
16use rayon::prelude::*;
17
18use std::sync::Arc;
19
20pub trait ObservableDecoder: Decoder {
21 fn observable_error_matrix(&self) -> Arc<SparseBitMatrix>;
23
24 fn compute_observables(&self, errors: ArrayView1<Bit>) -> Array1<Bit> {
26 self.observable_error_matrix().mul_mod2(&errors.to_owned())
27 }
28
29 fn decode_observables(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
30 let errors = self.decode(detectors);
31 self.compute_observables(errors.view())
32 }
33}
34
35#[derive(Clone)]
36pub struct ObservableDecoderRunner<'a> {
37 decoder: Box<dyn Decoder + Send + 'a>,
38 observable_error_matrix: Arc<SparseBitMatrix>,
39 include_decode_result: bool,
40}
41
42impl<'a> ObservableDecoderRunner<'a> {
43 pub fn new(
44 decoder: Box<dyn Decoder + Send + 'a>,
45 observable_error_matrix: Arc<SparseBitMatrix>,
46 include_decode_result: bool,
47 ) -> Self {
48 ObservableDecoderRunner {
49 decoder,
50 observable_error_matrix,
51 include_decode_result,
52 }
53 }
54
55 pub fn get_decoder(&self) -> &dyn Decoder {
56 self.decoder.as_ref()
57 }
58
59 pub fn get_decoder_mut(&mut self) -> &mut dyn Decoder {
60 self.decoder.as_mut()
61 }
62
63 pub fn decode_observables(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
64 let decode_result = self.decode(detectors.view());
65 self.compute_observables(decode_result.view())
66 }
67
68 pub fn decode_observables_detailed(
69 &mut self,
70 detectors: ArrayView1<Bit>,
71 ) -> ObservableDecodeResult {
72 let decode_result = self.decode_detailed(detectors.view());
73 let observables = self.compute_observables(decode_result.decoding.view());
74
75 ObservableDecodeResult {
76 observables,
77 converged: decode_result.success,
78 iterations: decode_result.iterations,
79 true_decoding: None,
80 physical_decode_result: if self.include_decode_result {
81 Some(decode_result)
82 } else {
83 None
84 },
85 }
86 }
87
88 pub fn decode_observables_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
89 let arrs: Vec<Array1<Bit>> = detectors
90 .axis_iter(Axis(0))
91 .map(|row| self.decode_observables(row))
92 .collect();
93 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
94 }
95
96 pub fn decode_observables_batch_progress_bar(
97 &mut self,
98 detectors: ArrayView2<Bit>,
99 leave_progress_bar_on_finish: bool,
100 ) -> Array2<Bit> {
101 let finish_mode = match leave_progress_bar_on_finish {
102 true => ProgressFinish::AndLeave,
103 false => ProgressFinish::AndClear,
104 };
105
106 let arrs: Vec<Array1<Bit>> = detectors
107 .axis_iter(Axis(0))
108 .progress_with_style(self.get_progress_bar_style())
109 .with_finish(finish_mode)
110 .map(|row| self.decode_observables(row))
111 .collect();
112 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
113 }
114
115 pub fn decode_observables_detailed_batch(
116 &mut self,
117 detectors: ArrayView2<Bit>,
118 ) -> Vec<ObservableDecodeResult> {
119 detectors
120 .axis_iter(Axis(0))
121 .map(|row| self.decode_observables_detailed(row))
122 .collect()
123 }
124
125 pub fn par_decode_observables_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
126 let arrs: Vec<Array1<Bit>> = detectors
127 .axis_iter(Axis(0))
128 .into_par_iter()
129 .map_with(
130 || self.clone(),
131 |decoder, row| decoder().decode_observables(row),
132 )
133 .collect();
134 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
135 }
136
137 pub fn par_decode_observables_batch_progress_bar(
138 &mut self,
139 detectors: ArrayView2<Bit>,
140 leave_progress_bar_on_finish: bool,
141 ) -> Array2<Bit> {
142 let finish_mode = match leave_progress_bar_on_finish {
143 true => ProgressFinish::AndLeave,
144 false => ProgressFinish::AndClear,
145 };
146
147 let arrs: Vec<Array1<Bit>> = detectors
148 .axis_iter(Axis(0))
149 .into_par_iter()
150 .progress_with_style(self.get_progress_bar_style())
151 .with_finish(finish_mode)
152 .map_with(
153 || self.clone(),
154 |decoder, row| decoder().decode_observables(row),
155 )
156 .collect();
157 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
158 }
159
160 pub fn par_decode_observables_detailed_batch(
161 &mut self,
162 detectors: ArrayView2<Bit>,
163 ) -> Vec<ObservableDecodeResult> {
164 detectors
165 .axis_iter(Axis(0))
166 .into_par_iter()
167 .map_with(
168 || self.clone(),
169 |decoder, row| decoder().decode_observables_detailed(row),
170 )
171 .collect()
172 }
173
174 pub fn decode_observables_detailed_batch_progress_bar(
176 &mut self,
177 detectors: ArrayView2<Bit>,
178 leave_progress_bar_on_finish: bool,
179 ) -> Vec<ObservableDecodeResult> {
180 let finish_mode = match leave_progress_bar_on_finish {
181 true => ProgressFinish::AndLeave,
182 false => ProgressFinish::AndClear,
183 };
184
185 detectors
186 .axis_iter(Axis(0))
187 .progress_with_style(self.get_progress_bar_style())
188 .with_finish(finish_mode)
189 .map(|row| self.decode_observables_detailed(row))
190 .collect()
191 }
192
193 pub fn par_decode_observables_detailed_batch_progress_bar(
194 &mut self,
195 detectors: ArrayView2<Bit>,
196 leave_progress_bar_on_finish: bool,
197 ) -> Vec<ObservableDecodeResult> {
198 let finish_mode = match leave_progress_bar_on_finish {
199 true => ProgressFinish::AndLeave,
200 false => ProgressFinish::AndClear,
201 };
202
203 detectors
204 .axis_iter(Axis(0))
205 .into_par_iter()
206 .progress_with_style(self.get_progress_bar_style())
207 .with_finish(finish_mode)
208 .map_with(
209 || self.clone(),
210 |decoder, row| decoder().decode_observables_detailed(row),
211 )
212 .collect()
213 }
214
215 pub fn from_errors_decode_observables(&mut self, errors: ArrayView1<Bit>) -> Array1<Bit> {
216 let detectors = self.get_detectors(errors);
217 let decode_result = self.decode(detectors.view());
218 self.compute_observables(decode_result.view())
219 }
220
221 pub fn from_errors_decode_observables_detailed(
222 &mut self,
223 errors: ArrayView1<Bit>,
224 ) -> ObservableDecodeResult {
225 let detectors = self.get_detectors(errors);
226 let decode_result = self.decode_detailed(detectors.view());
227 let observables = self.compute_observables(errors);
228 let decoded_observables = self.compute_observables(decode_result.decoding.view());
229
230 let error_detected: bool = observables != decoded_observables;
231 let error_mismatch_detected: bool = errors != decode_result.decoding;
232 let better_decoding_quality: bool =
233 decode_result.decoding_quality < self.get_decoding_quality(errors.clone().view());
234
235 let unconverged_no_error: bool = !error_detected && !decode_result.success;
236 let better_decoding_quality_error: bool = error_detected && better_decoding_quality;
237 let worse_decoding_quality_error: bool = error_detected && !better_decoding_quality;
238
239 ObservableDecodeResult {
240 observables,
241 converged: decode_result.success,
242 iterations: decode_result.iterations,
243 true_decoding: Some(TrueDecodingResults {
244 error_detected,
245 error_mismatch_detected,
246 better_decoding_quality_error,
247 worse_decoding_quality_error,
248 unconverged_no_error,
249 }),
250 physical_decode_result: if self.include_decode_result {
251 Some(decode_result)
252 } else {
253 None
254 },
255 }
256 }
257
258 pub fn from_errors_decode_observables_batch(&mut self, errors: ArrayView2<Bit>) -> Array2<Bit> {
259 let arrs: Vec<Array1<Bit>> = errors
260 .axis_iter(Axis(0))
261 .map(|row| self.from_errors_decode_observables(row))
262 .collect();
263 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
264 }
265
266 pub fn par_from_errors_decode_observables_batch(
267 &mut self,
268 errors: ArrayView2<Bit>,
269 ) -> Array2<Bit> {
270 let arrs: Vec<Array1<Bit>> = errors
271 .axis_iter(Axis(0))
272 .into_par_iter()
273 .map_with(
274 || self.clone(),
275 |decoder, row| decoder().from_errors_decode_observables(row),
276 )
277 .collect();
278 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
279 }
280
281 pub fn from_errors_decode_observables_batch_progress_bar(
282 &mut self,
283 errors: ArrayView2<Bit>,
284 leave_progress_bar_on_finish: bool,
285 ) -> Array2<Bit> {
286 let finish_mode = match leave_progress_bar_on_finish {
287 true => ProgressFinish::AndLeave,
288 false => ProgressFinish::AndClear,
289 };
290
291 let arrs: Vec<Array1<Bit>> = errors
292 .axis_iter(Axis(0))
293 .progress_with_style(self.get_progress_bar_style())
294 .with_finish(finish_mode)
295 .map(|row| self.from_errors_decode_observables(row))
296 .collect();
297 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
298 }
299
300 pub fn par_from_errors_decode_observables_batch_progress_bar(
301 &mut self,
302 errors: ArrayView2<Bit>,
303 leave_progress_bar_on_finish: bool,
304 ) -> Array2<Bit> {
305 let finish_mode = match leave_progress_bar_on_finish {
306 true => ProgressFinish::AndLeave,
307 false => ProgressFinish::AndClear,
308 };
309
310 let arrs: Vec<Array1<Bit>> = errors
311 .axis_iter(Axis(0))
312 .into_par_iter()
313 .progress_with_style(self.get_progress_bar_style())
314 .with_finish(finish_mode)
315 .map_with(
316 || self.clone(),
317 |decoder, row| decoder().from_errors_decode_observables(row),
318 )
319 .collect();
320 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
321 }
322
323 pub fn from_errors_decode_observables_detailed_batch(
324 &mut self,
325 errors: ArrayView2<Bit>,
326 ) -> Vec<ObservableDecodeResult> {
327 errors
328 .axis_iter(Axis(0))
329 .map(|row| self.from_errors_decode_observables_detailed(row))
330 .collect()
331 }
332
333 pub fn par_from_errors_decode_observables_detailed_batch(
334 &mut self,
335 errors: ArrayView2<Bit>,
336 ) -> Vec<ObservableDecodeResult> {
337 errors
338 .axis_iter(Axis(0))
339 .into_par_iter()
340 .map_with(
341 || self.clone(),
342 |decoder, row| decoder().from_errors_decode_observables_detailed(row),
343 )
344 .collect()
345 }
346
347 pub fn from_errors_decode_observables_detailed_batch_progress_bar(
349 &mut self,
350 errors: ArrayView2<Bit>,
351 leave_progress_bar_on_finish: bool,
352 ) -> Vec<ObservableDecodeResult> {
353 let finish_mode = match leave_progress_bar_on_finish {
354 true => ProgressFinish::AndLeave,
355 false => ProgressFinish::AndClear,
356 };
357
358 errors
359 .axis_iter(Axis(0))
360 .progress_with_style(self.get_progress_bar_style())
361 .with_finish(finish_mode)
362 .map(|row| self.from_errors_decode_observables_detailed(row))
363 .collect()
364 }
365
366 pub fn par_from_errors_decode_observables_detailed_batch_progress_bar(
367 &mut self,
368 errors: ArrayView2<Bit>,
369 leave_progress_bar_on_finish: bool,
370 ) -> Vec<ObservableDecodeResult> {
371 let finish_mode = match leave_progress_bar_on_finish {
372 true => ProgressFinish::AndLeave,
373 false => ProgressFinish::AndClear,
374 };
375
376 errors
377 .axis_iter(Axis(0))
378 .into_par_iter()
379 .progress_with_style(self.get_progress_bar_style())
380 .with_finish(finish_mode)
381 .map_with(
382 || self.clone(),
383 |decoder, row| decoder().from_errors_decode_observables_detailed(row),
384 )
385 .collect()
386 }
387
388 fn get_progress_bar_style(&self) -> ProgressStyle {
389 ProgressStyle::default_bar().template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({per_sec}, {eta})").unwrap()
390 }
391}
392
393impl DecoderRunner for ObservableDecoderRunner<'_> {}
394
395impl Decoder for ObservableDecoderRunner<'_> {
396 fn check_matrix(&self) -> Arc<SparseBitMatrix> {
397 self.get_decoder().check_matrix()
398 }
399 fn log_prior_ratios(&mut self) -> Array1<f64> {
400 self.get_decoder_mut().log_prior_ratios()
401 }
402 fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
403 self.get_decoder_mut().decode_detailed(detectors)
404 }
405 fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
406 self.get_decoder_mut().get_decoding_quality(errors)
407 }
408}
409
410impl ObservableDecoder for ObservableDecoderRunner<'_> {
411 fn observable_error_matrix(&self) -> Arc<SparseBitMatrix> {
412 self.observable_error_matrix.clone()
413 }
414}
415
416#[derive(Serialize, Deserialize, Debug)]
417pub struct ObservableDecodeResult {
418 pub observables: Array1<Bit>,
419 pub converged: bool,
420 pub iterations: usize,
421 pub true_decoding: Option<TrueDecodingResults>,
422 pub physical_decode_result: Option<DecodeResult>,
423}
424
425#[derive(Serialize, Deserialize, Debug)]
426pub struct TrueDecodingResults {
427 pub error_detected: bool,
428 pub error_mismatch_detected: bool,
429 pub unconverged_no_error: bool,
430 pub better_decoding_quality_error: bool,
431 pub worse_decoding_quality_error: bool,
432}
433
434#[cfg(test)]
435mod tests {
436
437 use super::*;
438 use ndarray::prelude::*;
439
440 use crate::bp::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
441 use crate::dem::DetectorErrorModel;
442 use crate::utilities::test::get_test_data_path;
443 use ndarray::Array2;
444 use ndarray_npy::read_npy;
445
446 #[test]
447 fn min_sum_decode_144_12_12() {
448 let resources = get_test_data_path();
449 let code_144_12_12 =
450 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
451 let errors_144_12_12: Array2<Bit> =
452 read_npy(resources.join("144_12_12_errors.npy")).expect("Unable to open file");
453 let bp_config_144_12_12 = MinSumDecoderConfig {
454 error_priors: code_144_12_12.error_priors,
455 alpha: Some(0.),
456 ..Default::default()
457 };
458
459 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
460 let bp_config = Arc::new(bp_config_144_12_12);
461 let decoder_144_12_12: Box<MinSumBPDecoder<f64>> =
462 Box::new(MinSumBPDecoder::new(check_matrix, bp_config));
463
464 let obs_matrix = Arc::new(code_144_12_12.observable_error_matrix);
465 let mut observable_decoder =
466 ObservableDecoderRunner::new(decoder_144_12_12, obs_matrix, true);
467
468 let num_errors = 100;
469 let errors_slice = errors_144_12_12.slice(s![..num_errors, ..]);
470 let results =
471 observable_decoder.par_from_errors_decode_observables_detailed_batch(errors_slice);
472
473 assert!(
475 results
476 .iter()
477 .map(|x| x.physical_decode_result.as_ref().unwrap().success as usize)
478 .sum::<usize>() as f64
479 >= (errors_slice.shape()[0] as f64) * 0.93
480 );
481
482 assert!(
483 results
484 .iter()
485 .map(|x| x.true_decoding.as_ref().unwrap().error_mismatch_detected as usize)
486 .sum::<usize>() as f64
487 >= (errors_slice.shape()[0] as f64) * 0.96
488 );
489
490 assert_eq!(
491 results[0].observables,
492 array![1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1]
493 );
494
495 assert!(
496 results
497 .iter()
498 .map(|x| x.true_decoding.as_ref().unwrap().error_detected as usize)
499 .sum::<usize>() as f64
500 <= (errors_slice.shape()[0] as f64) * 0.07
501 );
502
503 let direct_results = observable_decoder.decode_observables(
504 observable_decoder
505 .get_detectors(errors_144_12_12.row(0))
506 .view(),
507 );
508 assert_eq!(results[0].observables, direct_results);
509
510 let results2 = observable_decoder.par_decode_observables_batch(
511 observable_decoder
512 .get_detectors_batch(errors_144_12_12.view())
513 .view(),
514 );
515
516 assert_eq!(results2.row(0), direct_results);
517 }
518}