1use crate::bipartite_graph::SparseBipartiteGraph;
12use crate::decoder::{BPExtraResult, DecodeResult, Decoder, DecoderRunner};
13use crate::decoder::{Bit, SparseBitMatrix};
14use itertools::izip;
15use log::debug;
16use ndarray::{Array1, ArrayView1};
17use num_traits::FromPrimitive;
18use num_traits::{Bounded, Signed, ToPrimitive};
19use sprs::CsMatView;
20use std::fmt::Debug;
21use std::sync::Arc;
22
23#[derive(Clone, Debug)]
24pub struct MinSumDecoderConfig {
25 pub error_priors: Array1<f64>,
26 pub max_iter: usize,
27 pub alpha: Option<f64>,
28 pub alpha_iteration_scaling_factor: f64,
29 pub gamma0: Option<f64>,
30 pub data_scale_value: Option<f64>,
31 pub max_data_value: Option<f64>,
32 pub int_bits: Option<isize>,
33 pub frac_bits: Option<isize>,
34}
35
36impl Default for MinSumDecoderConfig {
37 fn default() -> Self {
38 Self {
39 error_priors: Default::default(),
40 max_iter: 200,
41 alpha: None,
42 alpha_iteration_scaling_factor: 1.,
43 gamma0: None,
44 data_scale_value: None,
45 max_data_value: None,
46 int_bits: None,
47 frac_bits: None,
48 }
49 }
50}
51
52impl MinSumDecoderConfig {
53 pub fn prior_ratios(&self) -> Array1<f64> {
54 (1.0 - &self.error_priors) / &self.error_priors
56 }
57
58 pub fn log_prior_ratios(&self) -> Array1<f64> {
59 self.prior_ratios().ln()
60 }
61
62 pub fn set_max_iter(&mut self, iterations: usize) {
63 self.max_iter = iterations;
64 }
65
66 pub fn set_fixed(&mut self, int_bits: isize, frac_bits: isize) {
67 self.int_bits = Some(int_bits);
68 self.frac_bits = Some(frac_bits);
69 self.max_data_value = Some((1 << (int_bits - 1)) as f64);
70 }
71}
72
73#[derive(Clone)]
76pub struct MinSumBPDecoder<N: PartialEq + Default + Clone + Copy> {
77 check_matrix: Arc<SparseBitMatrix>,
78 pub config: Arc<MinSumDecoderConfig>,
79 log_prior_ratios: Array1<N>,
80 check_to_variable: SparseBipartiteGraph<N>,
81 variable_to_check: SparseBipartiteGraph<N>,
82 check_to_variable_nnz_map: Vec<usize>,
84 variable_to_check_nnz_map: Vec<usize>,
86 posterior_ratios: Array1<N>,
87 memory_strengths: Array1<N>,
88 decoding: Array1<Bit>,
89 max_data_value: Option<N>,
90 data_scale_value: Option<N>,
91 pub current_iteration: usize,
92}
93
94impl<N> MinSumBPDecoder<N>
95where
96 N: PartialEq
97 + Debug
98 + Default
99 + Clone
100 + Copy
101 + Signed
102 + Bounded
103 + FromPrimitive
104 + ToPrimitive
105 + std::cmp::PartialOrd
106 + std::ops::Add
107 + std::ops::AddAssign
108 + std::ops::DivAssign
109 + std::ops::Mul<N>
110 + std::ops::MulAssign
111 + Send
112 + Sync
113 + std::fmt::Display
114 + 'static,
115{
116 pub fn new(
117 check_matrix: Arc<SparseBitMatrix>,
118 config: Arc<MinSumDecoderConfig>,
119 ) -> MinSumBPDecoder<N> {
120 let check_to_variable = MinSumBPDecoder::build_check_to_variable(check_matrix.clone());
121 let variable_to_check = MinSumBPDecoder::build_variable_to_check(check_matrix.clone());
122 let (check_to_variable_nnz_map, variable_to_check_nnz_map) =
123 MinSumBPDecoder::build_nnz_maps(check_to_variable.view(), variable_to_check.view());
124
125 let max_data_value = match config.max_data_value {
126 Some(val) => N::from_f64(val),
127 None => None,
128 };
129
130 let data_scale_value = match config.data_scale_value {
131 Some(val) => N::from_f64(val),
132 None => None,
133 };
134
135 let log_prior_ratios = config.log_prior_ratios().mapv_into_any(|val| {
136 let updated_val = match val {
137 f64::INFINITY => N::max_value(),
138 _ => {
139 let prior = match config.data_scale_value {
141 Some(scale_value) => scale_value * val,
142 None => val,
143 };
144 N::from_f64(prior).unwrap()
145 }
146 };
147
148 match max_data_value {
150 Some(max_val) => Self::bound_value_magnitude(updated_val, max_val),
151 None => updated_val,
152 }
153 });
154
155 let memory_strengths = Array1::from_elem(check_matrix.cols(), N::zero());
156
157 let posterior_ratios = if config.gamma0.is_some() {
158 log_prior_ratios.clone()
159 } else {
160 Array1::zeros(check_matrix.cols())
161 };
162
163 let decoding = Array1::zeros(check_matrix.cols());
164
165 MinSumBPDecoder::<N> {
166 check_matrix,
167 config,
168 log_prior_ratios,
169 check_to_variable,
170 variable_to_check,
171 check_to_variable_nnz_map,
172 variable_to_check_nnz_map,
173 posterior_ratios,
174 memory_strengths,
175 decoding,
176 max_data_value,
177 data_scale_value,
178 current_iteration: 0,
179 }
180 }
181
182 pub fn set_log_prior_ratio(&mut self, mut log_prior_ratios: Array1<N>) {
183 self.log_prior_ratios = match self.data_scale_value {
184 Some(scale_val) => {
185 log_prior_ratios.iter_mut().for_each(|v| *v *= scale_val);
186 log_prior_ratios
187 }
188 None => log_prior_ratios,
189 };
190 }
191
192 pub fn set_log_prior_ratio_f64(&mut self, log_prior_ratios: Array1<f64>) {
193 self.log_prior_ratios = match self.config.data_scale_value {
194 Some(scale_val) => {
195 log_prior_ratios.mapv_into_any(|v| N::from_f64(scale_val * v).unwrap())
196 }
197 None => log_prior_ratios.mapv_into_any(|v| N::from_f64(v).unwrap()),
198 };
199 }
200
201 pub fn set_posterior_ratios_to_priors(&mut self) {
202 self.posterior_ratios = self.log_prior_ratios.clone();
203 }
204
205 pub fn set_memory_strengths_f64(&mut self, memory_strengths: Array1<f64>) {
207 self.memory_strengths = match self.config.data_scale_value {
208 Some(scale_val) => {
209 memory_strengths.mapv_into_any(|v| N::from_f64(scale_val * v).unwrap())
210 }
211 None => memory_strengths.mapv_into_any(|v| N::from_f64(v).unwrap()),
212 };
213 }
214
215 pub fn set_memory_strengths(&mut self, mut memory_strengths: Array1<N>) {
217 self.memory_strengths = match self.data_scale_value {
218 Some(scale_val) => {
219 memory_strengths.iter_mut().for_each(|v| *v *= scale_val);
220 memory_strengths
221 }
222 None => memory_strengths,
223 };
224 }
225
226 fn build_check_to_variable(check_matrix: Arc<SparseBitMatrix>) -> SparseBipartiteGraph<N> {
228 let check_matrix_csc = check_matrix.to_csc();
229
230 let default_messages: Vec<_> = vec![N::default(); check_matrix_csc.nnz()];
231
232 SparseBipartiteGraph::new_csc(
233 check_matrix_csc.shape(),
234 check_matrix_csc.indptr().raw_storage().to_vec(),
235 check_matrix_csc.indices().to_vec(),
236 default_messages,
237 )
238 }
239
240 fn build_variable_to_check(check_matrix: Arc<SparseBitMatrix>) -> SparseBipartiteGraph<N> {
242 let check_matrix_csr = check_matrix.to_csr();
243
244 let default_messages: Vec<_> = vec![N::default(); check_matrix_csr.nnz()];
245
246 SparseBipartiteGraph::new(
247 check_matrix_csr.shape(),
248 check_matrix_csr.indptr().raw_storage().to_vec(),
249 check_matrix_csr.indices().to_vec(),
250 default_messages,
251 )
252 }
253
254 fn build_nnz_maps(
255 check_to_variable: CsMatView<N>,
256 variable_to_check: CsMatView<N>,
257 ) -> (Vec<usize>, Vec<usize>) {
258 let mut check_to_variable_nnz_map: Vec<usize> = Vec::with_capacity(check_to_variable.nnz());
259 for (_, (row, col)) in check_to_variable.view().iter() {
260 check_to_variable_nnz_map.push(variable_to_check.nnz_index(row, col).unwrap().0);
261 }
262 let mut variable_to_check_nnz_map: Vec<usize> = Vec::with_capacity(variable_to_check.nnz());
263 for (_, (row, col)) in variable_to_check.view().iter() {
264 variable_to_check_nnz_map.push(check_to_variable.nnz_index(row, col).unwrap().0);
265 }
266 (check_to_variable_nnz_map, variable_to_check_nnz_map)
267 }
268
269 pub fn initialize_variable_to_check(&mut self) {
271 for mut row_vec in self.variable_to_check.outer_iterator_mut() {
272 row_vec
273 .iter_mut()
274 .for_each(|(col_ind, val)| *val = self.log_prior_ratios[col_ind]);
275 }
276 }
277
278 pub fn initialize_check_to_variable(&mut self) {
279 for mut col_vec in self.check_to_variable.outer_iterator_mut() {
280 col_vec
281 .iter_mut()
282 .for_each(|(_row_ind, val)| *val = N::zero());
283 }
284 }
285
286 pub fn initialize_memory_strengths(&mut self) {
287 let ewa_factor_float = self.config.gamma0.unwrap_or(0.);
289 let ewa_factor = N::from_f64(match self.config.data_scale_value {
290 Some(scale_value) => scale_value * ewa_factor_float,
291 None => ewa_factor_float,
292 })
293 .unwrap();
294 self.memory_strengths.fill(ewa_factor);
295 }
296
297 pub fn initialize_decoder(&mut self) {
298 self.current_iteration = 0;
299 self.initialize_memory_strengths();
300 self.initialize_check_to_variable();
301 self.initialize_variable_to_check();
302 if self.config.gamma0.is_some() {
304 self.set_posterior_ratios_to_priors();
305 };
306 }
307
308 fn alpha(&self) -> N {
309 let mut alpha = match self.config.alpha {
310 Some(0.) => {
311 let iteration = (self.current_iteration + 1) as f64;
312 1.0 - (2_f64).powf(-(iteration / self.config.alpha_iteration_scaling_factor))
313 }
314 Some(val) => val,
315 None => 1.0,
316 };
317
318 if alpha < 0. {
321 alpha = 1.
322 }
323 alpha = match self.config.data_scale_value {
325 Some(scale_val) => scale_val * alpha,
326 None => alpha,
327 };
328
329 N::from_f64(alpha).unwrap()
330 }
331
332 fn compute_check_to_variable(
334 &mut self,
335 detectors: ArrayView1<Bit>,
336 ) -> &mut SparseBipartiteGraph<N> {
337 let alpha = self.alpha();
338
339 for (var_check_row_ind, var_check_row_vec) in
340 self.variable_to_check.outer_iterator().enumerate()
341 {
342 let row_sign = if detectors[var_check_row_ind] == 1 {
343 N::one().neg()
344 } else {
345 N::one()
346 };
347 let mut accumulated_sign = row_sign.is_negative();
348 let mut min_ind: usize = 0;
349 let mut min_message = N::max_value();
351 let mut second_min_message = N::max_value();
353
354 for (var_check_col_ind, var_check_col_val) in var_check_row_vec.iter() {
355 accumulated_sign ^= var_check_col_val.is_negative();
356 let abs_msg = var_check_col_val.abs();
357 if abs_msg <= min_message {
358 second_min_message = min_message;
359 min_message = abs_msg;
360 min_ind = var_check_col_ind;
361 } else if abs_msg <= second_min_message {
362 second_min_message = abs_msg
363 }
364 }
365
366 debug!("Variable messages for row {var_check_row_ind:?}: {var_check_row_vec:?}");
367
368 let data_range = self
370 .variable_to_check
371 .indptr()
372 .outer_inds(var_check_row_ind);
373
374 for (ind, var_check_col_ind, var_check_col_val) in izip!(
375 data_range.clone(),
376 &self.variable_to_check.indices()[data_range.clone()],
377 &self.variable_to_check.data()[data_range.clone()]
378 ) {
379 let check_to_variable_sign = accumulated_sign ^ var_check_col_val.is_negative();
381 let check_to_variable_min: N = if *var_check_col_ind != min_ind {
382 min_message
383 } else {
384 second_min_message
385 };
386 let mut check_to_variable = alpha * check_to_variable_min;
388 if check_to_variable_sign {
389 check_to_variable = check_to_variable.neg();
390 }
391
392 self.check_to_variable.data_mut()[self.variable_to_check_nnz_map[ind]] =
395 check_to_variable;
396 }
397 }
398
399 if let Some(scale_val) = self.data_scale_value {
400 self.check_to_variable /= scale_val
401 }
402
403 &mut self.check_to_variable
404 }
405
406 fn compute_variable_prior(&self, variable: usize) -> N {
407 if self.config.gamma0.is_some() {
409 if self.log_prior_ratios[variable] == N::max_value() {
410 return self.log_prior_ratios[variable];
411 }
412 let scaled_one = self.data_scale_value.unwrap_or(N::one());
413 let prior_component = (self.log_prior_ratios[variable] / scaled_one)
415 * (scaled_one - self.memory_strengths[variable]);
416 let posterior_component =
417 (self.posterior_ratios[variable] / scaled_one) * self.memory_strengths[variable];
418 return prior_component + posterior_component;
419 }
420 self.log_prior_ratios[variable]
421 }
422
423 fn compute_variable_to_check(&mut self) -> &mut SparseBipartiteGraph<N> {
425 for (check_var_col_ind, check_var_col_vec) in
426 self.check_to_variable.outer_iterator().enumerate()
427 {
428 let mut check_to_var_row_sum = self.compute_variable_prior(check_var_col_ind);
430
431 debug!("Check messages for col {check_var_col_ind:?}: {check_var_col_vec:?}");
432
433 let data_range = self
434 .check_to_variable
435 .indptr()
436 .outer_inds(check_var_col_ind);
437
438 for (ind, check_var_row_val) in izip!(
440 data_range.clone(),
441 &self.check_to_variable.data()[data_range.clone()]
442 ) {
443 self.variable_to_check.data_mut()[self.check_to_variable_nnz_map[ind]] =
444 check_to_var_row_sum;
445 check_to_var_row_sum += *check_var_row_val;
446 }
447
448 self.posterior_ratios[check_var_col_ind] = check_to_var_row_sum;
449
450 check_to_var_row_sum = N::zero();
452 for (ind, check_var_row_val) in izip!(
454 data_range.clone(),
455 &self.check_to_variable.data()[data_range.clone()]
456 )
457 .rev()
458 {
459 let map_ind = self.check_to_variable_nnz_map[ind];
460 self.variable_to_check.data_mut()[map_ind] += check_to_var_row_sum;
461 check_to_var_row_sum += *check_var_row_val;
462
463 debug!(
466 "location ({:?}, {:?}), variable_to_check: {:.32}",
467 self.check_to_variable.indices()[ind],
468 check_var_col_ind,
469 self.variable_to_check.data_mut()[self.check_to_variable_nnz_map[ind]]
470 );
471 }
472 }
473
474 self.bound_magnitudes();
475
476 &mut self.variable_to_check
477 }
478
479 pub fn run_iteration(&mut self, detectors: ArrayView1<Bit>) {
480 debug!("Iteration {:?} start", self.current_iteration);
481 self.compute_check_to_variable(detectors);
482 self.compute_variable_to_check();
484
485 self.compute_hard_decision();
486 debug!("Iteration {:?} end", self.current_iteration);
487 }
488
489 pub fn build_result(
490 &mut self,
491 success: bool,
492 decoded_detectors: Array1<Bit>,
493 max_iter: usize,
494 ) -> DecodeResult {
495 DecodeResult {
496 decoding: self.decoding.clone(),
497 decoded_detectors,
498 posterior_ratios: self.posterior_ratios.clone().mapv_into_any(|val| {
499 let posterior = N::to_f64(&val).unwrap();
500 match self.config.data_scale_value {
501 Some(scale_val) => posterior / scale_val,
502 None => posterior,
503 }
504 }),
505 success,
506 decoding_quality: if success {
507 self.get_decoding_quality(self.decoding.clone().view())
508 } else {
509 f64::MAX
510 },
511 iterations: self.current_iteration,
512 max_iter,
513 extra: BPExtraResult::None,
514 }
515 }
516 fn bound_magnitudes(&mut self) {
517 if self.max_data_value.is_some() {
519 let max_val = self.max_data_value.unwrap();
520 self.variable_to_check
521 .data_mut()
522 .iter_mut()
523 .for_each(|v| *v = Self::bound_value_magnitude(*v, max_val));
524 self.posterior_ratios
525 .iter_mut()
526 .for_each(|v| *v = Self::bound_value_magnitude(*v, max_val));
527 }
528 }
529
530 fn bound_value_magnitude(value: N, max_val: N) -> N
531 where
532 N: std::ops::Add,
533 {
534 if value < max_val.neg() {
535 max_val.neg()
536 } else if value > max_val {
537 max_val
538 } else {
539 value
540 }
541 }
542
543 fn compute_hard_decision(&mut self) {
544 for (idx, posterior) in self.posterior_ratios.iter().enumerate() {
545 self.decoding[idx] = Bit::from((*posterior) <= N::zero());
546 }
547 debug!("Posteriors: {:?}", self.posterior_ratios);
548 debug!("Hard decision: {:?}", self.decoding);
549 }
550
551 pub fn compute_decoded_detectors(&self) -> Array1<Bit> {
552 self.get_detectors(self.decoding.view())
553 }
554
555 pub fn check_convergence(
557 &self,
558 detectors: ArrayView1<Bit>,
559 decoded_detectors: ArrayView1<Bit>,
560 ) -> bool {
561 detectors == decoded_detectors
562 }
563}
564
565impl<N> Decoder for MinSumBPDecoder<N>
566where
567 N: PartialEq
568 + Debug
569 + Default
570 + Clone
571 + Copy
572 + FromPrimitive
573 + ToPrimitive
574 + Signed
575 + Bounded
576 + std::cmp::PartialOrd
577 + std::ops::Add
578 + std::ops::AddAssign
579 + std::ops::Mul<N>
580 + std::ops::MulAssign
581 + std::ops::DivAssign
582 + Send
583 + Sync
584 + std::fmt::Display
585 + 'static,
586{
587 fn check_matrix(&self) -> Arc<SparseBitMatrix> {
588 self.check_matrix.clone()
589 }
590
591 fn log_prior_ratios(&mut self) -> Array1<f64> {
592 self.config.log_prior_ratios()
593 }
594
595 fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
596 self.initialize_decoder();
598 let mut success: bool = false;
599 let mut decoded_detectors = Array1::default(detectors.dim());
600
601 for _ in 0..self.config.max_iter {
602 self.run_iteration(detectors);
603 self.current_iteration += 1;
604 decoded_detectors = self.compute_decoded_detectors();
605 success = self.check_convergence(detectors, decoded_detectors.view());
606
607 if success {
609 debug!("Succeeded on iteration {:?}", self.current_iteration);
610 break;
611 }
612 }
613
614 self.build_result(success, decoded_detectors, self.config.max_iter)
615 }
616}
617
618impl<N> DecoderRunner for MinSumBPDecoder<N> where
619 N: PartialEq
620 + Debug
621 + Default
622 + Clone
623 + Copy
624 + FromPrimitive
625 + ToPrimitive
626 + Signed
627 + Bounded
628 + std::cmp::PartialOrd
629 + std::ops::Add
630 + std::ops::AddAssign
631 + std::ops::Mul<N>
632 + std::ops::MulAssign
633 + std::ops::DivAssign
634 + Send
635 + Sync
636 + std::fmt::Display
637 + 'static
638{
639}
640
641#[cfg(test)]
642mod tests {
643 use crate::bipartite_graph::BipartiteGraph;
644
645 use super::*;
646 use env_logger;
647 use ndarray::prelude::*;
648
649 use crate::dem::DetectorErrorModel;
650 use crate::utilities::test::get_test_data_path;
651 use ndarray::Array2;
652 use ndarray_npy::read_npy;
653
654 fn init() {
655 let _ = env_logger::builder().is_test(true).try_init();
656 }
657
658 #[test]
659 fn decode_detailed_repetition_code() {
660 init();
661
662 let check_matrix = array![[1, 1, 0], [0, 1, 1],];
664
665 let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
666 let arc_check_matrix = Arc::new(check_matrix);
667
668 let iterations = 10;
669 let bp_config = MinSumDecoderConfig {
670 error_priors: array![0.003, 0.003, 0.003],
671 max_iter: iterations,
672 ..Default::default()
673 };
674 let arc_bp_config = Arc::new(bp_config);
675
676 let mut decoder: MinSumBPDecoder<f64> =
677 MinSumBPDecoder::new(arc_check_matrix, arc_bp_config);
678
679 let error = array![0, 0, 0];
680 let detectors: Array1<Bit> = array![0, 0];
681
682 let result = decoder.decode_detailed(detectors.view());
683
684 assert_eq!(result.decoding, error);
685 assert_eq!(result.decoded_detectors, detectors);
686 assert_eq!(result.max_iter, iterations);
687 assert!(result.success);
688
689 let error = array![1, 0, 0];
690 let detectors: Array1<Bit> = array![1, 0];
691
692 let result = decoder.decode_detailed(detectors.view());
693
694 assert_eq!(result.decoding, error);
695 assert_eq!(result.decoded_detectors, detectors);
696 assert_eq!(result.max_iter, iterations);
697 assert!(result.success);
698
699 let error = array![0, 1, 0];
700 let detectors: Array1<Bit> = array![1, 1];
701
702 let result = decoder.decode_detailed(detectors.view());
703
704 assert_eq!(result.decoding, error);
705 assert_eq!(result.decoded_detectors, detectors);
706 assert_eq!(result.max_iter, iterations);
707 assert!(result.success);
708
709 let error = array![0, 0, 1];
710 let detectors: Array1<Bit> = array![0, 1];
711
712 let result = decoder.decode_detailed(detectors.view());
713
714 assert_eq!(result.decoding, error);
715 assert_eq!(result.decoded_detectors, detectors);
716 assert_eq!(result.max_iter, iterations);
717 assert!(result.success);
718 }
719
720 #[test]
721 fn decode_detailed_repetition_code_int() {
722 init();
723
724 let check_matrix = array![[1, 1, 0], [0, 1, 1],];
726
727 let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
728 let arc_check_matrix = Arc::new(check_matrix);
729
730 let iterations = 10;
731
732 let bits = 7;
733 let scale = 4.0;
734
735 let bp_config = MinSumDecoderConfig {
736 error_priors: array![0.003, 0.003, 0.003],
737 max_iter: iterations,
738 max_data_value: Some(((1 << bits) - 1) as f64),
739 data_scale_value: Some(scale),
740 ..Default::default()
741 };
742 let arc_bp_config = Arc::new(bp_config);
743
744 let mut decoder: MinSumBPDecoder<isize> =
745 MinSumBPDecoder::new(arc_check_matrix, arc_bp_config);
746
747 let error = array![0, 0, 0];
748 let detectors: Array1<Bit> = array![0, 0];
749
750 let result = decoder.decode_detailed(detectors.view());
751
752 assert_eq!(result.decoding, error);
753 assert_eq!(result.decoded_detectors, detectors);
754 assert_eq!(result.max_iter, iterations);
755 assert!(result.success);
756
757 let error = array![1, 0, 0];
758 let detectors: Array1<Bit> = array![1, 0];
759
760 let result = decoder.decode_detailed(detectors.view());
761
762 assert_eq!(result.decoding, error);
763 assert_eq!(result.decoded_detectors, detectors);
764 assert_eq!(result.max_iter, iterations);
765 assert!(result.success);
766
767 let error = array![0, 1, 0];
768 let detectors: Array1<Bit> = array![1, 1];
769
770 let result = decoder.decode_detailed(detectors.view());
771
772 assert_eq!(result.decoding, error);
773 assert_eq!(result.decoded_detectors, detectors);
774 assert_eq!(result.max_iter, iterations);
775 assert!(result.success);
776
777 let error = array![0, 0, 1];
778 let detectors: Array1<Bit> = array![0, 1];
779
780 let result = decoder.decode_detailed(detectors.view());
781
782 assert_eq!(result.decoding, error);
783 assert_eq!(result.decoded_detectors, detectors);
784 assert_eq!(result.max_iter, iterations);
785 assert!(result.success);
786 }
787
788 #[test]
789 fn decode_detailed_144_12_12() {
790 let resources = get_test_data_path();
791 let code_144_12_12 =
792 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
793 let detectors_144_12_12: Array2<Bit> =
794 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
795 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
796 let bp_config_144_12_12 = MinSumDecoderConfig {
797 error_priors: code_144_12_12.error_priors,
798 alpha: Some(0.),
799 ..Default::default()
800 };
801 let config = Arc::new(bp_config_144_12_12);
802
803 let mut decoder_144_12_12: MinSumBPDecoder<f64> =
804 MinSumBPDecoder::new(check_matrix, config);
805 let num_errors = 100;
806 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
807 let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
808
809 assert!(
810 results.iter().map(|x| x.success as usize).sum::<usize>() as f64
811 >= (detectors_slice.shape()[0] as f64) * 0.93
812 );
813
814 assert_eq!(results[0].decoding.len(), 8785);
815 }
816
817 #[test]
818 fn decode_144_12_12() {
819 let resources = get_test_data_path();
820 let code_144_12_12 =
821 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
822 let detectors_144_12_12: Array2<Bit> =
823 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
824 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
825 let bp_config_144_12_12 = MinSumDecoderConfig {
826 error_priors: code_144_12_12.error_priors,
827 ..Default::default()
828 };
829 let config = Arc::new(bp_config_144_12_12);
830
831 let mut decoder_144_12_12: MinSumBPDecoder<f64> =
832 MinSumBPDecoder::new(check_matrix, config);
833 let num_errors = 100;
834 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
835
836 let results = decoder_144_12_12.par_decode_batch(detectors_slice);
837
838 let results_detailed = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
839
840 for i in 0..results.shape()[0] {
841 assert!(results.row(i) == results_detailed[i].decoding)
842 }
843 }
844
845 #[test]
846 fn decode_detailed_144_12_12_membp() {
847 let resources = get_test_data_path();
848 let code_144_12_12 =
849 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
850 let detectors_144_12_12: Array2<Bit> =
851 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
852 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
853 let bp_config_144_12_12 = MinSumDecoderConfig {
854 error_priors: code_144_12_12.error_priors,
855 gamma0: Some(0.15),
856 ..Default::default()
857 };
858 let config = Arc::new(bp_config_144_12_12);
859
860 let mut decoder_144_12_12: MinSumBPDecoder<f64> =
861 MinSumBPDecoder::new(check_matrix, config);
862 let num_errors = 100;
863 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
864 let par_results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
865 let results = decoder_144_12_12.decode_detailed_batch(detectors_slice);
866 assert!(
867 results.iter().map(|x| x.success as usize).sum::<usize>() as f64
868 == par_results
869 .iter()
870 .map(|x| x.success as usize)
871 .sum::<usize>() as f64
872 );
873 assert!(
874 results.iter().map(|x| x.success as usize).sum::<usize>() as f64
875 >= (detectors_slice.shape()[0] as f64) * 0.93
876 );
877
878 assert_eq!(results[0].decoding.len(), 8785);
879 }
880
881 #[test]
882 fn decode_detailed_144_12_12_int() {
883 let resources = get_test_data_path();
884 let code_144_12_12 =
885 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
886 let detectors_144_12_12: Array2<Bit> =
887 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
888 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
889
890 let bits = 16;
891 let scale = 8.0;
892
893 let bp_config_144_12_12 = MinSumDecoderConfig {
894 error_priors: code_144_12_12.error_priors,
895 max_data_value: Some(((1 << bits) - 1) as f64),
896 data_scale_value: Some(scale),
897 alpha: Some(0.),
898 ..Default::default()
899 };
900 let config = Arc::new(bp_config_144_12_12);
901
902 let mut decoder_144_12_12: MinSumBPDecoder<isize> =
903 MinSumBPDecoder::new(check_matrix, config);
904 let num_errors = 100;
905 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
906 let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
907
908 assert!(
909 results.iter().map(|x| x.success as usize).sum::<usize>() as f64
910 >= (detectors_slice.shape()[0] as f64) * 0.93
911 );
912
913 assert_eq!(results[0].decoding.len(), 8785);
914 }
915}