qecp/
decoder_mwpm.rs

1//! minimum-weight perfect matching decoder
2//!
3
4use super::blossom_v;
5use super::complete_model_graph::*;
6use super::erasure_graph::*;
7use super::model_graph::*;
8use super::noise_model::*;
9use super::serde_json;
10use super::simulator::*;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14
15/// MWPM decoder, initialized and cloned for multiple threads
16#[derive(Debug, Clone, Serialize)]
17pub struct MWPMDecoder {
18    /// model graph is immutably shared
19    pub model_graph: Arc<ModelGraph>,
20    /// erasure graph is immutably shared
21    pub erasure_graph: Arc<ErasureGraph>,
22    /// complete model graph each thread maintain its own precomputed data; the internal model_graph might be copied and modified if erasure error exists
23    pub complete_model_graph: CompleteModelGraph,
24    /// save configuration for later usage
25    pub config: MWPMDecoderConfig,
26    /// an immutably shared simulator that is used to change model graph on the fly for correcting erasure errors
27    pub simulator: Arc<Simulator>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(deny_unknown_fields)]
32pub struct MWPMDecoderConfig {
33    /// build complete model graph at first, but this will consume O(N^2) memory and increase initialization time,
34    /// disable this when you're simulating large code
35    #[serde(alias = "pcmg")] // abbreviation
36    #[serde(default = "mwpm_default_configs::precompute_complete_model_graph")]
37    pub precompute_complete_model_graph: bool,
38    /// weight function, by default using [`WeightFunction::AutotuneImproved`]
39    #[serde(alias = "wf")] // abbreviation
40    #[serde(default = "mwpm_default_configs::weight_function")]
41    pub weight_function: WeightFunction,
42    /// combined probability can improve accuracy, but will cause probabilities differ a lot even in the case of i.i.d. noise model
43    #[serde(alias = "ucp")] // abbreviation
44    #[serde(default = "mwpm_default_configs::use_combined_probability")]
45    pub use_combined_probability: bool,
46    #[serde(default = "mwpm_default_configs::log_matchings")]
47    pub log_matchings: bool,
48}
49
50pub mod mwpm_default_configs {
51    use super::*;
52    pub fn precompute_complete_model_graph() -> bool {
53        false
54    } // save for erasure noise model and also large code distance
55    pub fn weight_function() -> WeightFunction {
56        WeightFunction::AutotuneImproved
57    }
58    pub fn use_combined_probability() -> bool {
59        true
60    } // default use combined probability for better accuracy
61    pub fn log_matchings() -> bool {
62        false
63    }
64}
65
66impl MWPMDecoder {
67    /// create a new MWPM decoder with decoder configuration
68    pub fn new(
69        simulator: &Simulator,
70        noise_model: Arc<NoiseModel>,
71        decoder_configuration: &serde_json::Value,
72        parallel: usize,
73        use_brief_edge: bool,
74    ) -> Self {
75        // read attribute of decoder configuration
76        let config: MWPMDecoderConfig = serde_json::from_value(decoder_configuration.clone()).unwrap();
77        // build model graph
78        let mut simulator = simulator.clone();
79        let mut model_graph = ModelGraph::new(&simulator);
80        model_graph.build(
81            &mut simulator,
82            Arc::clone(&noise_model),
83            &config.weight_function,
84            parallel,
85            config.use_combined_probability,
86            use_brief_edge,
87        );
88        let model_graph = Arc::new(model_graph);
89        // build erasure graph
90        let mut erasure_graph = ErasureGraph::new(&simulator);
91        erasure_graph.build(&mut simulator, Arc::clone(&noise_model), parallel);
92        let erasure_graph = Arc::new(erasure_graph);
93        // build complete model graph
94        let mut complete_model_graph = CompleteModelGraph::new(&simulator, Arc::clone(&model_graph));
95        complete_model_graph.precompute(&simulator, config.precompute_complete_model_graph, parallel);
96        Self {
97            model_graph,
98            erasure_graph,
99            complete_model_graph,
100            config,
101            simulator: Arc::new(simulator),
102        }
103    }
104
105    /// decode given measurement results
106    #[allow(dead_code)]
107    pub fn decode(&mut self, sparse_measurement: &SparseMeasurement) -> (SparseCorrection, serde_json::Value) {
108        self.decode_with_erasure(sparse_measurement, &SparseErasures::new())
109    }
110
111    /// decode given measurement results and detected erasures
112    pub fn decode_with_erasure(
113        &mut self,
114        sparse_measurement: &SparseMeasurement,
115        sparse_detected_erasures: &SparseErasures,
116    ) -> (SparseCorrection, serde_json::Value) {
117        if !sparse_detected_erasures.is_empty() {
118            assert!(!self.config.precompute_complete_model_graph, "if erasure happens, the precomputed complete graph is invalid; please disable `precompute_complete_model_graph` or `pcmg` in the decoder configuration");
119        }
120        let mut correction = SparseCorrection::new();
121        // list nontrivial measurements to be matched
122        let to_be_matched = sparse_measurement.to_vec();
123        let mut time_prepare_graph = 0.;
124        let mut time_blossom_v = 0.;
125        let mut time_build_correction = 0.;
126        let mut matching_edges: Vec<(Position, Position)> = Vec::with_capacity(0);
127        if !to_be_matched.is_empty() {
128            // println!{"to_be_matched: {:?}", to_be_matched};
129            let begin = Instant::now();
130            // add the edges to the graph
131            let m_len = to_be_matched.len(); // virtual boundary of `i` is `i + m_len`
132            let node_num = m_len * 2;
133            // Z (X) stabilizers are (fully) connected, boundaries are fully connected
134            // stabilizer to boundary is one-to-one connected
135            let mut weighted_edges = Vec::<(usize, usize, f64)>::new();
136            // update model graph weights to consider erasure information
137            let mut erasure_graph_modifier = ErasureGraphModifier::<f64>::new();
138            if !sparse_detected_erasures.is_empty() {
139                // if erasure exists, the model graph will be duplicated on demand
140                let erasure_edges = sparse_detected_erasures.get_erasure_edges(&self.erasure_graph);
141                let model_graph_mut = self.complete_model_graph.get_model_graph_mut();
142                for erasure_edge in erasure_edges.iter() {
143                    match erasure_edge {
144                        ErasureEdge::Connection(position1, position2) => {
145                            let node1 = model_graph_mut.get_node_mut_unwrap(position1);
146                            let edge12 = node1.edges.get_mut(position2).expect("neighbor must exist");
147                            let original_weight12 = edge12.weight;
148                            edge12.weight = 0.; // set to 0 because of erasure
149                            let node2 = model_graph_mut.get_node_mut_unwrap(position2);
150                            let edge21 = node2.edges.get_mut(position1).expect("neighbor must exist");
151                            assert_eq!(original_weight12, edge21.weight, "model graph edge must be symmetric");
152                            edge21.weight = 0.; // set to 0 because of erasure
153                            erasure_graph_modifier.push_modified_edge(
154                                ErasureEdge::Connection(position1.clone(), position2.clone()),
155                                original_weight12,
156                            );
157                        }
158                        ErasureEdge::Boundary(position) => {
159                            let node = model_graph_mut.get_node_mut_unwrap(position);
160                            let boundary = node.boundary.as_mut().expect("boundary must exist").as_mut();
161                            let original_weight = boundary.weight;
162                            boundary.weight = 0.;
163                            erasure_graph_modifier
164                                .push_modified_edge(ErasureEdge::Boundary(position.clone()), original_weight);
165                        }
166                    }
167                }
168                self.complete_model_graph.model_graph_changed(&self.simulator);
169            }
170            // invalidate previous cache to save memory
171            self.complete_model_graph.invalidate_previous_dijkstra();
172            for i in 0..m_len {
173                let position = &to_be_matched[i];
174                let (edges, boundary) = self.complete_model_graph.get_edges(position, &to_be_matched);
175                if let Some(weight) = boundary {
176                    // eprintln!{"boundary {} {} ", i, weight};
177                    weighted_edges.push((i, i + m_len, weight));
178                }
179                for &(j, weight) in edges.iter() {
180                    if i < j {
181                        // remove duplicated edges
182                        // eprintln!{"edge {} {} {} ", i, j, weight};
183                        weighted_edges.push((i, j, weight));
184                    }
185                }
186                for j in (i + 1)..m_len {
187                    // virtual boundaries are always fully connected
188                    weighted_edges.push((i + m_len, j + m_len, 0.));
189                }
190            }
191            time_prepare_graph += begin.elapsed().as_secs_f64();
192            // run the Blossom algorithm
193            let begin = Instant::now();
194            let matching = blossom_v::safe_minimum_weight_perfect_matching(node_num, weighted_edges);
195            time_blossom_v += begin.elapsed().as_secs_f64();
196            // build correction based on the matching
197            let begin = Instant::now();
198            for i in 0..m_len {
199                let j = matching[i];
200                let a: &Position = &to_be_matched[i];
201                if j < i {
202                    // only add correction if j < i, so that the same correction is not applied twice
203                    // println!("match peer {:?} {:?}", to_be_matched[i], to_be_matched[j]);
204                    let b = &to_be_matched[j];
205                    let matching_correction = self.complete_model_graph.build_correction_matching(a, b);
206                    correction.extend(&matching_correction);
207                } else if j >= m_len {
208                    // matched with boundary
209                    // println!("match boundary {:?}", to_be_matched[i]);
210                    let boundary_correction = self.complete_model_graph.build_correction_boundary(a);
211                    correction.extend(&boundary_correction);
212                }
213                if self.config.log_matchings {
214                    let peer_position = if j < i {
215                        Some(to_be_matched[j].clone())
216                    } else if j >= m_len {
217                        Some(
218                            self.complete_model_graph
219                                .get_node_unwrap(a)
220                                .precomputed
221                                .as_ref()
222                                .unwrap()
223                                .boundary
224                                .as_ref()
225                                .unwrap()
226                                .next
227                                .clone(),
228                        )
229                    } else {
230                        None
231                    };
232                    if let Some(peer_position) = peer_position {
233                        matching_edges.push((a.clone(), peer_position));
234                    }
235                }
236            }
237            time_build_correction += begin.elapsed().as_secs_f64();
238            // recover the modified edges
239            if !sparse_detected_erasures.is_empty() {
240                let model_graph_mut = self.complete_model_graph.get_model_graph_mut();
241                while erasure_graph_modifier.has_modified_edges() {
242                    let (erasure_edge, weight) = erasure_graph_modifier.pop_modified_edge();
243                    match erasure_edge {
244                        ErasureEdge::Connection(position1, position2) => {
245                            let node1 = model_graph_mut.get_node_mut_unwrap(&position1);
246                            let edge12 = node1.edges.get_mut(&position2).expect("neighbor must exist");
247                            assert_eq!(edge12.weight, 0., "why a non-zero edge needs to be recovered");
248                            edge12.weight = weight; // recover the weight
249                            let node2 = model_graph_mut.get_node_mut_unwrap(&position2);
250                            let edge21 = node2.edges.get_mut(&position1).expect("neighbor must exist");
251                            assert_eq!(edge21.weight, 0., "why a non-zero edge needs to be recovered");
252                            edge21.weight = weight; // recover the weight
253                        }
254                        ErasureEdge::Boundary(position) => {
255                            let node = model_graph_mut.get_node_mut_unwrap(&position);
256                            let boundary = node.boundary.as_mut().expect("boundary must exist").as_mut();
257                            assert_eq!(boundary.weight, 0., "why a non-zero edge needs to be recovered");
258                            boundary.weight = weight;
259                        }
260                    }
261                }
262                // need to call here because if next round there are no erasure errors, the complete mode graph must still be in a consistent state
263                self.complete_model_graph.model_graph_changed(&self.simulator);
264            }
265        }
266        let mut runtime_statistics = json!({
267            "to_be_matched": to_be_matched.len(),
268            "time_prepare_graph": time_prepare_graph,
269            "time_blossom_v": time_blossom_v,
270            "time_build_correction": time_build_correction,
271        });
272        if self.config.log_matchings {
273            let runtime_statistics = runtime_statistics.as_object_mut().unwrap();
274            runtime_statistics.insert(
275                "log_matchings".to_string(),
276                json!([{
277                    "name": "matching",
278                    "description": "minimum-weight perfect matching",
279                    "edges": matching_edges,
280                }]),
281            );
282        }
283        (correction, runtime_statistics)
284    }
285}
286
287#[cfg(feature = "blossom_v")]
288#[cfg(test)]
289mod tests {
290    use super::super::code_builder::*;
291    use super::super::noise_model_builder::*;
292    use super::*;
293
294    // 2022.6.16: mwpm decoder should correct this pattern because UF decoder does
295    // {"[0][1][5]":"Z","[0][2][6]":"Z","[0][4][4]":"X","[0][5][7]":"X","[0][9][7]":"Y"}, {"erasures":["[0][1][3]","[0][1][5]","[0][2][6]","[0][4][4]","[0][5][7]","[0][6][6]","[0][9][7]"]}
296    // cargo run --release -- tool benchmark [5] [0] [0] --pes [0.1] --max_repeats 0 --min_failed_cases 10 --time_budget 60 --decoder mwpm --code_type StandardPlanarCode --noise_model erasure-only-phenomenological -p0 --debug_print failed-error-pattern
297    #[test]
298    fn mwpm_decoder_debug_1() {
299        // cargo test mwpm_decoder_debug_1 -- --nocapture
300        let d = 5;
301        let noisy_measurements = 0; // perfect measurement
302        let p = 0.;
303        let pe = 0.1;
304        // build simulator
305        let mut simulator = Simulator::new(CodeType::StandardPlanarCode, CodeSize::new(noisy_measurements, d, d));
306        code_builder_sanity_check(&simulator).unwrap();
307        // build noise model
308        let mut noise_model = NoiseModel::new(&simulator);
309        let noise_model_builder = NoiseModelBuilder::ErasureOnlyPhenomenological;
310        noise_model_builder.apply(&mut simulator, &mut noise_model, &json!({}), p, 1., pe);
311        simulator.compress_error_rates(&mut noise_model);
312        noise_model_sanity_check(&simulator, &noise_model).unwrap();
313        let noise_model = Arc::new(noise_model);
314        // build decoder
315        let decoder_config = json!({});
316        let mut mwpm_decoder = MWPMDecoder::new(
317            &Arc::new(simulator.clone()),
318            Arc::clone(&noise_model),
319            &decoder_config,
320            1,
321            false,
322        );
323        // load errors onto the simulator
324        let sparse_error_pattern: SparseErrorPattern =
325            serde_json::from_value(json!({"[0][1][5]":"Z","[0][2][6]":"Z","[0][4][4]":"X","[0][5][7]":"X","[0][9][7]":"Y"}))
326                .unwrap();
327        let sparse_detected_erasures: SparseErasures = serde_json::from_value(json!([
328            "[0][1][3]",
329            "[0][1][5]",
330            "[0][2][6]",
331            "[0][4][4]",
332            "[0][5][7]",
333            "[0][6][6]",
334            "[0][9][7]"
335        ]))
336        .unwrap();
337        simulator
338            .load_sparse_error_pattern(&sparse_error_pattern, &noise_model)
339            .expect("success");
340        simulator
341            .load_sparse_detected_erasures(&sparse_detected_erasures, &noise_model)
342            .expect("success");
343        simulator.propagate_errors();
344        let sparse_measurement = simulator.generate_sparse_measurement();
345        println!("sparse_measurement: {:?}", sparse_measurement);
346        let sparse_detected_erasures = simulator.generate_sparse_detected_erasures();
347        let (correction, _runtime_statistics) =
348            mwpm_decoder.decode_with_erasure(&sparse_measurement, &sparse_detected_erasures);
349        code_builder_sanity_check_correction(&mut simulator, &correction).unwrap();
350        let (logical_i, logical_j) = simulator.validate_correction(&correction);
351        assert!(!logical_i && !logical_j);
352    }
353}