Skip to main content

relay_bp/bp/
min_sum_fixed.rs

1// (C) Copyright IBM 2025
2//
3// This code is licensed under the Apache License, Version 2.0. You may
4// obtain a copy of this license in the LICENSE.txt file in the root directory
5// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
6//
7// Any modifications or derivative works of this code must retain this
8// copyright notice, and modified files need to carry a notice indicating
9// that they have been altered from the originals.
10
11use fixed::{types::extra as fbits, FixedI64};
12
13use crate::bp::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
14use crate::decoder::{Bit, DecodeResult, Decoder, DecoderRunner, SparseBitMatrix};
15use ndarray::{Array1, ArrayView1};
16
17use std::sync::Arc;
18
19/// Runtime implemetnation of fixed point precision which routes to the appropriate compile time
20/// implementation.
21#[derive(Clone)]
22pub struct MinSumBPDecoderFixed {
23    decoder: Box<dyn Decoder + Send + 'static>,
24}
25
26impl MinSumBPDecoderFixed {
27    pub fn new(
28        check_matrix: Arc<SparseBitMatrix>,
29        min_sum_config: Arc<MinSumDecoderConfig>,
30    ) -> Self {
31        MinSumBPDecoderFixed {
32            decoder: MinSumBPDecoderFixed::build_min_sum(check_matrix, min_sum_config),
33        }
34    }
35
36    fn build_min_sum(
37        check_matrix: Arc<SparseBitMatrix>,
38        min_sum_config: Arc<MinSumDecoderConfig>,
39    ) -> Box<dyn Decoder + Send + 'static> {
40        macro_rules! genmatch {
41            ($Variant:path) => {
42                return Box::new(MinSumBPDecoder::<FixedI64<$Variant>>::new(
43                    check_matrix,
44                    min_sum_config,
45                ))
46            };
47        }
48
49        match min_sum_config.frac_bits {
50            Some(0) => genmatch!(fbits::U0),
51            Some(1) => genmatch!(fbits::U1),
52            Some(2) => genmatch!(fbits::U2),
53            Some(3) => genmatch!(fbits::U3),
54            Some(4) => genmatch!(fbits::U4),
55            Some(5) => genmatch!(fbits::U5),
56            Some(6) => genmatch!(fbits::U6),
57            Some(7) => genmatch!(fbits::U7),
58            Some(8) => genmatch!(fbits::U8),
59            Some(9) => genmatch!(fbits::U9),
60            Some(10) => genmatch!(fbits::U10),
61            Some(11) => genmatch!(fbits::U11),
62            Some(12) => genmatch!(fbits::U12),
63            Some(13) => genmatch!(fbits::U13),
64            Some(14) => genmatch!(fbits::U14),
65            Some(15) => genmatch!(fbits::U15),
66            Some(16) => genmatch!(fbits::U16),
67            Some(17) => genmatch!(fbits::U17),
68            Some(18) => genmatch!(fbits::U18),
69            Some(19) => genmatch!(fbits::U19),
70            Some(20) => genmatch!(fbits::U20),
71            Some(21) => genmatch!(fbits::U21),
72            Some(22) => genmatch!(fbits::U22),
73            Some(23) => genmatch!(fbits::U23),
74            Some(24) => genmatch!(fbits::U24),
75            Some(25) => genmatch!(fbits::U25),
76            Some(26) => genmatch!(fbits::U26),
77            Some(27) => genmatch!(fbits::U27),
78            Some(28) => genmatch!(fbits::U28),
79            Some(29) => genmatch!(fbits::U29),
80            Some(30) => genmatch!(fbits::U30),
81            Some(31) => genmatch!(fbits::U31),
82            _ => panic!(
83                "Decoder does not support {:?} fractional bits",
84                min_sum_config.frac_bits
85            ),
86        }
87    }
88}
89
90impl Decoder for MinSumBPDecoderFixed {
91    fn check_matrix(&self) -> Arc<SparseBitMatrix> {
92        self.decoder.check_matrix()
93    }
94
95    fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
96        self.decoder.decode_detailed(detectors)
97    }
98
99    fn log_prior_ratios(&mut self) -> Array1<f64> {
100        self.decoder.log_prior_ratios()
101    }
102}
103
104impl DecoderRunner for MinSumBPDecoderFixed {}
105
106#[cfg(test)]
107mod tests {
108    use crate::bipartite_graph::BipartiteGraph;
109    use crate::bipartite_graph::SparseBipartiteGraph;
110    use crate::decoder::{Bit, Decoder, DecoderRunner};
111
112    use super::*;
113    use env_logger;
114    use ndarray::prelude::*;
115
116    use crate::dem::DetectorErrorModel;
117    use crate::utilities::test::get_test_data_path;
118    use ndarray::Array2;
119    use ndarray_npy::read_npy;
120
121    use std::sync::Arc;
122
123    fn init() {
124        let _ = env_logger::builder().is_test(true).try_init();
125    }
126
127    #[test]
128    fn decode_repetition_code_int() {
129        init();
130
131        // Build 3, 2 qubit repetition code with weight 2 checks
132        let check_matrix = array![[1, 1, 0], [0, 1, 1],];
133
134        let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
135
136        let iterations = 10;
137        let mut bp_config = MinSumDecoderConfig {
138            error_priors: array![0.003, 0.003, 0.003],
139            max_iter: iterations,
140            ..Default::default()
141        };
142        bp_config.set_fixed(5, 2);
143
144        let check_matrix_arc = Arc::new(check_matrix);
145        let config = Arc::new(bp_config);
146        let mut decoder: MinSumBPDecoderFixed = MinSumBPDecoderFixed::new(check_matrix_arc, config);
147
148        let error = array![0, 0, 0];
149        let detectors: Array1<Bit> = array![0, 0];
150
151        let result = decoder.decode_detailed(detectors.view());
152
153        assert_eq!(result.decoding, error);
154        assert_eq!(result.decoded_detectors, detectors);
155        assert_eq!(result.max_iter, iterations);
156        assert!(result.success);
157
158        let error = array![1, 0, 0];
159        let detectors: Array1<Bit> = array![1, 0];
160
161        let result = decoder.decode_detailed(detectors.view());
162
163        assert_eq!(result.decoding, error);
164        assert_eq!(result.decoded_detectors, detectors);
165        assert_eq!(result.max_iter, iterations);
166        assert!(result.success);
167
168        let error = array![0, 1, 0];
169        let detectors: Array1<Bit> = array![1, 1];
170
171        let result = decoder.decode_detailed(detectors.view());
172
173        assert_eq!(result.decoding, error);
174        assert_eq!(result.decoded_detectors, detectors);
175        assert_eq!(result.max_iter, iterations);
176        assert!(result.success);
177
178        let error = array![0, 0, 1];
179        let detectors: Array1<Bit> = array![0, 1];
180
181        let result = decoder.decode_detailed(detectors.view());
182
183        assert_eq!(result.decoding, error);
184        assert_eq!(result.decoded_detectors, detectors);
185        assert_eq!(result.max_iter, iterations);
186        assert!(result.success);
187    }
188
189    #[test]
190    fn decode_144_12_12() {
191        let resources = get_test_data_path();
192        let code_144_12_12 =
193            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
194        let detectors_144_12_12: Array2<Bit> =
195            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
196        let mut bp_config_144_12_12 = MinSumDecoderConfig {
197            error_priors: code_144_12_12.error_priors,
198            ..Default::default()
199        };
200        bp_config_144_12_12.set_fixed(5, 2);
201
202        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
203        let bp_config = Arc::new(bp_config_144_12_12);
204        let mut decoder_144_12_12: MinSumBPDecoderFixed =
205            MinSumBPDecoderFixed::new(check_matrix, bp_config);
206        let num_errors = 1;
207        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
208        let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
209
210        assert!(
211            results.iter().map(|x| x.success as usize).sum::<usize>() as f64
212                >= (detectors_slice.shape()[0] as f64) * 0.93
213        );
214
215        assert_eq!(results[0].decoding.len(), 8785);
216    }
217}