Skip to main content

relay_bp/bp/
relay.rs

1// (C) Copyright IBM 2025
2//
3// This code is licensed under the Apache License, Version 2.0. You may
4// obtain a copy of this license in the LICENSE.txt file in the root directory
5// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
6//
7// Any modifications or derivative works of this code must retain this
8// copyright notice, and modified files need to carry a notice indicating
9// that they have been altered from the originals.
10
11use super::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
12use crate::decoder::{Bit, SparseBitMatrix};
13use crate::decoder::{DecodeResult, Decoder, DecoderRunner};
14use log::debug;
15
16use ndarray::{Array1, Array2, ArrayView1};
17use num_traits::{Bounded, FromPrimitive, Signed, ToPrimitive};
18use std::fmt::Debug;
19use std::fs::File;
20use std::fs::OpenOptions;
21use std::io::{BufWriter, Write};
22//use std::string;
23use rand::distributions::{Distribution, Uniform};
24use rand::SeedableRng;
25use std::process::exit;
26use std::sync::Arc;
27
28#[derive(Clone, PartialEq, Debug)]
29pub enum StoppingCriterion {
30    PreIter,
31    NConv { stop_after: usize },
32    All,
33}
34
35impl Default for StoppingCriterion {
36    fn default() -> StoppingCriterion {
37        StoppingCriterion::NConv { stop_after: 1 }
38    }
39}
40
41#[derive(Clone, Debug)]
42pub struct RelayDecoderConfig {
43    pub pre_iter: usize,
44    pub num_sets: usize,
45    pub set_max_iter: usize,
46    pub gamma_dist_interval: (f64, f64),
47    pub explicit_gammas: Option<Array2<f64>>,
48    pub stopping_criterion: StoppingCriterion,
49    pub logging: bool,
50    pub seed: u64,
51}
52
53impl Default for RelayDecoderConfig {
54    fn default() -> Self {
55        Self {
56            pre_iter: 80,
57            num_sets: 300,
58            set_max_iter: 60,
59            gamma_dist_interval: (-0.24, 0.66),
60            explicit_gammas: None,
61            stopping_criterion: StoppingCriterion::default(),
62            logging: false,
63            seed: 0,
64        }
65    }
66}
67
68#[derive(Clone)]
69struct PosteriorUpdateState {
70    rng_std: rand::rngs::StdRng,
71    uniform: rand::distributions::Uniform<f64>,
72}
73
74/// An ensemble decoder which controls an inner BP min-sum decoder.
75#[derive(Clone)]
76pub struct RelayDecoder<N: PartialEq + Default + Clone + Copy> {
77    bp_decoder: MinSumBPDecoder<N>,
78    relay_config: Arc<RelayDecoderConfig>,
79    posterior_update_state: PosteriorUpdateState,
80    sets_quality: Array1<f64>,
81    sets_iter: Array1<usize>,
82    sets_conv: Array1<bool>,
83    sets_best: Array1<bool>,
84    num_executed_sets: usize,
85}
86
87impl<N> RelayDecoder<N>
88where
89    N: PartialEq
90        + Debug
91        + Default
92        + Clone
93        + Copy
94        + Signed
95        + Bounded
96        + FromPrimitive
97        + ToPrimitive
98        + std::cmp::PartialOrd
99        + std::ops::Add
100        + std::ops::AddAssign
101        + std::ops::DivAssign
102        + std::ops::Mul<N>
103        + std::ops::MulAssign
104        + Send
105        + Sync
106        + std::fmt::Display
107        + 'static,
108{
109    pub fn new(
110        check_matrix: Arc<SparseBitMatrix>,
111        min_sum_config: Arc<MinSumDecoderConfig>,
112        relay_config: Arc<RelayDecoderConfig>,
113    ) -> RelayDecoder<N> {
114        if relay_config.logging {
115            let log_line = format!(
116                "# pre_iter: {}: sets: {} set_max_iter: {}\n\
117                # gamma_distribution: {:?} # set_idx, num_iter, converged, unique_best_solution\n",
118                relay_config.pre_iter,
119                relay_config.num_sets,
120                relay_config.set_max_iter,
121                relay_config.gamma_dist_interval,
122            );
123            let mut file =
124                File::create("relay_logging.out").expect("Unable to create file for logging.");
125            file.write_all(log_line.as_bytes())
126                .expect("Unable to write Relay logging data.");
127        }
128
129        // Create logging variables if applicable
130        let (sets_quality, sets_iter, sets_conv, sets_best);
131        if relay_config.logging {
132            sets_quality = Array1::<f64>::from_elem(relay_config.num_sets + 1, f64::MAX);
133            sets_iter =
134                Array1::<usize>::from_elem(relay_config.num_sets + 1, relay_config.set_max_iter);
135            sets_conv = Array1::<bool>::from_elem(relay_config.num_sets + 1, false);
136            sets_best = Array1::<bool>::from_elem(relay_config.num_sets + 1, false);
137        } else {
138            sets_quality = Array1::<f64>::zeros(1);
139            sets_iter = Array1::<usize>::zeros(1);
140            sets_conv = Array1::<bool>::from_elem(1, false);
141            sets_best = Array1::<bool>::from_elem(1, false);
142        }
143
144        if relay_config.explicit_gammas.is_some() {
145            let gammas_shape = relay_config.explicit_gammas.as_ref().unwrap().shape();
146            let num_variable_nodes = check_matrix.cols();
147            if num_variable_nodes != gammas_shape[1] {
148                println!("ERROR: Number of specified gammas {} does not match the number of variable nodes {}.", gammas_shape[1], num_variable_nodes);
149                exit(-1);
150            };
151            if relay_config.num_sets > gammas_shape[0] {
152                println!("WARNING: Number of different gamma sets {} is smaller than the number of Relay legs {}. Legs will be reused.", gammas_shape[0], relay_config.num_sets)
153            }
154        }
155
156        // The actual number of sets Relay ran, depends on the stopping criterion
157        let num_executed_sets = 0;
158
159        let bp_decoder = MinSumBPDecoder::new(check_matrix, min_sum_config);
160
161        let posterior_update_state = Self::init_dismem_state(&relay_config);
162
163        RelayDecoder {
164            bp_decoder,
165            relay_config,
166            posterior_update_state,
167            sets_quality,
168            sets_iter,
169            sets_conv,
170            sets_best,
171            num_executed_sets,
172        }
173    }
174
175    fn init_dismem_state(relay_config: &RelayDecoderConfig) -> PosteriorUpdateState {
176        let rng_std: rand::prelude::StdRng = rand::rngs::StdRng::seed_from_u64(relay_config.seed);
177        let low = relay_config.gamma_dist_interval.0;
178        let high = relay_config.gamma_dist_interval.1;
179        let uniform: rand::distributions::Uniform<f64> = Uniform::new(low, high);
180        PosteriorUpdateState { rng_std, uniform }
181    }
182
183    fn init_next_set(&mut self, set_idx: usize) {
184        let mut gammas = Array1::zeros(self.check_matrix().cols());
185        if self.relay_config.explicit_gammas.is_some() {
186            let gammas_num_sets = self.relay_config.explicit_gammas.as_ref().unwrap().shape()[0];
187            for i in 0..gammas.len() {
188                gammas[i] = *self
189                    .relay_config
190                    .explicit_gammas
191                    .as_ref()
192                    .unwrap()
193                    .get((set_idx % gammas_num_sets, i))
194                    .unwrap();
195            }
196            self.bp_decoder.set_memory_strengths_f64(gammas);
197            return;
198        }
199        for i in 0..gammas.len() {
200            gammas[i] = self
201                .posterior_update_state
202                .uniform
203                .sample(&mut self.posterior_update_state.rng_std);
204        }
205        self.bp_decoder.set_memory_strengths_f64(gammas);
206    }
207
208    /// Decode with the inner decoder
209    fn decode_inner(&mut self, detectors: ArrayView1<Bit>, max_iter: usize) -> DecodeResult {
210        let mut success: bool = false;
211        let mut decoded_detectors = Array1::default(detectors.dim());
212
213        for _ in 0..max_iter {
214            self.bp_decoder.run_iteration(detectors);
215            decoded_detectors = self.bp_decoder.compute_decoded_detectors();
216            success = self
217                .bp_decoder
218                .check_convergence(detectors, decoded_detectors.view());
219
220            // If we have converged may now exit
221            if success {
222                debug!(
223                    "Succeeded on iteration {:?}",
224                    self.bp_decoder.current_iteration
225                );
226                break;
227            }
228            self.bp_decoder.current_iteration += 1;
229        }
230
231        self.bp_decoder
232            .build_result(success, decoded_detectors, max_iter)
233    }
234
235    fn write_log(&mut self, file: File) {
236        let mut buf_writer = BufWriter::new(file);
237        for set in 0..self.num_executed_sets {
238            let log_line = format!(
239                "{}, {}, {}, {}\n",
240                (set - 1) as i32,
241                self.sets_iter[set],
242                self.sets_conv[set] as u8,
243                self.sets_best[set] as u8
244            );
245            buf_writer
246                .write_all(log_line.as_bytes())
247                .expect("Unable to write Relay logging data.");
248        }
249        buf_writer
250            .flush()
251            .expect("Unable to write Relay logging data.");
252    }
253}
254
255impl<N> Decoder for RelayDecoder<N>
256where
257    N: PartialEq
258        + Debug
259        + Default
260        + Clone
261        + Copy
262        + Signed
263        + Bounded
264        + FromPrimitive
265        + ToPrimitive
266        + std::cmp::PartialOrd
267        + std::ops::Add
268        + std::ops::AddAssign
269        + std::ops::DivAssign
270        + std::ops::Mul<N>
271        + std::ops::MulAssign
272        + Send
273        + Sync
274        + std::fmt::Display
275        + 'static,
276{
277    fn check_matrix(&self) -> Arc<SparseBitMatrix> {
278        self.bp_decoder.check_matrix()
279    }
280
281    fn log_prior_ratios(&mut self) -> Array1<f64> {
282        self.bp_decoder.log_prior_ratios()
283    }
284
285    fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
286        // Initialization
287        let mut num_conv = 0;
288        let mut min_pm = f64::MAX;
289        let mut num_sets_best = 0;
290        let mut best_set_idx = 0;
291        let mut total_iterations: usize = 0;
292        self.num_executed_sets = 0;
293        let stopping_criterion = self.relay_config.stopping_criterion.clone();
294
295        // First Mem-BP
296        self.bp_decoder.initialize_decoder();
297        let mut result = self.decode_inner(detectors, self.relay_config.pre_iter);
298        self.num_executed_sets = 1;
299
300        // Create logging variables and log first set if applicable
301        if self.relay_config.logging {
302            self.sets_iter[0] = result.iterations;
303            self.sets_conv[0] = result.success;
304        }
305
306        // Check early stopping criteria
307        if result.success {
308            num_conv += 1;
309            min_pm = result.decoding_quality;
310            num_sets_best += 1;
311            if self.relay_config.logging {
312                self.sets_quality[0] = result.decoding_quality
313            };
314
315            let mut done = false;
316            if stopping_criterion == StoppingCriterion::PreIter {
317                done = true;
318            } else if let StoppingCriterion::NConv { stop_after } = stopping_criterion {
319                if num_conv >= stop_after {
320                    done = true;
321                }
322            }
323            // If stopping criterion has been met: Log (if applicable) and return
324            if done {
325                if self.relay_config.logging {
326                    self.sets_best[0] = true;
327                    let file = OpenOptions::new()
328                        .append(true)
329                        .open("relay_logging.out")
330                        .unwrap();
331                    self.write_log(file);
332                }
333                return result;
334            }
335        }
336
337        // Init and loop over all Relay sets
338        total_iterations += result.iterations;
339        for set in 1..=self.relay_config.num_sets {
340            // Do not completely initialize decoder as we wish to relay
341            // posterior marginals with new memory strengths.
342            self.init_next_set(set);
343            self.bp_decoder.current_iteration = 0;
344            self.bp_decoder.initialize_check_to_variable();
345            self.bp_decoder.initialize_variable_to_check();
346            let temp_result = self.decode_inner(detectors, self.relay_config.set_max_iter);
347
348            self.num_executed_sets += 1;
349            total_iterations += temp_result.iterations;
350            if temp_result.success {
351                num_conv += 1;
352                let pm = temp_result.decoding_quality;
353                if self.relay_config.logging {
354                    self.sets_conv[set] = true;
355                    self.sets_iter[set] = temp_result.iterations;
356                    self.sets_quality[set] = pm;
357                }
358                if pm == min_pm {
359                    // Count how often we found the best solution
360                    num_sets_best += 1;
361                }
362                if pm < min_pm {
363                    // Found a new best solution
364                    num_sets_best = 1;
365                    best_set_idx = set;
366                    min_pm = pm;
367                    result = temp_result;
368                }
369                if let StoppingCriterion::NConv { stop_after } = stopping_criterion {
370                    if num_conv >= stop_after {
371                        break;
372                    }
373                }
374            }
375        }
376        result.iterations = total_iterations;
377
378        // Rest of the function is just logging
379        if self.relay_config.logging {
380            if num_sets_best == 1 {
381                self.sets_best[best_set_idx] = true;
382            }
383            let file = OpenOptions::new()
384                .append(true)
385                .open("relay_logging.out")
386                .unwrap();
387            self.write_log(file);
388        }
389
390        result
391    }
392
393    fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
394        self.bp_decoder.get_decoding_quality(errors)
395    }
396}
397
398impl<N> DecoderRunner for RelayDecoder<N> where
399    N: PartialEq
400        + Debug
401        + Default
402        + Clone
403        + Copy
404        + Signed
405        + Bounded
406        + FromPrimitive
407        + ToPrimitive
408        + std::cmp::PartialOrd
409        + std::ops::Add
410        + std::ops::AddAssign
411        + std::ops::DivAssign
412        + std::ops::Mul<N>
413        + std::ops::MulAssign
414        + Send
415        + Sync
416        + std::fmt::Display
417        + 'static
418{
419}
420
421#[cfg(test)]
422mod tests {
423
424    use super::*;
425
426    use crate::bipartite_graph::{BipartiteGraph, SparseBipartiteGraph};
427    use env_logger;
428    use ndarray::prelude::*;
429
430    use crate::dem::DetectorErrorModel;
431    use crate::utilities::test::get_test_data_path;
432    use ndarray::Array2;
433    use ndarray_npy::read_npy;
434
435    fn init() {
436        let _ = env_logger::builder().is_test(true).try_init();
437    }
438
439    // Basic test where Relay is called but only runs 1 BP iteration
440    #[test]
441    fn min_sum_decode_repetition_code() {
442        init();
443
444        // Build 3, 2 qubit repetition code with weight 2 checks
445        let check_matrix = array![[1, 1, 0], [0, 1, 1],];
446
447        let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
448        let check_matrix_arc = Arc::new(check_matrix);
449
450        let iterations = 10;
451        let bp_config = MinSumDecoderConfig {
452            error_priors: array![0.003, 0.003, 0.003],
453            max_iter: iterations,
454            alpha: Some(1.),
455            alpha_iteration_scaling_factor: 1.,
456            gamma0: None,
457            ..Default::default()
458        };
459        let bp_config_arc = Arc::new(bp_config);
460
461        let relay_config = RelayDecoderConfig {
462            pre_iter: iterations,
463            num_sets: 0,
464            set_max_iter: 150,
465            stopping_criterion: StoppingCriterion::PreIter,
466            explicit_gammas: None,
467            ..Default::default()
468        };
469        let relay_config_arc = Arc::new(relay_config);
470
471        let mut decoder: RelayDecoder<f32> =
472            RelayDecoder::new(check_matrix_arc, bp_config_arc, relay_config_arc);
473
474        let error = array![0, 0, 0];
475        let detectors: Array1<Bit> = array![0, 0];
476
477        let result = decoder.decode_detailed(detectors.view());
478
479        assert_eq!(result.decoding, error);
480        assert_eq!(result.decoded_detectors, detectors);
481        assert_eq!(result.max_iter, iterations);
482        assert!(result.success);
483
484        let error = array![1, 0, 0];
485        let detectors: Array1<Bit> = array![1, 0];
486
487        let result = decoder.decode_detailed(detectors.view());
488
489        assert_eq!(result.decoding, error);
490        assert_eq!(result.decoded_detectors, detectors);
491        assert_eq!(result.max_iter, iterations);
492        assert!(result.success);
493
494        let error = array![0, 1, 0];
495        let detectors: Array1<Bit> = array![1, 1];
496
497        let result = decoder.decode_detailed(detectors.view());
498
499        assert_eq!(result.decoding, error);
500        assert_eq!(result.decoded_detectors, detectors);
501        assert_eq!(result.max_iter, iterations);
502        assert!(result.success);
503
504        let error = array![0, 0, 1];
505        let detectors: Array1<Bit> = array![0, 1];
506
507        let result = decoder.decode_detailed(detectors.view());
508
509        assert_eq!(result.decoding, error);
510        assert_eq!(result.decoded_detectors, detectors);
511        assert_eq!(result.max_iter, iterations);
512        assert!(result.success);
513    }
514
515    // Basic test where Relay runs 40 sets
516    #[test]
517    fn decode_144_12_12() {
518        let resources = get_test_data_path();
519        let code_144_12_12 =
520            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
521        let detectors_144_12_12: Array2<Bit> =
522            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
523        let bp_config_144_12_12 = MinSumDecoderConfig {
524            error_priors: code_144_12_12.error_priors,
525            max_iter: 200,
526            alpha: None,
527            alpha_iteration_scaling_factor: 0.,
528            gamma0: Some(0.9),
529            ..Default::default()
530        };
531        let relay_config = RelayDecoderConfig::default();
532        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
533        let bp_config = Arc::new(bp_config_144_12_12);
534        let config = Arc::new(relay_config);
535        let mut decoder_144_12_12: RelayDecoder<f64> =
536            RelayDecoder::new(check_matrix, bp_config, config);
537        let num_errors = 100;
538        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
539        let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
540
541        // All should pass for Relay.
542        assert!(
543            results.iter().map(|x| x.success as usize).sum::<usize>()
544                == (detectors_slice.shape()[0])
545        );
546
547        assert_eq!(results[0].decoding.len(), 8785);
548    }
549
550    // Basic test where Relay runs 40 sets
551    #[test]
552    fn decode_144_12_12_int() {
553        let resources = get_test_data_path();
554        let code_144_12_12 =
555            DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
556        let detectors_144_12_12: Array2<Bit> =
557            read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
558
559        let bits = 16;
560        let scale = 8.0;
561
562        let bp_config_144_12_12 = MinSumDecoderConfig {
563            error_priors: code_144_12_12.error_priors,
564            max_iter: 200,
565            alpha: None,
566            alpha_iteration_scaling_factor: 0.,
567            gamma0: Some(0.9),
568            max_data_value: Some(((1 << bits) - 1) as f64),
569            data_scale_value: Some(scale),
570            ..Default::default()
571        };
572        let relay_config = RelayDecoderConfig {
573            ..Default::default()
574        };
575        let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
576        let bp_config = Arc::new(bp_config_144_12_12);
577        let config = Arc::new(relay_config);
578        let mut decoder_144_12_12: RelayDecoder<isize> =
579            RelayDecoder::new(check_matrix, bp_config, config);
580        let num_errors = 100;
581        let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
582        let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
583
584        // All should pass for Relay.
585        assert!(
586            results.iter().map(|x| x.success as usize).sum::<usize>()
587                == (detectors_slice.shape()[0])
588        );
589
590        assert_eq!(results[0].decoding.len(), 8785);
591    }
592}