1use crate::bipartite_graph::SparseBipartiteGraph;
12use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis};
13
14use indicatif::{ParallelProgressIterator, ProgressFinish, ProgressIterator, ProgressStyle};
15use rayon::prelude::*;
16
17use serde::{Deserialize, Serialize};
18use std::fmt::Debug;
19
20pub type Bit = u8;
21pub type SparseBitMatrix = SparseBipartiteGraph<Bit>;
22
23use dyn_clone::DynClone;
24
25use std::sync::Arc;
26
27pub trait Mod2Mul<Rhs = Self> {
28 type Output;
29
30 fn mul_mod2(&self, rhs: Rhs) -> Self::Output;
32}
33
34impl Mod2Mul<&Array1<Bit>> for SparseBitMatrix {
35 type Output = Array1<Bit>;
36
37 fn mul_mod2(&self, rhs: &Array1<Bit>) -> Self::Output {
38 let mut detectors_u8 = self * rhs;
39 detectors_u8.map_inplace(|x| *x %= 2);
40 detectors_u8
41 }
42}
43
44pub trait Decoder: DynClone + Sync {
45 fn check_matrix(&self) -> Arc<SparseBitMatrix>;
46 fn log_prior_ratios(&mut self) -> Array1<f64>;
47 fn decode(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
49 self.decode_detailed(detectors).decoding
50 }
51 fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult;
52 fn decode_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
53 let arrs: Vec<Array1<Bit>> = detectors
54 .axis_iter(Axis(0))
55 .map(|row| self.decode(row))
56 .collect();
57
58 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
59 }
60 fn decode_detailed_batch(&mut self, detectors: ArrayView2<Bit>) -> Vec<DecodeResult> {
61 detectors
62 .axis_iter(Axis(0))
63 .map(|row| self.decode_detailed(row))
64 .collect()
65 }
66 fn get_detectors(&self, errors: ArrayView1<Bit>) -> Array1<Bit> {
68 self.check_matrix().mul_mod2(&errors.to_owned())
69 }
70
71 fn get_detectors_batch(&self, errors: ArrayView2<Bit>) -> Array2<Bit> {
72 let check_matrix = self.check_matrix();
73 let detectors: Vec<Array1<Bit>> = errors
74 .axis_iter(Axis(0))
75 .map(|row| check_matrix.mul_mod2(&row.to_owned()))
76 .collect();
77 stack(
78 Axis(0),
79 &detectors.iter().map(|a| a.view()).collect::<Vec<_>>(),
80 )
81 .unwrap()
82 }
83
84 fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
85 let log_prior_ratios = self.log_prior_ratios();
86 let mut decoding_quality: f64 = 0.0;
87 for i in 0..errors.len() {
88 if errors[i] == 1 && f64::is_finite(log_prior_ratios[i]) {
89 decoding_quality += log_prior_ratios[i];
90 }
91 }
92 decoding_quality
93 }
94}
95
96dyn_clone::clone_trait_object!(Decoder);
97
98pub trait DecoderRunner: Decoder + Clone + Sync {
99 fn par_decode_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
100 let arrs: Vec<Array1<Bit>> = detectors
101 .axis_iter(Axis(0))
102 .into_par_iter()
103 .map_with(|| self.clone(), |decoder, row| decoder().decode(row))
104 .collect();
105 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
106 }
107
108 fn par_decode_detailed_batch(&mut self, detectors: ArrayView2<Bit>) -> Vec<DecodeResult> {
109 detectors
110 .axis_iter(Axis(0))
111 .into_par_iter()
112 .map_with(
113 || self.clone(),
114 |decoder, row| decoder().decode_detailed(row),
115 )
116 .collect()
117 }
118
119 fn decode_batch_progress_bar(
121 &mut self,
122 detectors: ArrayView2<Bit>,
123 leave_progress_bar_on_finish: bool,
124 ) -> Array2<Bit> {
125 let finish_mode = match leave_progress_bar_on_finish {
126 true => ProgressFinish::AndLeave,
127 false => ProgressFinish::AndClear,
128 };
129
130 let arrs: Vec<Array1<Bit>> = detectors
131 .axis_iter(Axis(0))
132 .progress_with_style(self.get_progress_bar_style())
133 .with_finish(finish_mode)
134 .map(|row| self.decode(row))
135 .collect();
136 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
137 }
138
139 fn decode_detailed_batch_progress_bar(
141 &mut self,
142 detectors: ArrayView2<Bit>,
143 leave_progress_bar_on_finish: bool,
144 ) -> Vec<DecodeResult> {
145 let finish_mode = match leave_progress_bar_on_finish {
146 true => ProgressFinish::AndLeave,
147 false => ProgressFinish::AndClear,
148 };
149
150 detectors
151 .axis_iter(Axis(0))
152 .progress_with_style(self.get_progress_bar_style())
153 .with_finish(finish_mode)
154 .map(|row| self.decode_detailed(row))
155 .collect()
156 }
157
158 fn par_decode_batch_progress_bar(
159 &mut self,
160 detectors: ArrayView2<Bit>,
161 leave_progress_bar_on_finish: bool,
162 ) -> Array2<Bit> {
163 let finish_mode = match leave_progress_bar_on_finish {
164 true => ProgressFinish::AndLeave,
165 false => ProgressFinish::AndClear,
166 };
167
168 let arrs: Vec<Array1<Bit>> = detectors
169 .axis_iter(Axis(0))
170 .into_par_iter()
171 .progress_with_style(self.get_progress_bar_style())
172 .with_finish(finish_mode)
173 .map_with(|| self.clone(), |decoder, row| decoder().decode(row))
174 .collect();
175
176 stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
177 }
178
179 fn par_decode_detailed_batch_progress_bar(
180 &mut self,
181 detectors: ArrayView2<Bit>,
182 leave_progress_bar_on_finish: bool,
183 ) -> Vec<DecodeResult> {
184 let finish_mode = match leave_progress_bar_on_finish {
185 true => ProgressFinish::AndLeave,
186 false => ProgressFinish::AndClear,
187 };
188
189 detectors
190 .axis_iter(Axis(0))
191 .into_par_iter()
192 .progress_with_style(self.get_progress_bar_style())
193 .with_finish(finish_mode)
194 .map_with(
195 || self.clone(),
196 |decoder, row| decoder().decode_detailed(row),
197 )
198 .collect()
199 }
200
201 fn get_progress_bar_style(&self) -> ProgressStyle {
202 ProgressStyle::default_bar().template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({per_sec}, {eta})").unwrap()
203 }
204}
205
206#[derive(Serialize, Deserialize, Debug, Clone)]
207pub struct DecodeResult {
208 pub decoding: Array1<Bit>,
209 pub decoded_detectors: Array1<Bit>,
210 pub posterior_ratios: Array1<f64>,
211 pub success: bool,
212 pub decoding_quality: f64,
213 pub iterations: usize,
214 pub max_iter: usize,
215 pub extra: BPExtraResult,
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219pub enum BPExtraResult {
220 None,
221}