Skip to main content

relay_bp/
decoder.rs

1// (C) Copyright IBM 2025
2//
3// This code is licensed under the Apache License, Version 2.0. You may
4// obtain a copy of this license in the LICENSE.txt file in the root directory
5// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
6//
7// Any modifications or derivative works of this code must retain this
8// copyright notice, and modified files need to carry a notice indicating
9// that they have been altered from the originals.
10
11use crate::bipartite_graph::SparseBipartiteGraph;
12use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis};
13
14use indicatif::{ParallelProgressIterator, ProgressFinish, ProgressIterator, ProgressStyle};
15use rayon::prelude::*;
16
17use serde::{Deserialize, Serialize};
18use std::fmt::Debug;
19
20pub type Bit = u8;
21pub type SparseBitMatrix = SparseBipartiteGraph<Bit>;
22
23use dyn_clone::DynClone;
24
25use std::sync::Arc;
26
27pub trait Mod2Mul<Rhs = Self> {
28    type Output;
29
30    // Required method
31    fn mul_mod2(&self, rhs: Rhs) -> Self::Output;
32}
33
34impl Mod2Mul<&Array1<Bit>> for SparseBitMatrix {
35    type Output = Array1<Bit>;
36
37    fn mul_mod2(&self, rhs: &Array1<Bit>) -> Self::Output {
38        let mut detectors_u8 = self * rhs;
39        detectors_u8.map_inplace(|x| *x %= 2);
40        detectors_u8
41    }
42}
43
44pub trait Decoder: DynClone + Sync {
45    fn check_matrix(&self) -> Arc<SparseBitMatrix>;
46    fn log_prior_ratios(&mut self) -> Array1<f64>;
47    /// Decode a single input problem
48    fn decode(&mut self, detectors: ArrayView1<Bit>) -> Array1<Bit> {
49        self.decode_detailed(detectors).decoding
50    }
51    fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult;
52    fn decode_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
53        let arrs: Vec<Array1<Bit>> = detectors
54            .axis_iter(Axis(0))
55            .map(|row| self.decode(row))
56            .collect();
57
58        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
59    }
60    fn decode_detailed_batch(&mut self, detectors: ArrayView2<Bit>) -> Vec<DecodeResult> {
61        detectors
62            .axis_iter(Axis(0))
63            .map(|row| self.decode_detailed(row))
64            .collect()
65    }
66    // Compute detectors from errors
67    fn get_detectors(&self, errors: ArrayView1<Bit>) -> Array1<Bit> {
68        self.check_matrix().mul_mod2(&errors.to_owned())
69    }
70
71    fn get_detectors_batch(&self, errors: ArrayView2<Bit>) -> Array2<Bit> {
72        let check_matrix = self.check_matrix();
73        let detectors: Vec<Array1<Bit>> = errors
74            .axis_iter(Axis(0))
75            .map(|row| check_matrix.mul_mod2(&row.to_owned()))
76            .collect();
77        stack(
78            Axis(0),
79            &detectors.iter().map(|a| a.view()).collect::<Vec<_>>(),
80        )
81        .unwrap()
82    }
83
84    fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
85        let log_prior_ratios = self.log_prior_ratios();
86        let mut decoding_quality: f64 = 0.0;
87        for i in 0..errors.len() {
88            if errors[i] == 1 && f64::is_finite(log_prior_ratios[i]) {
89                decoding_quality += log_prior_ratios[i];
90            }
91        }
92        decoding_quality
93    }
94}
95
96dyn_clone::clone_trait_object!(Decoder);
97
98pub trait DecoderRunner: Decoder + Clone + Sync {
99    fn par_decode_batch(&mut self, detectors: ArrayView2<Bit>) -> Array2<Bit> {
100        let arrs: Vec<Array1<Bit>> = detectors
101            .axis_iter(Axis(0))
102            .into_par_iter()
103            .map_with(|| self.clone(), |decoder, row| decoder().decode(row))
104            .collect();
105        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
106    }
107
108    fn par_decode_detailed_batch(&mut self, detectors: ArrayView2<Bit>) -> Vec<DecodeResult> {
109        detectors
110            .axis_iter(Axis(0))
111            .into_par_iter()
112            .map_with(
113                || self.clone(),
114                |decoder, row| decoder().decode_detailed(row),
115            )
116            .collect()
117    }
118
119    /// Decode a batch displaying a progress bar
120    fn decode_batch_progress_bar(
121        &mut self,
122        detectors: ArrayView2<Bit>,
123        leave_progress_bar_on_finish: bool,
124    ) -> Array2<Bit> {
125        let finish_mode = match leave_progress_bar_on_finish {
126            true => ProgressFinish::AndLeave,
127            false => ProgressFinish::AndClear,
128        };
129
130        let arrs: Vec<Array1<Bit>> = detectors
131            .axis_iter(Axis(0))
132            .progress_with_style(self.get_progress_bar_style())
133            .with_finish(finish_mode)
134            .map(|row| self.decode(row))
135            .collect();
136        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
137    }
138
139    /// Decode a batch displaying a progress bar
140    fn decode_detailed_batch_progress_bar(
141        &mut self,
142        detectors: ArrayView2<Bit>,
143        leave_progress_bar_on_finish: bool,
144    ) -> Vec<DecodeResult> {
145        let finish_mode = match leave_progress_bar_on_finish {
146            true => ProgressFinish::AndLeave,
147            false => ProgressFinish::AndClear,
148        };
149
150        detectors
151            .axis_iter(Axis(0))
152            .progress_with_style(self.get_progress_bar_style())
153            .with_finish(finish_mode)
154            .map(|row| self.decode_detailed(row))
155            .collect()
156    }
157
158    fn par_decode_batch_progress_bar(
159        &mut self,
160        detectors: ArrayView2<Bit>,
161        leave_progress_bar_on_finish: bool,
162    ) -> Array2<Bit> {
163        let finish_mode = match leave_progress_bar_on_finish {
164            true => ProgressFinish::AndLeave,
165            false => ProgressFinish::AndClear,
166        };
167
168        let arrs: Vec<Array1<Bit>> = detectors
169            .axis_iter(Axis(0))
170            .into_par_iter()
171            .progress_with_style(self.get_progress_bar_style())
172            .with_finish(finish_mode)
173            .map_with(|| self.clone(), |decoder, row| decoder().decode(row))
174            .collect();
175
176        stack(Axis(0), &arrs.iter().map(|a| a.view()).collect::<Vec<_>>()).unwrap()
177    }
178
179    fn par_decode_detailed_batch_progress_bar(
180        &mut self,
181        detectors: ArrayView2<Bit>,
182        leave_progress_bar_on_finish: bool,
183    ) -> Vec<DecodeResult> {
184        let finish_mode = match leave_progress_bar_on_finish {
185            true => ProgressFinish::AndLeave,
186            false => ProgressFinish::AndClear,
187        };
188
189        detectors
190            .axis_iter(Axis(0))
191            .into_par_iter()
192            .progress_with_style(self.get_progress_bar_style())
193            .with_finish(finish_mode)
194            .map_with(
195                || self.clone(),
196                |decoder, row| decoder().decode_detailed(row),
197            )
198            .collect()
199    }
200
201    fn get_progress_bar_style(&self) -> ProgressStyle {
202        ProgressStyle::default_bar().template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({per_sec}, {eta})").unwrap()
203    }
204}
205
206#[derive(Serialize, Deserialize, Debug, Clone)]
207pub struct DecodeResult {
208    pub decoding: Array1<Bit>,
209    pub decoded_detectors: Array1<Bit>,
210    pub posterior_ratios: Array1<f64>,
211    pub success: bool,
212    pub decoding_quality: f64,
213    pub iterations: usize,
214    pub max_iter: usize,
215    pub extra: BPExtraResult,
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219pub enum BPExtraResult {
220    None,
221}