1use fixed::{types::extra as fbits, FixedI64};
12
13use crate::bp::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
14use crate::decoder::{Bit, DecodeResult, Decoder, DecoderRunner, SparseBitMatrix};
15use ndarray::{Array1, ArrayView1};
16
17use std::sync::Arc;
18
19#[derive(Clone)]
22pub struct MinSumBPDecoderFixed {
23 decoder: Box<dyn Decoder + Send + 'static>,
24}
25
26impl MinSumBPDecoderFixed {
27 pub fn new(
28 check_matrix: Arc<SparseBitMatrix>,
29 min_sum_config: Arc<MinSumDecoderConfig>,
30 ) -> Self {
31 MinSumBPDecoderFixed {
32 decoder: MinSumBPDecoderFixed::build_min_sum(check_matrix, min_sum_config),
33 }
34 }
35
36 fn build_min_sum(
37 check_matrix: Arc<SparseBitMatrix>,
38 min_sum_config: Arc<MinSumDecoderConfig>,
39 ) -> Box<dyn Decoder + Send + 'static> {
40 macro_rules! genmatch {
41 ($Variant:path) => {
42 return Box::new(MinSumBPDecoder::<FixedI64<$Variant>>::new(
43 check_matrix,
44 min_sum_config,
45 ))
46 };
47 }
48
49 match min_sum_config.frac_bits {
50 Some(0) => genmatch!(fbits::U0),
51 Some(1) => genmatch!(fbits::U1),
52 Some(2) => genmatch!(fbits::U2),
53 Some(3) => genmatch!(fbits::U3),
54 Some(4) => genmatch!(fbits::U4),
55 Some(5) => genmatch!(fbits::U5),
56 Some(6) => genmatch!(fbits::U6),
57 Some(7) => genmatch!(fbits::U7),
58 Some(8) => genmatch!(fbits::U8),
59 Some(9) => genmatch!(fbits::U9),
60 Some(10) => genmatch!(fbits::U10),
61 Some(11) => genmatch!(fbits::U11),
62 Some(12) => genmatch!(fbits::U12),
63 Some(13) => genmatch!(fbits::U13),
64 Some(14) => genmatch!(fbits::U14),
65 Some(15) => genmatch!(fbits::U15),
66 Some(16) => genmatch!(fbits::U16),
67 Some(17) => genmatch!(fbits::U17),
68 Some(18) => genmatch!(fbits::U18),
69 Some(19) => genmatch!(fbits::U19),
70 Some(20) => genmatch!(fbits::U20),
71 Some(21) => genmatch!(fbits::U21),
72 Some(22) => genmatch!(fbits::U22),
73 Some(23) => genmatch!(fbits::U23),
74 Some(24) => genmatch!(fbits::U24),
75 Some(25) => genmatch!(fbits::U25),
76 Some(26) => genmatch!(fbits::U26),
77 Some(27) => genmatch!(fbits::U27),
78 Some(28) => genmatch!(fbits::U28),
79 Some(29) => genmatch!(fbits::U29),
80 Some(30) => genmatch!(fbits::U30),
81 Some(31) => genmatch!(fbits::U31),
82 _ => panic!(
83 "Decoder does not support {:?} fractional bits",
84 min_sum_config.frac_bits
85 ),
86 }
87 }
88}
89
90impl Decoder for MinSumBPDecoderFixed {
91 fn check_matrix(&self) -> Arc<SparseBitMatrix> {
92 self.decoder.check_matrix()
93 }
94
95 fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
96 self.decoder.decode_detailed(detectors)
97 }
98
99 fn log_prior_ratios(&mut self) -> Array1<f64> {
100 self.decoder.log_prior_ratios()
101 }
102}
103
104impl DecoderRunner for MinSumBPDecoderFixed {}
105
106#[cfg(test)]
107mod tests {
108 use crate::bipartite_graph::BipartiteGraph;
109 use crate::bipartite_graph::SparseBipartiteGraph;
110 use crate::decoder::{Bit, Decoder, DecoderRunner};
111
112 use super::*;
113 use env_logger;
114 use ndarray::prelude::*;
115
116 use crate::dem::DetectorErrorModel;
117 use crate::utilities::test::get_test_data_path;
118 use ndarray::Array2;
119 use ndarray_npy::read_npy;
120
121 use std::sync::Arc;
122
123 fn init() {
124 let _ = env_logger::builder().is_test(true).try_init();
125 }
126
127 #[test]
128 fn decode_repetition_code_int() {
129 init();
130
131 let check_matrix = array![[1, 1, 0], [0, 1, 1],];
133
134 let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
135
136 let iterations = 10;
137 let mut bp_config = MinSumDecoderConfig {
138 error_priors: array![0.003, 0.003, 0.003],
139 max_iter: iterations,
140 ..Default::default()
141 };
142 bp_config.set_fixed(5, 2);
143
144 let check_matrix_arc = Arc::new(check_matrix);
145 let config = Arc::new(bp_config);
146 let mut decoder: MinSumBPDecoderFixed = MinSumBPDecoderFixed::new(check_matrix_arc, config);
147
148 let error = array![0, 0, 0];
149 let detectors: Array1<Bit> = array![0, 0];
150
151 let result = decoder.decode_detailed(detectors.view());
152
153 assert_eq!(result.decoding, error);
154 assert_eq!(result.decoded_detectors, detectors);
155 assert_eq!(result.max_iter, iterations);
156 assert!(result.success);
157
158 let error = array![1, 0, 0];
159 let detectors: Array1<Bit> = array![1, 0];
160
161 let result = decoder.decode_detailed(detectors.view());
162
163 assert_eq!(result.decoding, error);
164 assert_eq!(result.decoded_detectors, detectors);
165 assert_eq!(result.max_iter, iterations);
166 assert!(result.success);
167
168 let error = array![0, 1, 0];
169 let detectors: Array1<Bit> = array![1, 1];
170
171 let result = decoder.decode_detailed(detectors.view());
172
173 assert_eq!(result.decoding, error);
174 assert_eq!(result.decoded_detectors, detectors);
175 assert_eq!(result.max_iter, iterations);
176 assert!(result.success);
177
178 let error = array![0, 0, 1];
179 let detectors: Array1<Bit> = array![0, 1];
180
181 let result = decoder.decode_detailed(detectors.view());
182
183 assert_eq!(result.decoding, error);
184 assert_eq!(result.decoded_detectors, detectors);
185 assert_eq!(result.max_iter, iterations);
186 assert!(result.success);
187 }
188
189 #[test]
190 fn decode_144_12_12() {
191 let resources = get_test_data_path();
192 let code_144_12_12 =
193 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
194 let detectors_144_12_12: Array2<Bit> =
195 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
196 let mut bp_config_144_12_12 = MinSumDecoderConfig {
197 error_priors: code_144_12_12.error_priors,
198 ..Default::default()
199 };
200 bp_config_144_12_12.set_fixed(5, 2);
201
202 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
203 let bp_config = Arc::new(bp_config_144_12_12);
204 let mut decoder_144_12_12: MinSumBPDecoderFixed =
205 MinSumBPDecoderFixed::new(check_matrix, bp_config);
206 let num_errors = 1;
207 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
208 let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
209
210 assert!(
211 results.iter().map(|x| x.success as usize).sum::<usize>() as f64
212 >= (detectors_slice.shape()[0] as f64) * 0.93
213 );
214
215 assert_eq!(results[0].decoding.len(), 8785);
216 }
217}