Skip to main content

relay_bp/
observable_decoder.rs

1// (C) Copyright IBM 2025
2//
3// This code is licensed under the Apache License, Version 2.0. You may
4// obtain a copy of this license in the LICENSE.txt file in the root directory
5// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
6//
7// Any modifications or derivative works of this code must retain this
8// copyright notice, and modified files need to carry a notice indicating
9// that they have been altered from the originals.
10
11use crate::decoder::{Bit, DecodeResult, Decoder, DecoderRunner, Mod2Mul, SparseBitMatrix};
12use serde::{Deserialize, Serialize};
13
14use indicatif::{ParallelProgressIterator, ProgressFinish, ProgressIterator, ProgressStyle};
15use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis};
16use rayon::prelude::*;
17
18use std::sync::Arc;
19
20pub trait ObservableDecoder: Decoder {
21    /// The logical action matrix of the underlying trait.
22    fn observable_error_matrix(&self) -> Arc<SparseBitMatrix>;
23
24    /// Compute the logical errors from the logical action matrix.
25    fn compute_observables(&self, errors: ArrayView1<Bit>) -> Array1<Bit> {
26        self.observable_error_matrix().mul_mod2(&errors.to_owned())
27    }
28
29    fn decode_observables(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
30        let errors = self.decode(detectors);
31        self.compute_observables(errors.view())
32    }
33}
34
35#[derive(Clone)]
36pub struct ObservableDecoderRunner<'a> {
37    decoder: Box<dyn Decoder + Send + 'a>,
38    observable_error_matrix: Arc<SparseBitMatrix>,
39    include_decode_result: bool,
40}
41
42impl<'a> ObservableDecoderRunner<'a> {
43    pub fn new(
44        decoder: Box<dyn Decoder + Send + 'a>,
45        observable_error_matrix: Arc<SparseBitMatrix>,
46        include_decode_result: bool,
47    ) -> Self {
48        ObservableDecoderRunner {
49            decoder,
50            observable_error_matrix,
51            include_decode_result,
52        }
53    }
54
55    pub fn get_decoder(&self) -> &dyn Decoder {
56        self.decoder.as_ref()
57    }
58
59    pub fn get_decoder_mut(&mut self) -> &mut dyn Decoder {
60        self.decoder.as_mut()
61    }
62
63    pub fn decode_observables(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
64        let decode_result = self.decode(detectors.view());
65        self.compute_observables(decode_result.view())
66    }
67
68    pub fn decode_observables_detailed(
69        &mut self,
70        detectors: ArrayView1<Bit>,
71    ) -> ObservableDecodeResult {
72        let decode_result = self.decode_detailed(detectors.view());
73        let observables = self.compute_observables(decode_result.decoding.view());
74
75        ObservableDecodeResult {
76            observables,
77            converged: decode_result.success,
78            iterations: decode_result.iterations,
79            true_decoding: None,
80            physical_decode_result: if self.include_decode_result {
81                Some(decode_result)
82            } else {
83                None
84            },
85        }
86    }
87
88    pub fn decode_observables_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
89        let arrs: Vec<Array1<Bit>> = detectors
90            .axis_iter(Axis(0))
91            .map(|row| self.decode_observables(row))
92            .collect();
93        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
94    }
95
96    pub fn decode_observables_batch_progress_bar(
97        &mut self,
98        detectors: ArrayView2<Bit>,
99        leave_progress_bar_on_finish: bool,
100    ) -> Array2<Bit> {
101        let finish_mode = match leave_progress_bar_on_finish {
102            true => ProgressFinish::AndLeave,
103            false => ProgressFinish::AndClear,
104        };
105
106        let arrs: Vec<Array1<Bit>> = detectors
107            .axis_iter(Axis(0))
108            .progress_with_style(self.get_progress_bar_style())
109            .with_finish(finish_mode)
110            .map(|row| self.decode_observables(row))
111            .collect();
112        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
113    }
114
115    pub fn decode_observables_detailed_batch(
116        &mut self,
117        detectors: ArrayView2<Bit>,
118    ) -> Vec<ObservableDecodeResult> {
119        detectors
120            .axis_iter(Axis(0))
121            .map(|row| self.decode_observables_detailed(row))
122            .collect()
123    }
124
125    pub fn par_decode_observables_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
126        let arrs: Vec<Array1<Bit>> = detectors
127            .axis_iter(Axis(0))
128            .into_par_iter()
129            .map_with(
130                || self.clone(),
131                |decoder, row| decoder().decode_observables(row),
132            )
133            .collect();
134        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
135    }
136
137    pub fn par_decode_observables_batch_progress_bar(
138        &mut self,
139        detectors: ArrayView2<Bit>,
140        leave_progress_bar_on_finish: bool,
141    ) -> Array2<Bit> {
142        let finish_mode = match leave_progress_bar_on_finish {
143            true => ProgressFinish::AndLeave,
144            false => ProgressFinish::AndClear,
145        };
146
147        let arrs: Vec<Array1<Bit>> = detectors
148            .axis_iter(Axis(0))
149            .into_par_iter()
150            .progress_with_style(self.get_progress_bar_style())
151            .with_finish(finish_mode)
152            .map_with(
153                || self.clone(),
154                |decoder, row| decoder().decode_observables(row),
155            )
156            .collect();
157        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
158    }
159
160    pub fn par_decode_observables_detailed_batch(
161        &mut self,
162        detectors: ArrayView2<Bit>,
163    ) -> Vec<ObservableDecodeResult> {
164        detectors
165            .axis_iter(Axis(0))
166            .into_par_iter()
167            .map_with(
168                || self.clone(),
169                |decoder, row| decoder().decode_observables_detailed(row),
170            )
171            .collect()
172    }
173
174    /// Decode a batch displaying a progress bar
175    pub fn decode_observables_detailed_batch_progress_bar(
176        &mut self,
177        detectors: ArrayView2<Bit>,
178        leave_progress_bar_on_finish: bool,
179    ) -> Vec<ObservableDecodeResult> {
180        let finish_mode = match leave_progress_bar_on_finish {
181            true => ProgressFinish::AndLeave,
182            false => ProgressFinish::AndClear,
183        };
184
185        detectors
186            .axis_iter(Axis(0))
187            .progress_with_style(self.get_progress_bar_style())
188            .with_finish(finish_mode)
189            .map(|row| self.decode_observables_detailed(row))
190            .collect()
191    }
192
193    pub fn par_decode_observables_detailed_batch_progress_bar(
194        &mut self,
195        detectors: ArrayView2<Bit>,
196        leave_progress_bar_on_finish: bool,
197    ) -> Vec<ObservableDecodeResult> {
198        let finish_mode = match leave_progress_bar_on_finish {
199            true => ProgressFinish::AndLeave,
200            false => ProgressFinish::AndClear,
201        };
202
203        detectors
204            .axis_iter(Axis(0))
205            .into_par_iter()
206            .progress_with_style(self.get_progress_bar_style())
207            .with_finish(finish_mode)
208            .map_with(
209                || self.clone(),
210                |decoder, row| decoder().decode_observables_detailed(row),
211            )
212            .collect()
213    }
214
215    pub fn from_errors_decode_observables(&mut self, errors: ArrayView1<Bit>) -> Array1<Bit> {
216        let detectors = self.get_detectors(errors);
217        let decode_result = self.decode(detectors.view());
218        self.compute_observables(decode_result.view())
219    }
220
221    pub fn from_errors_decode_observables_detailed(
222        &mut self,
223        errors: ArrayView1<Bit>,
224    ) -> ObservableDecodeResult {
225        let detectors = self.get_detectors(errors);
226        let decode_result = self.decode_detailed(detectors.view());
227        let observables = self.compute_observables(errors);
228        let decoded_observables = self.compute_observables(decode_result.decoding.view());
229
230        let error_detected: bool = observables != decoded_observables;
231        let error_mismatch_detected: bool = errors != decode_result.decoding;
232        let better_decoding_quality: bool =
233            decode_result.decoding_quality < self.get_decoding_quality(errors.clone().view());
234
235        let unconverged_no_error: bool = !error_detected && !decode_result.success;
236        let better_decoding_quality_error: bool = error_detected && better_decoding_quality;
237        let worse_decoding_quality_error: bool = error_detected && !better_decoding_quality;
238
239        ObservableDecodeResult {
240            observables,
241            converged: decode_result.success,
242            iterations: decode_result.iterations,
243            true_decoding: Some(TrueDecodingResults {
244                error_detected,
245                error_mismatch_detected,
246                better_decoding_quality_error,
247                worse_decoding_quality_error,
248                unconverged_no_error,
249            }),
250            physical_decode_result: if self.include_decode_result {
251                Some(decode_result)
252            } else {
253                None
254            },
255        }
256    }
257
258    pub fn from_errors_decode_observables_batch(&mut self, errors: ArrayView2<Bit>) -> Array2<Bit> {
259        let arrs: Vec<Array1<Bit>> = errors
260            .axis_iter(Axis(0))
261            .map(|row| self.from_errors_decode_observables(row))
262            .collect();
263        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
264    }
265
266    pub fn par_from_errors_decode_observables_batch(
267        &mut self,
268        errors: ArrayView2<Bit>,
269    ) -> Array2<Bit> {
270        let arrs: Vec<Array1<Bit>> = errors
271            .axis_iter(Axis(0))
272            .into_par_iter()
273            .map_with(
274                || self.clone(),
275                |decoder, row| decoder().from_errors_decode_observables(row),
276            )
277            .collect();
278        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
279    }
280
281    pub fn from_errors_decode_observables_batch_progress_bar(
282        &mut self,
283        errors: ArrayView2<Bit>,
284        leave_progress_bar_on_finish: bool,
285    ) -> Array2<Bit> {
286        let finish_mode = match leave_progress_bar_on_finish {
287            true => ProgressFinish::AndLeave,
288            false => ProgressFinish::AndClear,
289        };
290
291        let arrs: Vec<Array1<Bit>> = errors
292            .axis_iter(Axis(0))
293            .progress_with_style(self.get_progress_bar_style())
294            .with_finish(finish_mode)
295            .map(|row| self.from_errors_decode_observables(row))
296            .collect();
297        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
298    }
299
300    pub fn par_from_errors_decode_observables_batch_progress_bar(
301        &mut self,
302        errors: ArrayView2<Bit>,
303        leave_progress_bar_on_finish: bool,
304    ) -> Array2<Bit> {
305        let finish_mode = match leave_progress_bar_on_finish {
306            true => ProgressFinish::AndLeave,
307            false => ProgressFinish::AndClear,
308        };
309
310        let arrs: Vec<Array1<Bit>> = errors
311            .axis_iter(Axis(0))
312            .into_par_iter()
313            .progress_with_style(self.get_progress_bar_style())
314            .with_finish(finish_mode)
315            .map_with(
316                || self.clone(),
317                |decoder, row| decoder().from_errors_decode_observables(row),
318            )
319            .collect();
320        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
321    }
322
323    pub fn from_errors_decode_observables_detailed_batch(
324        &mut self,
325        errors: ArrayView2<Bit>,
326    ) -> Vec<ObservableDecodeResult> {
327        errors
328            .axis_iter(Axis(0))
329            .map(|row| self.from_errors_decode_observables_detailed(row))
330            .collect()
331    }
332
333    pub fn par_from_errors_decode_observables_detailed_batch(
334        &mut self,
335        errors: ArrayView2<Bit>,
336    ) -> Vec<ObservableDecodeResult> {
337        errors
338            .axis_iter(Axis(0))
339            .into_par_iter()
340            .map_with(
341                || self.clone(),
342                |decoder, row| decoder().from_errors_decode_observables_detailed(row),
343            )
344            .collect()
345    }
346
347    /// Decode a batch displaying a progress bar
348    pub fn from_errors_decode_observables_detailed_batch_progress_bar(
349        &mut self,
350        errors: ArrayView2<Bit>,
351        leave_progress_bar_on_finish: bool,
352    ) -> Vec<ObservableDecodeResult> {
353        let finish_mode = match leave_progress_bar_on_finish {
354            true => ProgressFinish::AndLeave,
355            false => ProgressFinish::AndClear,
356        };
357
358        errors
359            .axis_iter(Axis(0))
360            .progress_with_style(self.get_progress_bar_style())
361            .with_finish(finish_mode)
362            .map(|row| self.from_errors_decode_observables_detailed(row))
363            .collect()
364    }
365
366    pub fn par_from_errors_decode_observables_detailed_batch_progress_bar(
367        &mut self,
368        errors: ArrayView2<Bit>,
369        leave_progress_bar_on_finish: bool,
370    ) -> Vec<ObservableDecodeResult> {
371        let finish_mode = match leave_progress_bar_on_finish {
372            true => ProgressFinish::AndLeave,
373            false => ProgressFinish::AndClear,
374        };
375
376        errors
377            .axis_iter(Axis(0))
378            .into_par_iter()
379            .progress_with_style(self.get_progress_bar_style())
380            .with_finish(finish_mode)
381            .map_with(
382                || self.clone(),
383                |decoder, row| decoder().from_errors_decode_observables_detailed(row),
384            )
385            .collect()
386    }
387
388    fn get_progress_bar_style(&self) -> ProgressStyle {
389        ProgressStyle::default_bar().template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({per_sec}, {eta})").unwrap()
390    }
391}
392
393impl DecoderRunner for ObservableDecoderRunner<'_> {}
394
395impl Decoder for ObservableDecoderRunner<'_> {
396    fn check_matrix(&self) -> Arc<SparseBitMatrix> {
397        self.get_decoder().check_matrix()
398    }
399    fn log_prior_ratios(&mut self) -> Array1<f64> {
400        self.get_decoder_mut().log_prior_ratios()
401    }
402    fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
403        self.get_decoder_mut().decode_detailed(detectors)
404    }
405    fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
406        self.get_decoder_mut().get_decoding_quality(errors)
407    }
408}
409
410impl ObservableDecoder for ObservableDecoderRunner<'_> {
411    fn observable_error_matrix(&self) -> Arc<SparseBitMatrix> {
412        self.observable_error_matrix.clone()
413    }
414}
415
416#[derive(Serialize, Deserialize, Debug)]
417pub struct ObservableDecodeResult {
418    pub observables: Array1<Bit>,
419    pub converged: bool,
420    pub iterations: usize,
421    pub true_decoding: Option<TrueDecodingResults>,
422    pub physical_decode_result: Option<DecodeResult>,
423}
424
425#[derive(Serialize, Deserialize, Debug)]
426pub struct TrueDecodingResults {
427    pub error_detected: bool,
428    pub error_mismatch_detected: bool,
429    pub unconverged_no_error: bool,
430    pub better_decoding_quality_error: bool,
431    pub worse_decoding_quality_error: bool,
432}
433
434#[cfg(test)]
435mod tests {
436
437    use super::*;
438    use ndarray::prelude::*;
439
440    use crate::bp::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
441    use crate::dem::DetectorErrorModel;
442    use crate::utilities::test::get_test_data_path;
443    use ndarray::Array2;
444    use ndarray_npy::read_npy;
445
446    #[test]
447    fn min_sum_decode_144_12_12() {
448        let resources = get_test_data_path();
449        let code_144_12_12 =
450            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
451        let errors_144_12_12: Array2<Bit> =
452            read_npy(resources.join("144_12_12_errors.npy")).expect("Unable to open file");
453        let bp_config_144_12_12 = MinSumDecoderConfig {
454            error_priors: code_144_12_12.error_priors,
455            alpha: Some(0.),
456            ..Default::default()
457        };
458
459        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
460        let bp_config = Arc::new(bp_config_144_12_12);
461        let decoder_144_12_12: Box<MinSumBPDecoder<f64>> =
462            Box::new(MinSumBPDecoder::new(check_matrix, bp_config));
463
464        let obs_matrix = Arc::new(code_144_12_12.observable_error_matrix);
465        let mut observable_decoder =
466            ObservableDecoderRunner::new(decoder_144_12_12, obs_matrix, true);
467
468        let num_errors = 100;
469        let errors_slice = errors_144_12_12.slice(s![..num_errors, ..]);
470        let results =
471            observable_decoder.par_from_errors_decode_observables_detailed_batch(errors_slice);
472
473        // Assert 90% correct.
474        assert!(
475            results
476                .iter()
477                .map(|x| x.physical_decode_result.as_ref().unwrap().success as usize)
478                .sum::<usize>() as f64
479                >= (errors_slice.shape()[0] as f64) * 0.93
480        );
481
482        assert!(
483            results
484                .iter()
485                .map(|x| x.true_decoding.as_ref().unwrap().error_mismatch_detected as usize)
486                .sum::<usize>() as f64
487                >= (errors_slice.shape()[0] as f64) * 0.96
488        );
489
490        assert_eq!(
491            results[0].observables,
492            array![1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1]
493        );
494
495        assert!(
496            results
497                .iter()
498                .map(|x| x.true_decoding.as_ref().unwrap().error_detected as usize)
499                .sum::<usize>() as f64
500                <= (errors_slice.shape()[0] as f64) * 0.07
501        );
502
503        let direct_results = observable_decoder.decode_observables(
504            observable_decoder
505                .get_detectors(errors_144_12_12.row(0))
506                .view(),
507        );
508        assert_eq!(results[0].observables, direct_results);
509
510        let results2 = observable_decoder.par_decode_observables_batch(
511            observable_decoder
512                .get_detectors_batch(errors_144_12_12.view())
513                .view(),
514        );
515
516        assert_eq!(results2.row(0), direct_results);
517    }
518}