use super::model_graph::*;
use super::noise_model::*;
use super::serde_json;
use super::simulator::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use crate::decoder_fusion::{FusionBlossomAdaptor, FusionDecoderConfig};
use super::decoder_mwpm::*;
use super::derivative::*;
use crate::fusion_blossom::mwpm_solver::*;
use crate::fusion_blossom::util::*;
#[derive(Derivative, Serialize)]
#[derivative(Debug)]
pub struct ParallelFusionDecoder {
pub adaptor: Arc<FusionBlossomAdaptor>,
#[serde(skip)]
#[derivative(Debug = "ignore")]
pub fusion_solver: fusion_blossom::mwpm_solver::SolverParallel,
pub config: ParallelFusionDecoderConfig,
}
impl Clone for ParallelFusionDecoder {
fn clone(&self) -> Self {
let partition_info = self.config.partition_config.clone().unwrap_or(PartitionConfig::new(self.adaptor.vertex_to_position_mapping.len())).info();
let fusion_solver = if self.config.skip_decoding {
fusion_blossom::mwpm_solver::SolverParallel::new(&SolverInitializer {
vertex_num: 0,
weighted_edges: vec![],
virtual_vertices: vec![],
}, &partition_info, self.config.primal_dual_config.clone())
} else {
fusion_blossom::mwpm_solver::SolverParallel::new(&self.adaptor.initializer, &partition_info, self.config.primal_dual_config.clone())
};
Self {
adaptor: self.adaptor.clone(),
fusion_solver,
config: self.config.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ParallelFusionDecoderConfig {
#[serde(alias = "wf")] #[serde(default = "mwpm_default_configs::weight_function")]
pub weight_function: WeightFunction,
#[serde(alias = "ucp")] #[serde(default = "mwpm_default_configs::use_combined_probability")]
pub use_combined_probability: bool,
#[serde(default = "parallel_fusion_default_configs::only_stab_z")]
pub only_stab_z: bool,
#[serde(alias = "mhw")] #[serde(default = "parallel_fusion_default_configs::max_half_weight")]
pub max_half_weight: usize,
#[serde(default = "parallel_fusion_default_configs::skip_decoding")]
pub skip_decoding: bool,
#[serde(default = "parallel_fusion_default_configs::log_matchings")]
pub log_matchings: bool,
#[serde(default = "parallel_fusion_default_configs::primal_dual_config")]
pub primal_dual_config: serde_json::Value,
#[serde(default = "parallel_fusion_default_configs::partition_config")]
pub partition_config: Option<PartitionConfig>,
}
pub mod parallel_fusion_default_configs {
use super::*;
pub fn only_stab_z() -> bool {
false
}
pub fn max_half_weight() -> usize {
5000
}
pub fn skip_decoding() -> bool {
false
}
pub fn log_matchings() -> bool {
false
}
pub fn primal_dual_config() -> serde_json::Value { json!({}) }
pub fn partition_config() -> Option<PartitionConfig> { None }
}
impl ParallelFusionDecoder {
pub fn new(
simulator: &Simulator,
noise_model: Arc<NoiseModel>,
decoder_configuration: &serde_json::Value,
parallel: usize,
use_brief_edge: bool,
) -> Self {
let config: ParallelFusionDecoderConfig = serde_json::from_value(decoder_configuration.clone()).unwrap();
let mut simulator = simulator.clone();
let adaptor = FusionBlossomAdaptor::new(&FusionDecoderConfig {
weight_function: config.weight_function.clone(),
use_combined_probability: config.use_combined_probability,
only_stab_z: config.only_stab_z,
max_half_weight: config.max_half_weight,
skip_decoding: config.skip_decoding,
log_matchings: config.log_matchings,
max_tree_size: usize::MAX,
}, &mut simulator, noise_model, parallel, use_brief_edge);
let partition_info = config.partition_config.clone().unwrap_or(PartitionConfig::new(adaptor.vertex_to_position_mapping.len())).info();
let fusion_solver = fusion_blossom::mwpm_solver::SolverParallel::new(&adaptor.initializer, &partition_info, config.primal_dual_config.clone());
Self {
adaptor: Arc::new(adaptor),
fusion_solver,
config,
}
}
#[allow(dead_code)]
pub fn decode(&mut self, sparse_measurement: &SparseMeasurement) -> (SparseCorrection, serde_json::Value) {
self.decode_with_erasure(sparse_measurement, &SparseErasures::new())
}
pub fn decode_with_erasure(
&mut self,
sparse_measurement: &SparseMeasurement,
sparse_detected_erasures: &SparseErasures,
) -> (SparseCorrection, serde_json::Value) {
if self.config.skip_decoding {
return (SparseCorrection::new(), json!({}));
}
assert!(sparse_detected_erasures.is_empty(), "fusion decoder doesn't support erasure error yet: we'll do it in the next version to support 0-weight edges and dynamic setting");
let mut correction = SparseCorrection::new();
let mut time_fusion = 0.;
let mut time_build_correction = 0.;
let mut log_matchings = Vec::with_capacity(0);
if !sparse_measurement.is_empty() {
let begin = Instant::now();
let syndrome_pattern = self
.adaptor
.generate_syndrome_pattern(sparse_measurement, sparse_detected_erasures);
self.fusion_solver.solve(&syndrome_pattern);
let subgraph: Vec<usize> = self.fusion_solver.subgraph();
if self.config.log_matchings {
let mut subgraph_edges = vec![];
for &edge_index in subgraph.iter() {
let (vertex_1, vertex_2, _) = self.adaptor.initializer.weighted_edges[edge_index];
let position_1 = self.adaptor.vertex_to_position_mapping[vertex_1].clone();
let position_2 = self.adaptor.vertex_to_position_mapping[vertex_2].clone();
subgraph_edges.push((position_1, position_2));
}
log_matchings.push(json!({
"name": "subgraph",
"description": "elementary fault edges",
"edges": subgraph_edges,
}));
let mut perfect_matching_edges = vec![];
let perfect_matching = self.fusion_solver.perfect_matching();
for (node_ptr_1, node_ptr_2) in perfect_matching.peer_matchings.iter() {
let vertex_1 = node_ptr_1.get_representative_vertex();
let vertex_2 = node_ptr_2.get_representative_vertex();
let position_1 = self.adaptor.vertex_to_position_mapping[vertex_1].clone();
let position_2 = self.adaptor.vertex_to_position_mapping[vertex_2].clone();
perfect_matching_edges.push((position_1, position_2));
}
for (node_ptr, virtual_vertex) in perfect_matching.virtual_matchings.iter() {
let vertex = node_ptr.get_representative_vertex();
let position_1 = self.adaptor.vertex_to_position_mapping[vertex].clone();
let position_2 = self.adaptor.vertex_to_position_mapping[*virtual_vertex].clone();
perfect_matching_edges.push((position_1, position_2));
}
log_matchings.push(json!({
"name": "perfect matching",
"description": "the paths of the perfect matching",
"edges": perfect_matching_edges,
}));
}
self.fusion_solver.clear();
time_fusion += begin.elapsed().as_secs_f64();
correction = self.adaptor.subgraph_to_correction(&subgraph);
time_build_correction += begin.elapsed().as_secs_f64();
}
let mut runtime_statistics = json!({
"to_be_matched": sparse_measurement.len(),
"time_fusion": time_fusion,
"time_build_correction": time_build_correction,
});
if self.config.log_matchings {
let runtime_statistics = runtime_statistics.as_object_mut().unwrap();
runtime_statistics.insert("log_matchings".to_string(), json!(log_matchings));
}
(correction, runtime_statistics)
}
}
#[cfg(test)]
mod tests {
use super::super::code_builder::*;
use super::super::noise_model_builder::*;
use super::*;
#[test]
fn parallel_fusion_decoder_debug_1() {
let d = 5;
let noisy_measurements = 0; let p = 0.;
let pe = 0.1;
let mut simulator = Simulator::new(CodeType::StandardPlanarCode, CodeSize::new(noisy_measurements, d, d));
code_builder_sanity_check(&simulator).unwrap();
let mut noise_model = NoiseModel::new(&simulator);
let noise_model_builder = NoiseModelBuilder::ErasureOnlyPhenomenological;
noise_model_builder.apply(&mut simulator, &mut noise_model, &json!({}), p, 1., pe);
simulator.compress_error_rates(&mut noise_model);
noise_model_sanity_check(&simulator, &noise_model).unwrap();
let noise_model = Arc::new(noise_model);
let decoder_config = json!({
"partition_config": {
"vertex_num": 60,
"partitions": [(0, 20), (41, 60)],
"fusions": [(0, 1)],
}
});
let mut parallel_fusion_decoder = ParallelFusionDecoder::new(
&Arc::new(simulator.clone()),
Arc::clone(&noise_model),
&decoder_config,
1,
false,
);
let sparse_error_pattern: SparseErrorPattern =
serde_json::from_value(json!({"[0][1][5]":"Z","[0][2][6]":"Z","[0][4][4]":"X","[0][5][7]":"X","[0][9][7]":"Y"}))
.unwrap();
simulator
.load_sparse_error_pattern(&sparse_error_pattern, &noise_model)
.expect("success");
simulator.propagate_errors();
let sparse_measurement = simulator.generate_sparse_measurement();
println!("sparse_measurement: {:?}", sparse_measurement);
let sparse_detected_erasures = simulator.generate_sparse_detected_erasures();
let (correction, _runtime_statistics) =
parallel_fusion_decoder.decode_with_erasure(&sparse_measurement, &sparse_detected_erasures);
println!("correction: {:?}", correction);
code_builder_sanity_check_correction(&mut simulator, &correction).unwrap();
let (logical_i, logical_j) = simulator.validate_correction(&correction);
assert!(!logical_i && !logical_j);
}
}