1use super::min_sum::{MinSumBPDecoder, MinSumDecoderConfig};
12use crate::decoder::{Bit, SparseBitMatrix};
13use crate::decoder::{DecodeResult, Decoder, DecoderRunner};
14use log::debug;
15
16use ndarray::{Array1, Array2, ArrayView1};
17use num_traits::{Bounded, FromPrimitive, Signed, ToPrimitive};
18use std::fmt::Debug;
19use std::fs::File;
20use std::fs::OpenOptions;
21use std::io::{BufWriter, Write};
22use rand::distributions::{Distribution, Uniform};
24use rand::SeedableRng;
25use std::process::exit;
26use std::sync::Arc;
27
28#[derive(Clone, PartialEq, Debug)]
29pub enum StoppingCriterion {
30 PreIter,
31 NConv { stop_after: usize },
32 All,
33}
34
35impl Default for StoppingCriterion {
36 fn default() -> StoppingCriterion {
37 StoppingCriterion::NConv { stop_after: 1 }
38 }
39}
40
41#[derive(Clone, Debug)]
42pub struct RelayDecoderConfig {
43 pub pre_iter: usize,
44 pub num_sets: usize,
45 pub set_max_iter: usize,
46 pub gamma_dist_interval: (f64, f64),
47 pub explicit_gammas: Option<Array2<f64>>,
48 pub stopping_criterion: StoppingCriterion,
49 pub logging: bool,
50 pub seed: u64,
51}
52
53impl Default for RelayDecoderConfig {
54 fn default() -> Self {
55 Self {
56 pre_iter: 80,
57 num_sets: 300,
58 set_max_iter: 60,
59 gamma_dist_interval: (-0.24, 0.66),
60 explicit_gammas: None,
61 stopping_criterion: StoppingCriterion::default(),
62 logging: false,
63 seed: 0,
64 }
65 }
66}
67
68#[derive(Clone)]
69struct PosteriorUpdateState {
70 rng_std: rand::rngs::StdRng,
71 uniform: rand::distributions::Uniform<f64>,
72}
73
74#[derive(Clone)]
76pub struct RelayDecoder<N: PartialEq + Default + Clone + Copy> {
77 bp_decoder: MinSumBPDecoder<N>,
78 relay_config: Arc<RelayDecoderConfig>,
79 posterior_update_state: PosteriorUpdateState,
80 sets_quality: Array1<f64>,
81 sets_iter: Array1<usize>,
82 sets_conv: Array1<bool>,
83 sets_best: Array1<bool>,
84 num_executed_sets: usize,
85}
86
87impl<N> RelayDecoder<N>
88where
89 N: PartialEq
90 + Debug
91 + Default
92 + Clone
93 + Copy
94 + Signed
95 + Bounded
96 + FromPrimitive
97 + ToPrimitive
98 + std::cmp::PartialOrd
99 + std::ops::Add
100 + std::ops::AddAssign
101 + std::ops::DivAssign
102 + std::ops::Mul<N>
103 + std::ops::MulAssign
104 + Send
105 + Sync
106 + std::fmt::Display
107 + 'static,
108{
109 pub fn new(
110 check_matrix: Arc<SparseBitMatrix>,
111 min_sum_config: Arc<MinSumDecoderConfig>,
112 relay_config: Arc<RelayDecoderConfig>,
113 ) -> RelayDecoder<N> {
114 if relay_config.logging {
115 let log_line = format!(
116 "# pre_iter: {}: sets: {} set_max_iter: {}\n\
117 # gamma_distribution: {:?} # set_idx, num_iter, converged, unique_best_solution\n",
118 relay_config.pre_iter,
119 relay_config.num_sets,
120 relay_config.set_max_iter,
121 relay_config.gamma_dist_interval,
122 );
123 let mut file =
124 File::create("relay_logging.out").expect("Unable to create file for logging.");
125 file.write_all(log_line.as_bytes())
126 .expect("Unable to write Relay logging data.");
127 }
128
129 let (sets_quality, sets_iter, sets_conv, sets_best);
131 if relay_config.logging {
132 sets_quality = Array1::<f64>::from_elem(relay_config.num_sets + 1, f64::MAX);
133 sets_iter =
134 Array1::<usize>::from_elem(relay_config.num_sets + 1, relay_config.set_max_iter);
135 sets_conv = Array1::<bool>::from_elem(relay_config.num_sets + 1, false);
136 sets_best = Array1::<bool>::from_elem(relay_config.num_sets + 1, false);
137 } else {
138 sets_quality = Array1::<f64>::zeros(1);
139 sets_iter = Array1::<usize>::zeros(1);
140 sets_conv = Array1::<bool>::from_elem(1, false);
141 sets_best = Array1::<bool>::from_elem(1, false);
142 }
143
144 if relay_config.explicit_gammas.is_some() {
145 let gammas_shape = relay_config.explicit_gammas.as_ref().unwrap().shape();
146 let num_variable_nodes = check_matrix.cols();
147 if num_variable_nodes != gammas_shape[1] {
148 println!("ERROR: Number of specified gammas {} does not match the number of variable nodes {}.", gammas_shape[1], num_variable_nodes);
149 exit(-1);
150 };
151 if relay_config.num_sets > gammas_shape[0] {
152 println!("WARNING: Number of different gamma sets {} is smaller than the number of Relay legs {}. Legs will be reused.", gammas_shape[0], relay_config.num_sets)
153 }
154 }
155
156 let num_executed_sets = 0;
158
159 let bp_decoder = MinSumBPDecoder::new(check_matrix, min_sum_config);
160
161 let posterior_update_state = Self::init_dismem_state(&relay_config);
162
163 RelayDecoder {
164 bp_decoder,
165 relay_config,
166 posterior_update_state,
167 sets_quality,
168 sets_iter,
169 sets_conv,
170 sets_best,
171 num_executed_sets,
172 }
173 }
174
175 fn init_dismem_state(relay_config: &RelayDecoderConfig) -> PosteriorUpdateState {
176 let rng_std: rand::prelude::StdRng = rand::rngs::StdRng::seed_from_u64(relay_config.seed);
177 let low = relay_config.gamma_dist_interval.0;
178 let high = relay_config.gamma_dist_interval.1;
179 let uniform: rand::distributions::Uniform<f64> = Uniform::new(low, high);
180 PosteriorUpdateState { rng_std, uniform }
181 }
182
183 fn init_next_set(&mut self, set_idx: usize) {
184 let mut gammas = Array1::zeros(self.check_matrix().cols());
185 if self.relay_config.explicit_gammas.is_some() {
186 let gammas_num_sets = self.relay_config.explicit_gammas.as_ref().unwrap().shape()[0];
187 for i in 0..gammas.len() {
188 gammas[i] = *self
189 .relay_config
190 .explicit_gammas
191 .as_ref()
192 .unwrap()
193 .get((set_idx % gammas_num_sets, i))
194 .unwrap();
195 }
196 self.bp_decoder.set_memory_strengths_f64(gammas);
197 return;
198 }
199 for i in 0..gammas.len() {
200 gammas[i] = self
201 .posterior_update_state
202 .uniform
203 .sample(&mut self.posterior_update_state.rng_std);
204 }
205 self.bp_decoder.set_memory_strengths_f64(gammas);
206 }
207
208 fn decode_inner(&mut self, detectors: ArrayView1<Bit>, max_iter: usize) -> DecodeResult {
210 let mut success: bool = false;
211 let mut decoded_detectors = Array1::default(detectors.dim());
212
213 for _ in 0..max_iter {
214 self.bp_decoder.run_iteration(detectors);
215 decoded_detectors = self.bp_decoder.compute_decoded_detectors();
216 success = self
217 .bp_decoder
218 .check_convergence(detectors, decoded_detectors.view());
219
220 if success {
222 debug!(
223 "Succeeded on iteration {:?}",
224 self.bp_decoder.current_iteration
225 );
226 break;
227 }
228 self.bp_decoder.current_iteration += 1;
229 }
230
231 self.bp_decoder
232 .build_result(success, decoded_detectors, max_iter)
233 }
234
235 fn write_log(&mut self, file: File) {
236 let mut buf_writer = BufWriter::new(file);
237 for set in 0..self.num_executed_sets {
238 let log_line = format!(
239 "{}, {}, {}, {}\n",
240 (set - 1) as i32,
241 self.sets_iter[set],
242 self.sets_conv[set] as u8,
243 self.sets_best[set] as u8
244 );
245 buf_writer
246 .write_all(log_line.as_bytes())
247 .expect("Unable to write Relay logging data.");
248 }
249 buf_writer
250 .flush()
251 .expect("Unable to write Relay logging data.");
252 }
253}
254
255impl<N> Decoder for RelayDecoder<N>
256where
257 N: PartialEq
258 + Debug
259 + Default
260 + Clone
261 + Copy
262 + Signed
263 + Bounded
264 + FromPrimitive
265 + ToPrimitive
266 + std::cmp::PartialOrd
267 + std::ops::Add
268 + std::ops::AddAssign
269 + std::ops::DivAssign
270 + std::ops::Mul<N>
271 + std::ops::MulAssign
272 + Send
273 + Sync
274 + std::fmt::Display
275 + 'static,
276{
277 fn check_matrix(&self) -> Arc<SparseBitMatrix> {
278 self.bp_decoder.check_matrix()
279 }
280
281 fn log_prior_ratios(&mut self) -> Array1<f64> {
282 self.bp_decoder.log_prior_ratios()
283 }
284
285 fn decode_detailed(&mut self, detectors: ArrayView1<Bit>) -> DecodeResult {
286 let mut num_conv = 0;
288 let mut min_pm = f64::MAX;
289 let mut num_sets_best = 0;
290 let mut best_set_idx = 0;
291 let mut total_iterations: usize = 0;
292 self.num_executed_sets = 0;
293 let stopping_criterion = self.relay_config.stopping_criterion.clone();
294
295 self.bp_decoder.initialize_decoder();
297 let mut result = self.decode_inner(detectors, self.relay_config.pre_iter);
298 self.num_executed_sets = 1;
299
300 if self.relay_config.logging {
302 self.sets_iter[0] = result.iterations;
303 self.sets_conv[0] = result.success;
304 }
305
306 if result.success {
308 num_conv += 1;
309 min_pm = result.decoding_quality;
310 num_sets_best += 1;
311 if self.relay_config.logging {
312 self.sets_quality[0] = result.decoding_quality
313 };
314
315 let mut done = false;
316 if stopping_criterion == StoppingCriterion::PreIter {
317 done = true;
318 } else if let StoppingCriterion::NConv { stop_after } = stopping_criterion {
319 if num_conv >= stop_after {
320 done = true;
321 }
322 }
323 if done {
325 if self.relay_config.logging {
326 self.sets_best[0] = true;
327 let file = OpenOptions::new()
328 .append(true)
329 .open("relay_logging.out")
330 .unwrap();
331 self.write_log(file);
332 }
333 return result;
334 }
335 }
336
337 total_iterations += result.iterations;
339 for set in 1..=self.relay_config.num_sets {
340 self.init_next_set(set);
343 self.bp_decoder.current_iteration = 0;
344 self.bp_decoder.initialize_check_to_variable();
345 self.bp_decoder.initialize_variable_to_check();
346 let temp_result = self.decode_inner(detectors, self.relay_config.set_max_iter);
347
348 self.num_executed_sets += 1;
349 total_iterations += temp_result.iterations;
350 if temp_result.success {
351 num_conv += 1;
352 let pm = temp_result.decoding_quality;
353 if self.relay_config.logging {
354 self.sets_conv[set] = true;
355 self.sets_iter[set] = temp_result.iterations;
356 self.sets_quality[set] = pm;
357 }
358 if pm == min_pm {
359 num_sets_best += 1;
361 }
362 if pm < min_pm {
363 num_sets_best = 1;
365 best_set_idx = set;
366 min_pm = pm;
367 result = temp_result;
368 }
369 if let StoppingCriterion::NConv { stop_after } = stopping_criterion {
370 if num_conv >= stop_after {
371 break;
372 }
373 }
374 }
375 }
376 result.iterations = total_iterations;
377
378 if self.relay_config.logging {
380 if num_sets_best == 1 {
381 self.sets_best[best_set_idx] = true;
382 }
383 let file = OpenOptions::new()
384 .append(true)
385 .open("relay_logging.out")
386 .unwrap();
387 self.write_log(file);
388 }
389
390 result
391 }
392
393 fn get_decoding_quality(&mut self, errors: ArrayView1<u8>) -> f64 {
394 self.bp_decoder.get_decoding_quality(errors)
395 }
396}
397
398impl<N> DecoderRunner for RelayDecoder<N> where
399 N: PartialEq
400 + Debug
401 + Default
402 + Clone
403 + Copy
404 + Signed
405 + Bounded
406 + FromPrimitive
407 + ToPrimitive
408 + std::cmp::PartialOrd
409 + std::ops::Add
410 + std::ops::AddAssign
411 + std::ops::DivAssign
412 + std::ops::Mul<N>
413 + std::ops::MulAssign
414 + Send
415 + Sync
416 + std::fmt::Display
417 + 'static
418{
419}
420
421#[cfg(test)]
422mod tests {
423
424 use super::*;
425
426 use crate::bipartite_graph::{BipartiteGraph, SparseBipartiteGraph};
427 use env_logger;
428 use ndarray::prelude::*;
429
430 use crate::dem::DetectorErrorModel;
431 use crate::utilities::test::get_test_data_path;
432 use ndarray::Array2;
433 use ndarray_npy::read_npy;
434
435 fn init() {
436 let _ = env_logger::builder().is_test(true).try_init();
437 }
438
439 #[test]
441 fn min_sum_decode_repetition_code() {
442 init();
443
444 let check_matrix = array![[1, 1, 0], [0, 1, 1],];
446
447 let check_matrix: SparseBipartiteGraph<_> = SparseBipartiteGraph::from_dense(check_matrix);
448 let check_matrix_arc = Arc::new(check_matrix);
449
450 let iterations = 10;
451 let bp_config = MinSumDecoderConfig {
452 error_priors: array![0.003, 0.003, 0.003],
453 max_iter: iterations,
454 alpha: Some(1.),
455 alpha_iteration_scaling_factor: 1.,
456 gamma0: None,
457 ..Default::default()
458 };
459 let bp_config_arc = Arc::new(bp_config);
460
461 let relay_config = RelayDecoderConfig {
462 pre_iter: iterations,
463 num_sets: 0,
464 set_max_iter: 150,
465 stopping_criterion: StoppingCriterion::PreIter,
466 explicit_gammas: None,
467 ..Default::default()
468 };
469 let relay_config_arc = Arc::new(relay_config);
470
471 let mut decoder: RelayDecoder<f32> =
472 RelayDecoder::new(check_matrix_arc, bp_config_arc, relay_config_arc);
473
474 let error = array![0, 0, 0];
475 let detectors: Array1<Bit> = array![0, 0];
476
477 let result = decoder.decode_detailed(detectors.view());
478
479 assert_eq!(result.decoding, error);
480 assert_eq!(result.decoded_detectors, detectors);
481 assert_eq!(result.max_iter, iterations);
482 assert!(result.success);
483
484 let error = array![1, 0, 0];
485 let detectors: Array1<Bit> = array![1, 0];
486
487 let result = decoder.decode_detailed(detectors.view());
488
489 assert_eq!(result.decoding, error);
490 assert_eq!(result.decoded_detectors, detectors);
491 assert_eq!(result.max_iter, iterations);
492 assert!(result.success);
493
494 let error = array![0, 1, 0];
495 let detectors: Array1<Bit> = array![1, 1];
496
497 let result = decoder.decode_detailed(detectors.view());
498
499 assert_eq!(result.decoding, error);
500 assert_eq!(result.decoded_detectors, detectors);
501 assert_eq!(result.max_iter, iterations);
502 assert!(result.success);
503
504 let error = array![0, 0, 1];
505 let detectors: Array1<Bit> = array![0, 1];
506
507 let result = decoder.decode_detailed(detectors.view());
508
509 assert_eq!(result.decoding, error);
510 assert_eq!(result.decoded_detectors, detectors);
511 assert_eq!(result.max_iter, iterations);
512 assert!(result.success);
513 }
514
515 #[test]
517 fn decode_144_12_12() {
518 let resources = get_test_data_path();
519 let code_144_12_12 =
520 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
521 let detectors_144_12_12: Array2<Bit> =
522 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
523 let bp_config_144_12_12 = MinSumDecoderConfig {
524 error_priors: code_144_12_12.error_priors,
525 max_iter: 200,
526 alpha: None,
527 alpha_iteration_scaling_factor: 0.,
528 gamma0: Some(0.9),
529 ..Default::default()
530 };
531 let relay_config = RelayDecoderConfig::default();
532 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
533 let bp_config = Arc::new(bp_config_144_12_12);
534 let config = Arc::new(relay_config);
535 let mut decoder_144_12_12: RelayDecoder<f64> =
536 RelayDecoder::new(check_matrix, bp_config, config);
537 let num_errors = 100;
538 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
539 let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
540
541 assert!(
543 results.iter().map(|x| x.success as usize).sum::<usize>()
544 == (detectors_slice.shape()[0])
545 );
546
547 assert_eq!(results[0].decoding.len(), 8785);
548 }
549
550 #[test]
552 fn decode_144_12_12_int() {
553 let resources = get_test_data_path();
554 let code_144_12_12 =
555 DetectorErrorModel::load(resources.join("144_12_12")).expect("Unable to load the code");
556 let detectors_144_12_12: Array2<Bit> =
557 read_npy(resources.join("144_12_12_detectors.npy")).expect("Unable to open file");
558
559 let bits = 16;
560 let scale = 8.0;
561
562 let bp_config_144_12_12 = MinSumDecoderConfig {
563 error_priors: code_144_12_12.error_priors,
564 max_iter: 200,
565 alpha: None,
566 alpha_iteration_scaling_factor: 0.,
567 gamma0: Some(0.9),
568 max_data_value: Some(((1 << bits) - 1) as f64),
569 data_scale_value: Some(scale),
570 ..Default::default()
571 };
572 let relay_config = RelayDecoderConfig {
573 ..Default::default()
574 };
575 let check_matrix = Arc::new(code_144_12_12.detector_error_matrix);
576 let bp_config = Arc::new(bp_config_144_12_12);
577 let config = Arc::new(relay_config);
578 let mut decoder_144_12_12: RelayDecoder<isize> =
579 RelayDecoder::new(check_matrix, bp_config, config);
580 let num_errors = 100;
581 let detectors_slice = detectors_144_12_12.slice(s![..num_errors, ..]);
582 let results = decoder_144_12_12.par_decode_detailed_batch(detectors_slice);
583
584 assert!(
586 results.iter().map(|x| x.success as usize).sum::<usize>()
587 == (detectors_slice.shape()[0])
588 );
589
590 assert_eq!(results[0].decoding.len(), 8785);
591 }
592}