Skip to main content

relay_bp/bp/
min_sum.rs

1// (C) Copyright IBM 2025
2//
3// This code is licensed under the Apache License, Version 2.0. You may
4// obtain a copy of this license in the LICENSE.txt file in the root directory
5// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
6//
7// Any modifications or derivative works of this code must retain this
8// copyright notice, and modified files need to carry a notice indicating
9// that they have been altered from the originals.
10
11use crate::bipartite_graph::SparseBipartiteGraph;
12use crate::decoder::{BPExtraResult, DecodeResult, Decoder, DecoderRunner};
13use crate::decoder::{Bit, SparseBitMatrix};
14use itertools::izip;
15use log::debug;
16use ndarray::{Array1, ArrayView1};
17use num_traits::FromPrimitive;
18use num_traits::{Bounded, Signed, ToPrimitive};
19use sprs::CsMatView;
20use std::fmt::Debug;
21use std::sync::Arc;
22
23#[derive(Clone, Debug)]
24pub struct MinSumDecoderConfig {
25    pub error_priors: Array1<f64>,
26    pub max_iter: usize,
27    pub alpha: Option<f64>,
28    pub alpha_iteration_scaling_factor: f64,
29    pub gamma0: Option<f64>,
30    pub data_scale_value: Option<f64>,
31    pub max_data_value: Option<f64>,
32    pub int_bits: Option<isize>,
33    pub frac_bits: Option<isize>,
34}
35
36impl Default for MinSumDecoderConfig {
37    fn default() -> Self {
38        Self {
39            error_priors: Default::default(),
40            max_iter: 200,
41            alpha: None,
42            alpha_iteration_scaling_factor: 1.,
43            gamma0: None,
44            data_scale_value: None,
45            max_data_value: None,
46            int_bits: None,
47            frac_bits: None,
48        }
49    }
50}
51
52impl MinSumDecoderConfig {
53    pub fn prior_ratios(&self) -> Array1<f64> {
54        // A funky way of making (1-p)/p handle left to right type inference for arithmetic
55        (1.0 - &self.error_priors) / &self.error_priors
56    }
57
58    pub fn log_prior_ratios(&self) -> Array1<f64> {
59        self.prior_ratios().ln()
60    }
61
62    pub fn set_max_iter(&mut self, iterations: usize) {
63        self.max_iter = iterations;
64    }
65
66    pub fn set_fixed(&mut self, int_bits: isize, frac_bits: isize) {
67        self.int_bits = Some(int_bits);
68        self.frac_bits = Some(frac_bits);
69        self.max_data_value = Some((1 << (int_bits - 1)) as f64);
70    }
71}
72
73/// A fast min-sum implementation of BP implemented internally
74/// using a sparse bipartite graph.
75#[derive(Clone)]
76pub struct MinSumBPDecoder<N: PartialEq + Default + Clone + Copy> {
77    check_matrix: Arc<SparseBitMatrix>,
78    pub config: Arc<MinSumDecoderConfig>,
79    log_prior_ratios: Array1<N>,
80    check_to_variable: SparseBipartiteGraph<N>,
81    variable_to_check: SparseBipartiteGraph<N>,
82    // A cache of check to variable data mappings.
83    check_to_variable_nnz_map: Vec<usize>,
84    // A cache of variable to check data mappings.
85    variable_to_check_nnz_map: Vec<usize>,
86    posterior_ratios: Array1<N>,
87    memory_strengths: Array1<N>,
88    decoding: Array1<Bit>,
89    max_data_value: Option<N>,
90    data_scale_value: Option<N>,
91    pub current_iteration: usize,
92}
93
94impl<N> MinSumBPDecoder<N>
95where
96    N: PartialEq
97        + Debug
98        + Default
99        + Clone
100        + Copy
101        + Signed
102        + Bounded
103        + FromPrimitive
104        + ToPrimitive
105        + std::cmp::PartialOrd
106        + std::ops::Add
107        + std::ops::AddAssign
108        + std::ops::DivAssign
109        + std::ops::Mul<N>
110        + std::ops::MulAssign
111        + Send
112        + Sync
113        + std::fmt::Display
114        + 'static,
115{
116    pub fn new(
117        check_matrix: Arc<SparseBitMatrix>,
118        config: Arc<MinSumDecoderConfig>,
119    ) -> MinSumBPDecoder<N> {
120        let check_to_variable = MinSumBPDecoder::build_check_to_variable(check_matrix.clone());
121        let variable_to_check = MinSumBPDecoder::build_variable_to_check(check_matrix.clone());
122        let (check_to_variable_nnz_map, variable_to_check_nnz_map) =
123            MinSumBPDecoder::build_nnz_maps(check_to_variable.view(), variable_to_check.view());
124
125        let max_data_value = match config.max_data_value {
126            Some(val) => N::from_f64(val),
127            None => None,
128        };
129
130        let data_scale_value = match config.data_scale_value {
131            Some(val) => N::from_f64(val),
132            None => None,
133        };
134
135        let log_prior_ratios = config.log_prior_ratios().mapv_into_any(|val| {
136            let updated_val = match val {
137                f64::INFINITY => N::max_value(),
138                _ => {
139                    // Apply optional data scaling
140                    let prior = match config.data_scale_value {
141                        Some(scale_value) => scale_value * val,
142                        None => val,
143                    };
144                    N::from_f64(prior).unwrap()
145                }
146            };
147
148            // Bound prior values if necessary
149            match max_data_value {
150                Some(max_val) => Self::bound_value_magnitude(updated_val, max_val),
151                None => updated_val,
152            }
153        });
154
155        let memory_strengths = Array1::from_elem(check_matrix.cols(), N::zero());
156
157        let posterior_ratios = if config.gamma0.is_some() {
158            log_prior_ratios.clone()
159        } else {
160            Array1::zeros(check_matrix.cols())
161        };
162
163        let decoding = Array1::zeros(check_matrix.cols());
164
165        MinSumBPDecoder::<N> {
166            check_matrix,
167            config,
168            log_prior_ratios,
169            check_to_variable,
170            variable_to_check,
171            check_to_variable_nnz_map,
172            variable_to_check_nnz_map,
173            posterior_ratios,
174            memory_strengths,
175            decoding,
176            max_data_value,
177            data_scale_value,
178            current_iteration: 0,
179        }
180    }
181
182    pub fn set_log_prior_ratio(&mut self, mut log_prior_ratios: Array1<N>) {
183        self.log_prior_ratios = match self.data_scale_value {
184            Some(scale_val) => {
185                log_prior_ratios.iter_mut().for_each(|v| *v *= scale_val);
186                log_prior_ratios
187            }
188            None => log_prior_ratios,
189        };
190    }
191
192    pub fn set_log_prior_ratio_f64(&mut self, log_prior_ratios: Array1<f64>) {
193        self.log_prior_ratios = match self.config.data_scale_value {
194            Some(scale_val) => {
195                log_prior_ratios.mapv_into_any(|v| N::from_f64(scale_val * v).unwrap())
196            }
197            None => log_prior_ratios.mapv_into_any(|v| N::from_f64(v).unwrap()),
198        };
199    }
200
201    pub fn set_posterior_ratios_to_priors(&mut self) {
202        self.posterior_ratios = self.log_prior_ratios.clone();
203    }
204
205    /// Set external memory strengths from f64. Applies scaling if needed.
206    pub fn set_memory_strengths_f64(&mut self, memory_strengths: Array1<f64>) {
207        self.memory_strengths = match self.config.data_scale_value {
208            Some(scale_val) => {
209                memory_strengths.mapv_into_any(|v| N::from_f64(scale_val * v).unwrap())
210            }
211            None => memory_strengths.mapv_into_any(|v| N::from_f64(v).unwrap()),
212        };
213    }
214
215    /// Set external memory strengths from N. Applies scaling if needed.
216    pub fn set_memory_strengths(&mut self, mut memory_strengths: Array1<N>) {
217        self.memory_strengths = match self.data_scale_value {
218            Some(scale_val) => {
219                memory_strengths.iter_mut().for_each(|v| *v *= scale_val);
220                memory_strengths
221            }
222            None => memory_strengths,
223        };
224    }
225
226    // Construct a new check message graph
227    fn build_check_to_variable(check_matrix: Arc<SparseBitMatrix>) -> SparseBipartiteGraph<N> {
228        let check_matrix_csc = check_matrix.to_csc();
229
230        let default_messages: Vec<_> = vec![N::default(); check_matrix_csc.nnz()];
231
232        SparseBipartiteGraph::new_csc(
233            check_matrix_csc.shape(),
234            check_matrix_csc.indptr().raw_storage().to_vec(),
235            check_matrix_csc.indices().to_vec(),
236            default_messages,
237        )
238    }
239
240    // Construct a new variable message graph
241    fn build_variable_to_check(check_matrix: Arc<SparseBitMatrix>) -> SparseBipartiteGraph<N> {
242        let check_matrix_csr = check_matrix.to_csr();
243
244        let default_messages: Vec<_> = vec![N::default(); check_matrix_csr.nnz()];
245
246        SparseBipartiteGraph::new(
247            check_matrix_csr.shape(),
248            check_matrix_csr.indptr().raw_storage().to_vec(),
249            check_matrix_csr.indices().to_vec(),
250            default_messages,
251        )
252    }
253
254    fn build_nnz_maps(
255        check_to_variable: CsMatView<N>,
256        variable_to_check: CsMatView<N>,
257    ) -> (Vec<usize>, Vec<usize>) {
258        let mut check_to_variable_nnz_map: Vec<usize> = Vec::with_capacity(check_to_variable.nnz());
259        for (_, (row, col)) in check_to_variable.view().iter() {
260            check_to_variable_nnz_map.push(variable_to_check.nnz_index(row, col).unwrap().0);
261        }
262        let mut variable_to_check_nnz_map: Vec<usize> = Vec::with_capacity(variable_to_check.nnz());
263        for (_, (row, col)) in variable_to_check.view().iter() {
264            variable_to_check_nnz_map.push(check_to_variable.nnz_index(row, col).unwrap().0);
265        }
266        (check_to_variable_nnz_map, variable_to_check_nnz_map)
267    }
268
269    /// Initialize variable message state to the prior
270    pub fn initialize_variable_to_check(&mut self) {
271        for mut row_vec in self.variable_to_check.outer_iterator_mut() {
272            row_vec
273                .iter_mut()
274                .for_each(|(col_ind, val)| *val = self.log_prior_ratios[col_ind]);
275        }
276    }
277
278    pub fn initialize_check_to_variable(&mut self) {
279        for mut col_vec in self.check_to_variable.outer_iterator_mut() {
280            col_vec
281                .iter_mut()
282                .for_each(|(_row_ind, val)| *val = N::zero());
283        }
284    }
285
286    pub fn initialize_memory_strengths(&mut self) {
287        // Initialize Mem-BP and posteriors
288        let ewa_factor_float = self.config.gamma0.unwrap_or(0.);
289        let ewa_factor = N::from_f64(match self.config.data_scale_value {
290            Some(scale_value) => scale_value * ewa_factor_float,
291            None => ewa_factor_float,
292        })
293        .unwrap();
294        self.memory_strengths.fill(ewa_factor);
295    }
296
297    pub fn initialize_decoder(&mut self) {
298        self.current_iteration = 0;
299        self.initialize_memory_strengths();
300        self.initialize_check_to_variable();
301        self.initialize_variable_to_check();
302        // Initialize posteriors if needed for mem-BP
303        if self.config.gamma0.is_some() {
304            self.set_posterior_ratios_to_priors();
305        };
306    }
307
308    fn alpha(&self) -> N {
309        let mut alpha = match self.config.alpha {
310            Some(0.) => {
311                let iteration = (self.current_iteration + 1) as f64;
312                1.0 - (2_f64).powf(-(iteration / self.config.alpha_iteration_scaling_factor))
313            }
314            Some(val) => val,
315            None => 1.0,
316        };
317
318        // Handle case of alpha < 0 defaulting to 1.
319        // This aligns with integer case of ldpc-simulation
320        if alpha < 0. {
321            alpha = 1.
322        }
323        // Scale if needed
324        alpha = match self.config.data_scale_value {
325            Some(scale_val) => scale_val * alpha,
326            None => alpha,
327        };
328
329        N::from_f64(alpha).unwrap()
330    }
331
332    /// Compute check to bit message iteration
333    fn compute_check_to_variable(
334        &mut self,
335        detectors: ArrayView1<Bit>,
336    ) -> &mut SparseBipartiteGraph<N> {
337        let alpha = self.alpha();
338
339        for (var_check_row_ind, var_check_row_vec) in
340            self.variable_to_check.outer_iterator().enumerate()
341        {
342            let row_sign = if detectors[var_check_row_ind] == 1 {
343                N::one().neg()
344            } else {
345                N::one()
346            };
347            let mut accumulated_sign = row_sign.is_negative();
348            let mut min_ind: usize = 0;
349            // True min message
350            let mut min_message = N::max_value();
351            // Next lowest min message to be used for self-exlusive min value
352            let mut second_min_message = N::max_value();
353
354            for (var_check_col_ind, var_check_col_val) in var_check_row_vec.iter() {
355                accumulated_sign ^= var_check_col_val.is_negative();
356                let abs_msg = var_check_col_val.abs();
357                if abs_msg <= min_message {
358                    second_min_message = min_message;
359                    min_message = abs_msg;
360                    min_ind = var_check_col_ind;
361                } else if abs_msg <= second_min_message {
362                    second_min_message = abs_msg
363                }
364            }
365
366            debug!("Variable messages for row {var_check_row_ind:?}: {var_check_row_vec:?}");
367
368            // Iterate over the row's storage indices
369            let data_range = self
370                .variable_to_check
371                .indptr()
372                .outer_inds(var_check_row_ind);
373
374            for (ind, var_check_col_ind, var_check_col_val) in izip!(
375                data_range.clone(),
376                &self.variable_to_check.indices()[data_range.clone()],
377                &self.variable_to_check.data()[data_range.clone()]
378            ) {
379                // Extract the sign from the accumulated sign.
380                let check_to_variable_sign = accumulated_sign ^ var_check_col_val.is_negative();
381                let check_to_variable_min: N = if *var_check_col_ind != min_ind {
382                    min_message
383                } else {
384                    second_min_message
385                };
386                // Copy the sign to the variable. check_to_variable_min is guranteed to be positive.
387                let mut check_to_variable = alpha * check_to_variable_min;
388                if check_to_variable_sign {
389                    check_to_variable = check_to_variable.neg();
390                }
391
392                // We directly manipulate the indicies of the check_to_variable_matrix using
393                // the cached value map to avoid the need for a logarithmic insert
394                self.check_to_variable.data_mut()[self.variable_to_check_nnz_map[ind]] =
395                    check_to_variable;
396            }
397        }
398
399        if let Some(scale_val) = self.data_scale_value {
400            self.check_to_variable /= scale_val
401        }
402
403        &mut self.check_to_variable
404    }
405
406    fn compute_variable_prior(&self, variable: usize) -> N {
407        // Apply membp
408        if self.config.gamma0.is_some() {
409            if self.log_prior_ratios[variable] == N::max_value() {
410                return self.log_prior_ratios[variable];
411            }
412            let scaled_one = self.data_scale_value.unwrap_or(N::one());
413            // First divide through denominator before numerator to avoid overflow
414            let prior_component = (self.log_prior_ratios[variable] / scaled_one)
415                * (scaled_one - self.memory_strengths[variable]);
416            let posterior_component =
417                (self.posterior_ratios[variable] / scaled_one) * self.memory_strengths[variable];
418            return prior_component + posterior_component;
419        }
420        self.log_prior_ratios[variable]
421    }
422
423    /// Compute bit to check message iteration
424    fn compute_variable_to_check(&mut self) -> &mut SparseBipartiteGraph<N> {
425        for (check_var_col_ind, check_var_col_vec) in
426            self.check_to_variable.outer_iterator().enumerate()
427        {
428            // Accumulate messages
429            let mut check_to_var_row_sum = self.compute_variable_prior(check_var_col_ind);
430
431            debug!("Check messages for col {check_var_col_ind:?}: {check_var_col_vec:?}");
432
433            let data_range = self
434                .check_to_variable
435                .indptr()
436                .outer_inds(check_var_col_ind);
437
438            // Perform iteration in the forward direction to accumulate left to right
439            for (ind, check_var_row_val) in izip!(
440                data_range.clone(),
441                &self.check_to_variable.data()[data_range.clone()]
442            ) {
443                self.variable_to_check.data_mut()[self.check_to_variable_nnz_map[ind]] =
444                    check_to_var_row_sum;
445                check_to_var_row_sum += *check_var_row_val;
446            }
447
448            self.posterior_ratios[check_var_col_ind] = check_to_var_row_sum;
449
450            // Now perform iteration in the reverse direction to accumulate right to left
451            check_to_var_row_sum = N::zero();
452            // Remove each messages contribution
453            for (ind, check_var_row_val) in izip!(
454                data_range.clone(),
455                &self.check_to_variable.data()[data_range.clone()]
456            )
457            .rev()
458            {
459                let map_ind = self.check_to_variable_nnz_map[ind];
460                self.variable_to_check.data_mut()[map_ind] += check_to_var_row_sum;
461                check_to_var_row_sum += *check_var_row_val;
462
463                // We directly manipulate the indicies of the variable_to_check matrix using
464                // the cached value map to avoid the need for a logarithmic insert
465                debug!(
466                    "location ({:?}, {:?}), variable_to_check: {:.32}",
467                    self.check_to_variable.indices()[ind],
468                    check_var_col_ind,
469                    self.variable_to_check.data_mut()[self.check_to_variable_nnz_map[ind]]
470                );
471            }
472        }
473
474        self.bound_magnitudes();
475
476        &mut self.variable_to_check
477    }
478
479    pub fn run_iteration(&mut self, detectors: ArrayView1<Bit>) {
480        debug!("Iteration {:?} start", self.current_iteration);
481        self.compute_check_to_variable(detectors);
482        // Now compute variable to check messages
483        self.compute_variable_to_check();
484
485        self.compute_hard_decision();
486        debug!("Iteration {:?} end", self.current_iteration);
487    }
488
489    pub fn build_result(
490        &mut self,
491        success: bool,
492        decoded_detectors: Array1<Bit>,
493        max_iter: usize,
494    ) -> DecodeResult {
495        DecodeResult {
496            decoding: self.decoding.clone(),
497            decoded_detectors,
498            posterior_ratios: self.posterior_ratios.clone().mapv_into_any(|val| {
499                let posterior = N::to_f64(&val).unwrap();
500                match self.config.data_scale_value {
501                    Some(scale_val) => posterior / scale_val,
502                    None => posterior,
503                }
504            }),
505            success,
506            decoding_quality: if success {
507                self.get_decoding_quality(self.decoding.clone().view())
508            } else {
509                f64::MAX
510            },
511            iterations: self.current_iteration,
512            max_iter,
513            extra: BPExtraResult::None,
514        }
515    }
516    fn bound_magnitudes(&mut self) {
517        // Bound magnitudes
518        if self.max_data_value.is_some() {
519            let max_val = self.max_data_value.unwrap();
520            self.variable_to_check
521                .data_mut()
522                .iter_mut()
523                .for_each(|v| *v = Self::bound_value_magnitude(*v, max_val));
524            self.posterior_ratios
525                .iter_mut()
526                .for_each(|v| *v = Self::bound_value_magnitude(*v, max_val));
527        }
528    }
529
530    fn bound_value_magnitude(value: N, max_val: N) -> N
531    where
532        N: std::ops::Add,
533    {
534        if value < max_val.neg() {
535            max_val.neg()
536        } else if value > max_val {
537            max_val
538        } else {
539            value
540        }
541    }
542
543    fn compute_hard_decision(&mut self) {
544        for (idx, posterior) in self.posterior_ratios.iter().enumerate() {
545            self.decoding[idx] = Bit::from((*posterior) <= N::zero());
546        }
547        debug!("Posteriors: {:?}", self.posterior_ratios);
548        debug!("Hard decision: {:?}", self.decoding);
549    }
550
551    pub fn compute_decoded_detectors(&self) -> Array1<Bit> {
552        self.get_detectors(self.decoding.view())
553    }
554
555    // Check the convergence of the problem instance
556    pub fn check_convergence(
557        &self,
558        detectors: ArrayView1<Bit>,
559        decoded_detectors: ArrayView1<Bit>,
560    ) -> bool {
561        detectors == decoded_detectors
562    }
563}
564
565impl<N> Decoder for MinSumBPDecoder<N>
566where
567    N: PartialEq
568        + Debug
569        + Default
570        + Clone
571        + Copy
572        + FromPrimitive
573        + ToPrimitive
574        + Signed
575        + Bounded
576        + std::cmp::PartialOrd
577        + std::ops::Add
578        + std::ops::AddAssign
579        + std::ops::Mul<N>
580        + std::ops::MulAssign
581        + std::ops::DivAssign
582        + Send
583        + Sync
584        + std::fmt::Display
585        + 'static,
586{
587    fn check_matrix(&self) -> Arc<SparseBitMatrix> {
588        self.check_matrix.clone()
589    }
590
591    fn log_prior_ratios(&mut self) -> Array1<f64> {
592        self.config.log_prior_ratios()
593    }
594
595    fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
596        // Initialize probability ratios
597        self.initialize_decoder();
598        let mut success: bool = false;
599        let mut decoded_detectors = Array1::default(detectors.dim());
600
601        for _ in 0..self.config.max_iter {
602            self.run_iteration(detectors);
603            self.current_iteration += 1;
604            decoded_detectors = self.compute_decoded_detectors();
605            success = self.check_convergence(detectors, decoded_detectors.view());
606
607            // If we have converged may now exit
608            if success {
609                debug!("Succeeded on iteration {:?}", self.current_iteration);
610                break;
611            }
612        }
613
614        self.build_result(success, decoded_detectors, self.config.max_iter)
615    }
616}
617
618impl<N> DecoderRunner for MinSumBPDecoder<N> where
619    N: PartialEq
620        + Debug
621        + Default
622        + Clone
623        + Copy
624        + FromPrimitive
625        + ToPrimitive
626        + Signed
627        + Bounded
628        + std::cmp::PartialOrd
629        + std::ops::Add
630        + std::ops::AddAssign
631        + std::ops::Mul<N>
632        + std::ops::MulAssign
633        + std::ops::DivAssign
634        + Send
635        + Sync
636        + std::fmt::Display
637        + 'static
638{
639}
640
641#[cfg(test)]
642mod tests {
643    use crate::bipartite_graph::BipartiteGraph;
644
645    use super::*;
646    use env_logger;
647    use ndarray::prelude::*;
648
649    use crate::dem::DetectorErrorModel;
650    use crate::utilities::test::get_test_data_path;
651    use ndarray::Array2;
652    use ndarray_npy::read_npy;
653
654    fn init() {
655        let _ = env_logger::builder().is_test(true).try_init();
656    }
657
658    #[test]
659    fn decode_detailed_repetition_code() {
660        init();
661
662        // Build 3, 2 qubit repetition code with weight 2 checks
663        let check_matrix = array![[1, 1, 0], [0, 1, 1],];
664
665        let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
666        let arc_check_matrix = Arc::new(check_matrix);
667
668        let iterations = 10;
669        let bp_config = MinSumDecoderConfig {
670            error_priors: array![0.003, 0.003, 0.003],
671            max_iter: iterations,
672            ..Default::default()
673        };
674        let arc_bp_config = Arc::new(bp_config);
675
676        let mut decoder: MinSumBPDecoder<f64> =
677            MinSumBPDecoder::new(arc_check_matrix, arc_bp_config);
678
679        let error = array![0, 0, 0];
680        let detectors: Array1<Bit> = array![0, 0];
681
682        let result = decoder.decode_detailed(detectors.view());
683
684        assert_eq!(result.decoding, error);
685        assert_eq!(result.decoded_detectors, detectors);
686        assert_eq!(result.max_iter, iterations);
687        assert!(result.success);
688
689        let error = array![1, 0, 0];
690        let detectors: Array1<Bit> = array![1, 0];
691
692        let result = decoder.decode_detailed(detectors.view());
693
694        assert_eq!(result.decoding, error);
695        assert_eq!(result.decoded_detectors, detectors);
696        assert_eq!(result.max_iter, iterations);
697        assert!(result.success);
698
699        let error = array![0, 1, 0];
700        let detectors: Array1<Bit> = array![1, 1];
701
702        let result = decoder.decode_detailed(detectors.view());
703
704        assert_eq!(result.decoding, error);
705        assert_eq!(result.decoded_detectors, detectors);
706        assert_eq!(result.max_iter, iterations);
707        assert!(result.success);
708
709        let error = array![0, 0, 1];
710        let detectors: Array1<Bit> = array![0, 1];
711
712        let result = decoder.decode_detailed(detectors.view());
713
714        assert_eq!(result.decoding, error);
715        assert_eq!(result.decoded_detectors, detectors);
716        assert_eq!(result.max_iter, iterations);
717        assert!(result.success);
718    }
719
720    #[test]
721    fn decode_detailed_repetition_code_int() {
722        init();
723
724        // Build 3, 2 qubit repetition code with weight 2 checks
725        let check_matrix = array![[1, 1, 0], [0, 1, 1],];
726
727        let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
728        let arc_check_matrix = Arc::new(check_matrix);
729
730        let iterations = 10;
731
732        let bits = 7;
733        let scale = 4.0;
734
735        let bp_config = MinSumDecoderConfig {
736            error_priors: array![0.003, 0.003, 0.003],
737            max_iter: iterations,
738            max_data_value: Some(((1 << bits) - 1) as f64),
739            data_scale_value: Some(scale),
740            ..Default::default()
741        };
742        let arc_bp_config = Arc::new(bp_config);
743
744        let mut decoder: MinSumBPDecoder<isize> =
745            MinSumBPDecoder::new(arc_check_matrix, arc_bp_config);
746
747        let error = array![0, 0, 0];
748        let detectors: Array1<Bit> = array![0, 0];
749
750        let result = decoder.decode_detailed(detectors.view());
751
752        assert_eq!(result.decoding, error);
753        assert_eq!(result.decoded_detectors, detectors);
754        assert_eq!(result.max_iter, iterations);
755        assert!(result.success);
756
757        let error = array![1, 0, 0];
758        let detectors: Array1<Bit> = array![1, 0];
759
760        let result = decoder.decode_detailed(detectors.view());
761
762        assert_eq!(result.decoding, error);
763        assert_eq!(result.decoded_detectors, detectors);
764        assert_eq!(result.max_iter, iterations);
765        assert!(result.success);
766
767        let error = array![0, 1, 0];
768        let detectors: Array1<Bit> = array![1, 1];
769
770        let result = decoder.decode_detailed(detectors.view());
771
772        assert_eq!(result.decoding, error);
773        assert_eq!(result.decoded_detectors, detectors);
774        assert_eq!(result.max_iter, iterations);
775        assert!(result.success);
776
777        let error = array![0, 0, 1];
778        let detectors: Array1<Bit> = array![0, 1];
779
780        let result = decoder.decode_detailed(detectors.view());
781
782        assert_eq!(result.decoding, error);
783        assert_eq!(result.decoded_detectors, detectors);
784        assert_eq!(result.max_iter, iterations);
785        assert!(result.success);
786    }
787
788    #[test]
789    fn decode_detailed_144_12_12() {
790        let resources = get_test_data_path();
791        let code_144_12_12 =
792            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
793        let detectors_144_12_12: Array2<Bit> =
794            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
795        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
796        let bp_config_144_12_12 = MinSumDecoderConfig {
797            error_priors: code_144_12_12.error_priors,
798            alpha: Some(0.),
799            ..Default::default()
800        };
801        let config = Arc::new(bp_config_144_12_12);
802
803        let mut decoder_144_12_12: MinSumBPDecoder<f64> =
804            MinSumBPDecoder::new(check_matrix, config);
805        let num_errors = 100;
806        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
807        let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
808
809        assert!(
810            results.iter().map(|x| x.success as usize).sum::<usize>() as f64
811                >= (detectors_slice.shape()[0] as f64) * 0.93
812        );
813
814        assert_eq!(results[0].decoding.len(), 8785);
815    }
816
817    #[test]
818    fn decode_144_12_12() {
819        let resources = get_test_data_path();
820        let code_144_12_12 =
821            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
822        let detectors_144_12_12: Array2<Bit> =
823            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
824        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
825        let bp_config_144_12_12 = MinSumDecoderConfig {
826            error_priors: code_144_12_12.error_priors,
827            ..Default::default()
828        };
829        let config = Arc::new(bp_config_144_12_12);
830
831        let mut decoder_144_12_12: MinSumBPDecoder<f64> =
832            MinSumBPDecoder::new(check_matrix, config);
833        let num_errors = 100;
834        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
835
836        let results = decoder_144_12_12.par_decode_batch(detectors_slice);
837
838        let results_detailed = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
839
840        for i in 0..results.shape()[0] {
841            assert!(results.row(i) == results_detailed[i].decoding)
842        }
843    }
844
845    #[test]
846    fn decode_detailed_144_12_12_membp() {
847        let resources = get_test_data_path();
848        let code_144_12_12 =
849            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
850        let detectors_144_12_12: Array2<Bit> =
851            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
852        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
853        let bp_config_144_12_12 = MinSumDecoderConfig {
854            error_priors: code_144_12_12.error_priors,
855            gamma0: Some(0.15),
856            ..Default::default()
857        };
858        let config = Arc::new(bp_config_144_12_12);
859
860        let mut decoder_144_12_12: MinSumBPDecoder<f64> =
861            MinSumBPDecoder::new(check_matrix, config);
862        let num_errors = 100;
863        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
864        let par_results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
865        let results = decoder_144_12_12.decode_detailed_batch(detectors_slice);
866        assert!(
867            results.iter().map(|x| x.success as usize).sum::<usize>() as f64
868                == par_results
869                    .iter()
870                    .map(|x| x.success as usize)
871                    .sum::<usize>() as f64
872        );
873        assert!(
874            results.iter().map(|x| x.success as usize).sum::<usize>() as f64
875                >= (detectors_slice.shape()[0] as f64) * 0.93
876        );
877
878        assert_eq!(results[0].decoding.len(), 8785);
879    }
880
881    #[test]
882    fn decode_detailed_144_12_12_int() {
883        let resources = get_test_data_path();
884        let code_144_12_12 =
885            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
886        let detectors_144_12_12: Array2<Bit> =
887            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
888        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
889
890        let bits = 16;
891        let scale = 8.0;
892
893        let bp_config_144_12_12 = MinSumDecoderConfig {
894            error_priors: code_144_12_12.error_priors,
895            max_data_value: Some(((1 << bits) - 1) as f64),
896            data_scale_value: Some(scale),
897            alpha: Some(0.),
898            ..Default::default()
899        };
900        let config = Arc::new(bp_config_144_12_12);
901
902        let mut decoder_144_12_12: MinSumBPDecoder<isize> =
903            MinSumBPDecoder::new(check_matrix, config);
904        let num_errors = 100;
905        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
906        let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
907
908        assert!(
909            results.iter().map(|x| x.success as usize).sum::<usize>() as f64
910                >= (detectors_slice.shape()[0] as f64) * 0.93
911        );
912
913        assert_eq!(results[0].decoding.len(), 8785);
914    }
915}