use crate::bin;
use crate::coordinator;
use crate::coordinator::{DecoderCacheKey, FingerprintSource, build_modifier_fingerprints};
use crate::decoder::BlackBoxDecoderClient;
use crate::decoder::blackbox_decoder::{self, DecodingHypergraph, Hyperedge};
use crate::decoder::blackbox_util::assert_parity_factor;
use crate::misc::bit_vector::{self, get_bit, set_bit};
use crate::misc::index::{ErrorIndex, WILDCARD};
use crate::misc::pauli_frame_tracker::PauliFrameTracker;
use crate::misc::relative_program::{self, RelativeMapping, RelativeProgram};
use crate::misc::sync::{TaskCounter, check_or_receiver, get_or_receiver, get_value};
use crate::misc::union_find::{UnionFindGeneric, UnionNodeTrait};
use crate::misc::util::exclusive_probability_of;
use crate::util::BitVector;
use binar::{BitVec, BitwiseMut};
use hashbrown::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[cfg(feature = "cli")]
use structdoc::StructDoc;
use tokio::sync::{Mutex, RwLock, oneshot, watch};
use tokio_util::sync::CancellationToken;
use tonic::{Request, Response, Status};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "cli", derive(StructDoc))]
#[serde(deny_unknown_fields)]
pub struct MonolithicCoordinatorConfig {
#[serde(default)]
pub assert_parity_factor: bool,
#[serde(default = "default_true")]
pub merge_hyperedges: bool,
#[serde(default = "default_true")]
pub async_expand: bool,
#[serde(default = "default_true")]
pub persistent_decoder: bool,
}
fn default_true() -> bool {
true
}
pub struct MonolithicCoordinator {
pub config: MonolithicCoordinatorConfig,
pub port_types: RwLock<HashMap<u64, Arc<bin::PortType>>>,
pub gadget_types: RwLock<HashMap<u64, Arc<bin::GadgetType>>>,
pub check_model_types: RwLock<HashMap<u64, Arc<bin::CheckModelType>>>,
pub error_model_types: RwLock<HashMap<u64, Arc<bin::ErrorModelType>>>,
pub gadgets: Arc<RwLock<HashMap<u64, Gadget>>>,
pub check_models: Arc<RwLock<HashMap<u64, CheckModel>>>,
pub error_models: Arc<RwLock<HashMap<u64, ErrorModel>>>,
pub next_gid: Mutex<u64>,
pub next_cid: Mutex<u64>,
pub next_eid: Mutex<u64>,
pub pending_subgraphs: Mutex<UnionFindGeneric<MonolithicUnionNode>>,
pub gid_to_union_index: Mutex<HashMap<u64, usize>>,
pub loaded_decoders: RwLock<HashMap<DecoderCacheKey, LoadedDecoder>>,
pub black_box_decoder: BlackBoxDecoderClient,
pub pauli_frame_tracker: Mutex<PauliFrameTracker>,
pub cancellation: RwLock<CancellationToken>,
pub task_counter: Arc<TaskCounter>,
}
impl FingerprintSource for ErrorModel {
fn instance(&self) -> &bin::ErrorModel {
&self.instance
}
fn modified_remote_check_models(&self) -> &Arc<Vec<Option<bin::error_model_type::RemoteCheckModel>>> {
&self.modified_remote_check_models
}
}
#[derive(Debug, Clone)]
pub struct LoadedDecoder {
pub hid: u64,
pub errors: Arc<Vec<ErrorIndex>>,
pub decoding_hypergraph: Option<Arc<DecodingHypergraph>>,
pub vertex_remap: Option<Arc<Vec<u64>>>,
}
pub struct Gadget {
pub instance: bin::Gadget,
pub outcomes: Option<BitVector>,
pub binding_cid: watch::Sender<Option<u64>>,
pub outputs: Vec<watch::Sender<Option<bin::gadget::Connector>>>,
pub tx: oneshot::Sender<BitVector>,
pub rx: Option<oneshot::Receiver<BitVector>>,
}
pub struct CheckModel {
pub instance: bin::CheckModel,
pub attaching_eid_vec: Vec<u64>,
pub modified_remote_gadgets: Arc<Vec<Option<bin::check_model_type::RemoteGadget>>>,
pub expanded_remote_gadgets: watch::Sender<Option<Vec<Option<u64>>>>,
}
pub struct ErrorModel {
pub instance: bin::ErrorModel,
pub modified_remote_check_models: Arc<Vec<Option<bin::error_model_type::RemoteCheckModel>>>,
pub expanded_remote_check_models: watch::Sender<Option<Vec<Option<u64>>>>,
}
impl MonolithicCoordinator {
pub fn new(config: serde_json::Value, black_box_decoder: BlackBoxDecoderClient) -> Self {
let config: MonolithicCoordinatorConfig = serde_json::from_value(config).unwrap();
Self {
config,
port_types: Default::default(),
gadget_types: Default::default(),
check_model_types: Default::default(),
error_model_types: Default::default(),
gadgets: Default::default(),
check_models: Default::default(),
error_models: Default::default(),
next_gid: Mutex::new(1),
next_cid: Mutex::new(1),
next_eid: Mutex::new(1),
pending_subgraphs: Mutex::new(UnionFindGeneric::new(0)),
gid_to_union_index: Mutex::new(HashMap::new()),
loaded_decoders: Default::default(),
black_box_decoder,
pauli_frame_tracker: Default::default(),
cancellation: RwLock::new(CancellationToken::new()),
task_counter: TaskCounter::new(),
}
}
async fn get_subgraph(&self, gid: u64) -> HashSet<u64> {
let gadgets = self.gadgets.read().await;
let mut subgraph: HashSet<u64> = HashSet::new();
subgraph.insert(gid);
let mut boundary_gadgets: Vec<u64> = vec![gid];
while !boundary_gadgets.is_empty() {
let mut new_boundary_gadgets = vec![];
for boundary_gid in boundary_gadgets.into_iter() {
let gadget = gadgets.get(&boundary_gid).unwrap();
for next in gadget
.outputs
.iter()
.map(|x| x.borrow().unwrap())
.chain(gadget.instance.connectors.iter().copied())
{
if !subgraph.contains(&next.gid) {
subgraph.insert(next.gid);
new_boundary_gadgets.push(next.gid);
}
}
}
boundary_gadgets = new_boundary_gadgets;
}
subgraph
}
async fn take_subgraph(&self, gid: u64) -> (HashMap<u64, Gadget>, HashMap<u64, CheckModel>, HashMap<u64, ErrorModel>) {
let subgraph = self.get_subgraph(gid).await;
if self.config.async_expand {
let token = self.cancellation.read().await.clone();
let mut handles = vec![];
let gadgets = self.gadgets.read().await;
let check_models = self.check_models.read().await;
let error_models = self.error_models.read().await;
for &gid in subgraph.iter() {
let gadget = &gadgets[&gid];
if let Some(&cid) = gadget.binding_cid.borrow().as_ref() {
let check_model = &check_models[&cid];
if let Err(receiver) = check_or_receiver(&check_model.expanded_remote_gadgets, token.clone()) {
handles.push(receiver);
}
for &eid in check_model.attaching_eid_vec.iter() {
let error_model = &error_models[&eid];
match check_or_receiver(&error_model.expanded_remote_check_models, token.clone()) {
Ok(..) => {}
Err(receiver) => handles.push(receiver),
}
}
}
}
drop(gadgets);
drop(check_models);
drop(error_models);
futures_util::future::join_all(handles).await;
}
let gadgets: HashMap<u64, Gadget> = {
let mut gadgets = self.gadgets.write().await;
subgraph.iter().map(|gid| (*gid, gadgets.remove(gid).unwrap())).collect()
};
let check_models: HashMap<u64, CheckModel> = {
let mut check_models = self.check_models.write().await;
subgraph
.iter()
.filter_map(|gid| {
let gadget = &gadgets[gid];
if let Some(&cid) = gadget.binding_cid.borrow().as_ref() {
Some((cid, check_models.remove(&cid).unwrap()))
} else {
None
}
})
.collect()
};
let error_models: HashMap<u64, ErrorModel> = {
let mut error_models = self.error_models.write().await;
check_models
.iter()
.flat_map(|(_, check_model)| {
check_model
.attaching_eid_vec
.iter()
.map(|eid| {
let error_model = error_models.remove(eid).unwrap();
(*eid, error_model)
})
.collect::<Vec<_>>()
.into_iter()
})
.collect()
};
(gadgets, check_models, error_models)
}
async fn batch_expand(
&self,
gadgets: HashMap<u64, Gadget>,
mut check_models: HashMap<u64, CheckModel>,
mut error_models: HashMap<u64, ErrorModel>,
) -> (HashMap<u64, Gadget>, HashMap<u64, CheckModel>, HashMap<u64, ErrorModel>) {
let token = self.cancellation.read().await.clone();
let gadgets_locked = RwLock::new(gadgets);
for check_model in check_models.values_mut() {
let expanded_remote_gadgets = Self::expand_remote_gadgets(
&check_model.instance,
&check_model.modified_remote_gadgets,
&gadgets_locked,
token.clone(),
)
.await;
check_model
.expanded_remote_gadgets
.send_replace(Some(expanded_remote_gadgets));
}
let check_models_locked = RwLock::new(check_models);
for error_model in error_models.values_mut() {
let expanded_remote_check_models = Self::expand_remote_check_models(
&error_model.instance,
&error_model.modified_remote_check_models,
&gadgets_locked,
&check_models_locked,
token.clone(),
)
.await;
error_model
.expanded_remote_check_models
.send_replace(Some(expanded_remote_check_models));
}
(gadgets_locked.into_inner(), check_models_locked.into_inner(), error_models)
}
async fn decode_subgraph(&self, gid: u64) {
let (mut gadgets, mut check_models, mut error_models) = self.take_subgraph(gid).await;
if !self.config.async_expand {
(gadgets, check_models, error_models) = self.batch_expand(gadgets, check_models, error_models).await;
}
let mut expanded_gadgets: Vec<relative_program::ExpandedGadget> = vec![];
let mut gid_vec: Vec<_> = gadgets.keys().cloned().collect();
gid_vec.sort();
let token = self.cancellation.read().await.clone();
for &gid in gid_vec.iter() {
let gadget = gadgets.get(&gid).unwrap();
let inputs: Vec<_> = gadget.instance.connectors.iter().cloned().map(Some).collect();
let outputs: Vec<_> = gadget.outputs.iter().map(|v| v.borrow().unwrap()).map(Some).collect();
let gtype = gadget.instance.gtype;
let cid = gadget.binding_cid.borrow().as_ref().cloned();
let (check_model, error_models) = if let Some(cid) = cid {
let check_model = check_models.get(&cid).unwrap();
let remote_gadgets = get_value(&check_model.expanded_remote_gadgets, token.clone()).await;
let Some(remote_gadgets) = remote_gadgets else { return };
let expanded_check_model = relative_program::ExpandedCheckModel {
cid,
ctype: check_model.instance.ctype,
remote_gadgets,
count_checks: self
.check_model_types
.read()
.await
.get(&check_model.instance.ctype)
.unwrap()
.checks
.len(),
};
let mut expanded_error_models = vec![];
for &eid in check_model.attaching_eid_vec.iter() {
let error_model = error_models.get(&eid).unwrap();
let remote_check_models = get_value(&error_model.expanded_remote_check_models, token.clone()).await;
let Some(remote_check_models) = remote_check_models else {
return;
};
expanded_error_models.push(relative_program::ExpandedErrorModel {
eid,
etype: error_model.instance.etype,
remote_check_models,
});
}
(Some(expanded_check_model), expanded_error_models)
} else {
(None, vec![])
};
expanded_gadgets.push(relative_program::ExpandedGadget {
gid,
gtype,
inputs,
outputs,
check_model,
error_models,
});
}
let (relative_program, mapping) = RelativeProgram::new(&expanded_gadgets);
let (parity_factor, errors) = self
.decode_parity_factor(&relative_program, &mapping, &gadgets, &check_models, &error_models)
.await;
let updates = self
.update_pauli_frame(&parity_factor, &errors, &relative_program, &mapping, &error_models)
.await;
for (gid, readouts) in updates {
let gadget = gadgets.remove(&gid).unwrap();
let _ = gadget.tx.send(readouts);
}
}
async fn update_pauli_frame(
&self,
parity_factor: &blackbox_decoder::ParityFactor,
errors: &[ErrorIndex],
relative_program: &RelativeProgram,
mapping: &RelativeMapping,
error_models: &HashMap<u64, ErrorModel>,
) -> Vec<(u64, BitVector)> {
let error_model_types = self.error_model_types.read().await;
let mut tracker = self.pauli_frame_tracker.lock().await;
let mut residual_vec: Vec<BitVec> = Vec::with_capacity(relative_program.local_gadgets.len());
let mut readout_flips_vec: Vec<BitVec> = Vec::with_capacity(relative_program.local_gadgets.len());
for &gid in mapping.global_gid_of.iter() {
let Some(gadget) = tracker.gadgets.get(&gid) else {
return vec![];
};
residual_vec.push(BitVec::zeros(gadget.num_output_observables()));
readout_flips_vec.push(BitVec::zeros(gadget.num_readouts()));
}
for &ei in parity_factor.subgraph.iter() {
let local_error = &errors[ei as usize];
let local_eid = local_error.eid as usize;
let eid = mapping.global_eid_of[local_eid];
let error_index = local_error.error_index;
let error_model = error_models.get(&eid).unwrap();
let error_model_type = error_model_types.get(&error_model.instance.etype).unwrap();
let error = &error_model_type.errors[error_index as usize];
let local_gid = mapping.local_gid_of_local_eid[local_eid];
let residual = &mut residual_vec[local_gid];
let readout_flips = &mut readout_flips_vec[local_gid];
for &ri in error.residual.iter() {
residual.negate_index(ri as usize);
}
for &ri in error.readout_flips.iter() {
readout_flips.negate_index(ri as usize);
}
}
let mut updates = vec![];
for ((&gid, residual), readout_flips) in mapping.global_gid_of.iter().zip(residual_vec).zip(readout_flips_vec) {
let mut single_update = tracker.load_correction(gid, residual, readout_flips);
debug_assert_eq!(single_update.keys().cloned().collect::<Vec<_>>(), vec![gid]);
updates.push((gid, single_update.remove(&gid).unwrap()));
}
updates
}
async fn decode_parity_factor(
&self,
relative_program: &RelativeProgram,
mapping: &RelativeMapping,
gadgets: &HashMap<u64, Gadget>,
check_models: &HashMap<u64, CheckModel>,
error_models: &HashMap<u64, ErrorModel>,
) -> (blackbox_decoder::ParityFactor, Arc<Vec<ErrorIndex>>) {
let syndrome = self.get_syndrome(relative_program, mapping, gadgets, check_models).await;
let cache_key = if self.config.persistent_decoder {
let error_model_types = self.error_model_types.read().await;
Some(DecoderCacheKey {
relative_program: relative_program.clone(),
error_model_fingerprints: build_modifier_fingerprints(mapping, error_models, &error_model_types),
committing_local_cids: Vec::new(),
})
} else {
None
};
if let Some(ref cache_key) = cache_key {
let loaded_decoders = self.loaded_decoders.read().await;
let loaded = loaded_decoders.get(cache_key);
if let Some(loaded) = loaded {
let parity_factor = self
.black_box_decoder
.clone()
.decode_loaded(blackbox_decoder::LoadedDecodingProblem {
hid: loaded.hid,
syndrome: Some(syndrome.clone()),
})
.await
.unwrap();
if self.config.assert_parity_factor {
assert_parity_factor(loaded.decoding_hypergraph.as_ref().unwrap(), &parity_factor, &syndrome);
}
return (parity_factor, loaded.errors.clone());
}
}
let (mut decoding_hypergraph, mut errors) = self
.decoding_hypergraph(relative_program, mapping, check_models, error_models)
.await;
if self.config.merge_hyperedges {
let mut original_to_merged = Vec::with_capacity(errors.len());
let mut merged: HashMap<Vec<u64>, (usize, f64)> = HashMap::new();
let mut merged_hyperedges: Vec<Hyperedge> = Vec::with_capacity(errors.len());
let mut merged_errors = Vec::with_capacity(errors.len());
for (hyperedge, error_index) in decoding_hypergraph.hyperedges.iter().zip(errors.iter()) {
let mut syndrome = hyperedge.vertices.clone();
syndrome.sort();
debug_assert!({
let degree = syndrome.len();
syndrome.dedup();
syndrome.len() == degree
}); if let Some((ei, best_p_e)) = merged.get_mut(&syndrome) {
let p_all = merged_hyperedges[*ei].probability;
merged_hyperedges[*ei].probability = exclusive_probability_of(p_all, hyperedge.probability);
if hyperedge.probability > *best_p_e {
*best_p_e = hyperedge.probability;
merged_errors[*ei] = error_index.clone();
}
original_to_merged.push(*ei);
} else {
let ei = merged_errors.len();
merged_hyperedges.push(Hyperedge {
probability: hyperedge.probability,
vertices: syndrome.clone(),
});
merged_errors.push(error_index.clone());
original_to_merged.push(ei);
merged.insert(syndrome, (ei, hyperedge.probability));
}
}
decoding_hypergraph = DecodingHypergraph {
vertex_num: decoding_hypergraph.vertex_num,
hyperedges: merged_hyperedges,
};
errors = Arc::new(merged_errors);
}
let decoding_hypergraph = Arc::new(decoding_hypergraph);
let parity_factor = if let Some(cache_key) = cache_key {
let hid = self
.black_box_decoder
.clone()
.load_hypergraph(decoding_hypergraph.as_ref().clone())
.await
.unwrap()
.hid;
let mut loaded_decoders = self.loaded_decoders.write().await;
loaded_decoders.insert(
cache_key,
LoadedDecoder {
hid,
errors: errors.clone(),
decoding_hypergraph: self.config.assert_parity_factor.then_some(decoding_hypergraph.clone()),
vertex_remap: None,
},
);
drop(loaded_decoders);
self.black_box_decoder
.clone()
.decode_loaded(blackbox_decoder::LoadedDecodingProblem {
hid,
syndrome: Some(syndrome.clone()),
})
.await
.unwrap()
} else {
self.black_box_decoder
.clone()
.decode(blackbox_decoder::DecodingProblem {
hypergraph: Some(decoding_hypergraph.as_ref().clone()),
syndrome: Some(syndrome.clone()),
})
.await
.unwrap()
};
if self.config.assert_parity_factor {
assert_parity_factor(&decoding_hypergraph, &parity_factor, &syndrome);
}
(parity_factor, errors)
}
async fn get_syndrome(
&self,
relative_program: &RelativeProgram,
mapping: &RelativeMapping,
gadgets: &HashMap<u64, Gadget>,
check_models: &HashMap<u64, CheckModel>,
) -> BitVector {
let mut syndrome: BitVector = bit_vector::from_sparse_indices(relative_program.count_checks as u64, &[]);
let check_model_types = self.check_model_types.read().await;
for (&cid, &start_index) in mapping.global_cid_of.iter().zip(mapping.start_indices.iter()) {
let check_model = check_models.get(&cid).unwrap();
let check_model_type = check_model_types.get(&check_model.instance.ctype).unwrap();
let gid = check_model.instance.gid;
let gadget = gadgets.get(&gid).unwrap();
let expanded_remote_ref = check_model.expanded_remote_gadgets.borrow();
let expanded_remotes = expanded_remote_ref.as_ref().unwrap();
let local_outcomes = gadget.outcomes.as_ref().unwrap();
for (check_index, check) in check_model_type.checks.iter().enumerate() {
let mut is_defect = check.naturally_flipped;
for measurement in &check.measurements {
if let Some(ri) = measurement.remote_gadget {
let remote_gid = expanded_remotes[ri as usize].unwrap();
let remote_gadget = gadgets.get(&remote_gid).unwrap();
is_defect ^= get_bit(
remote_gadget.outcomes.as_ref().unwrap(),
measurement.measurement_index
+ check_model.modified_remote_gadgets[ri as usize]
.as_ref()
.unwrap()
.measurement_bias,
);
} else {
is_defect ^= get_bit(local_outcomes, measurement.measurement_index);
}
}
set_bit(&mut syndrome, (start_index + check_index) as u64, is_defect);
}
}
syndrome
}
async fn decoding_hypergraph(
&self,
relative_program: &RelativeProgram,
mapping: &RelativeMapping,
check_models: &HashMap<u64, CheckModel>,
error_models: &HashMap<u64, ErrorModel>,
) -> (DecodingHypergraph, Arc<Vec<ErrorIndex>>) {
let error_model_types = self.error_model_types.read().await;
let mut hyperedges: Vec<Hyperedge> = vec![];
let mut error_reference: Vec<ErrorIndex> = vec![];
for (local_cid, &cid) in mapping.global_cid_of.iter().enumerate() {
let check_model = check_models.get(&cid).unwrap();
for &eid in &check_model.attaching_eid_vec {
let local_eid = mapping.local_eid_of[&eid];
let error_model = error_models.get(&eid).unwrap();
let error_model_type = error_model_types.get(&error_model.instance.etype).unwrap();
let expanded_remote_ref = error_model.expanded_remote_check_models.borrow();
let expanded_remotes = expanded_remote_ref.as_ref().unwrap();
let mut errors = &error_model_type.errors;
let modified_errors: Option<Vec<bin::error_model_type::Error>>;
if let Some(modifier) = &error_model.instance.modifier
&& let Some(probability_modifier) = &modifier.probability_modifier
{
let mut new_errors = errors.clone();
for (error_index, &probability) in probability_modifier.probabilities.iter().enumerate() {
new_errors[error_index].probability = probability;
}
for (&error_index, &probability) in probability_modifier
.sparse_indices
.iter()
.zip(probability_modifier.sparse_probabilities.iter())
{
new_errors[error_index as usize].probability = probability;
}
modified_errors = Some(new_errors);
errors = modified_errors.as_ref().unwrap();
}
let local_start_index = mapping.start_indices[local_cid] as u64;
for (error_index, error) in errors.iter().enumerate() {
if error.probability <= 0.0 {
continue;
}
let mut vertices: Vec<u64> = vec![];
for check in &error.checks {
if let Some(ri) = check.remote_check_model {
let remote_cid = expanded_remotes[ri as usize].unwrap();
let remote_local_cid = mapping.local_cid_of[&remote_cid];
let remote_start_index = mapping.start_indices[remote_local_cid] as u64;
vertices.push(
remote_start_index
+ check.check_index
+ error_model.modified_remote_check_models[ri as usize]
.as_ref()
.unwrap()
.check_bias,
);
} else {
vertices.push(local_start_index + check.check_index);
}
}
if vertices.is_empty() {
continue; }
error_reference.push(ErrorIndex {
eid: local_eid as u64,
error_index: error_index as u64,
});
hyperedges.push(Hyperedge {
vertices,
probability: error.probability,
});
}
}
}
let hypergraph = DecodingHypergraph {
vertex_num: relative_program.count_checks as u64,
hyperedges,
};
(hypergraph, Arc::new(error_reference))
}
async fn expand_remote_gadgets(
check_model: &bin::CheckModel,
modified_remote_gadgets: &Vec<Option<bin::check_model_type::RemoteGadget>>,
gadgets: &RwLock<HashMap<u64, Gadget>>,
token: CancellationToken,
) -> Vec<Option<u64>> {
let mut expanded_remote_gid_vec: Vec<Option<u64>> = vec![None; modified_remote_gadgets.len()];
for ri in 0..modified_remote_gadgets.len() {
Self::expand_remote_gadget(
&mut expanded_remote_gid_vec,
ri,
modified_remote_gadgets,
check_model.gid,
gadgets,
token.clone(),
)
.await;
}
expanded_remote_gid_vec
}
async fn expand_remote_gadget(
expanded_remote_gid_vec: &mut Vec<Option<u64>>,
ri: usize,
remote_gadgets: &Vec<Option<bin::check_model_type::RemoteGadget>>,
gid: u64,
gadgets: &RwLock<HashMap<u64, Gadget>>,
token: CancellationToken,
) {
if expanded_remote_gid_vec[ri].is_some() || remote_gadgets[ri].is_none() {
return; }
let remote_gadget = remote_gadgets[ri].as_ref().unwrap();
if let Some(absolute_gid) = remote_gadget.absolute_gid {
expanded_remote_gid_vec[ri] = Some(absolute_gid);
return;
}
let previous = if let Some(previous) = remote_gadget.previous_remote_gadget {
Box::pin(Self::expand_remote_gadget(
expanded_remote_gid_vec,
previous as usize,
remote_gadgets,
gid,
gadgets,
token.clone(),
))
.await;
expanded_remote_gid_vec[previous as usize].unwrap()
} else {
gid
};
let gadgets = gadgets.read().await;
let gadget = gadgets.get(&previous).unwrap();
match remote_gadget.port.unwrap() {
bin::check_model_type::remote_gadget::Port::Output(port) => {
let next = get_or_receiver(&gadget.outputs[port as usize], token);
drop(gadgets); let next = match next {
Ok(next) => Some(next),
Err(handle) => handle.await.unwrap_or(None),
};
if let Some(next) = next {
expanded_remote_gid_vec[ri] = Some(next.gid);
}
}
bin::check_model_type::remote_gadget::Port::Input(port) => {
let connector = &gadget.instance.connectors[port as usize];
expanded_remote_gid_vec[ri] = Some(connector.gid);
}
}
}
async fn expand_remote_check_models(
error_model: &bin::ErrorModel,
modified_remote_check_models: &Vec<Option<bin::error_model_type::RemoteCheckModel>>,
gadgets: &RwLock<HashMap<u64, Gadget>>,
check_models: &RwLock<HashMap<u64, CheckModel>>,
token: CancellationToken,
) -> Vec<Option<u64>> {
let gid = check_models.read().await.get(&error_model.cid).unwrap().instance.gid;
let mut expanded_remote_gid_vec: Vec<Option<u64>> = vec![None; modified_remote_check_models.len()];
for ri in 0..modified_remote_check_models.len() {
Self::expand_remote_check_model(
&mut expanded_remote_gid_vec,
ri,
modified_remote_check_models,
gid,
gadgets,
token.clone(),
)
.await;
}
let mut expanded_remote_cid_vec = Vec::with_capacity(modified_remote_check_models.len());
let mut gadgets_read = gadgets.read().await;
for (ri, gid) in expanded_remote_gid_vec.into_iter().enumerate() {
if let Some(gid) = gid {
if gid == u64::MAX - 1 {
let absolute_cid = modified_remote_check_models[ri]
.as_ref()
.unwrap()
.absolute_cid
.expect("absolute_cid should be present when sentinel is used");
expanded_remote_cid_vec.push(Some(absolute_cid));
continue;
}
let gadget = gadgets_read.get(&gid).unwrap();
let cid = if let Some(&cid) = gadget.binding_cid.borrow().as_ref() {
cid
} else {
let mut rx = gadget.binding_cid.subscribe();
drop(gadgets_read);
let cid = tokio::select! {
result = rx.wait_for(|v| v.is_some()) => {
match result {
Ok(v) => v.unwrap(),
Err(_) => return expanded_remote_cid_vec,
}
}
_ = token.cancelled() => { return expanded_remote_cid_vec; }
};
gadgets_read = gadgets.read().await;
cid
};
expanded_remote_cid_vec.push(Some(cid));
} else {
expanded_remote_cid_vec.push(None);
}
}
expanded_remote_cid_vec
}
async fn expand_remote_check_model(
expanded_remotes: &mut Vec<Option<u64>>,
ri: usize,
remote_check_models: &Vec<Option<bin::error_model_type::RemoteCheckModel>>,
gid: u64,
gadgets: &RwLock<HashMap<u64, Gadget>>,
token: CancellationToken,
) {
if expanded_remotes[ri].is_some() || remote_check_models[ri].is_none() {
return; }
let remote_check_model = remote_check_models[ri].as_ref().unwrap();
if remote_check_model.absolute_cid.is_some() {
expanded_remotes[ri] = Some(u64::MAX - 1); return;
}
let previous = if let Some(previous) = remote_check_model.previous_remote_check_model {
Box::pin(Self::expand_remote_check_model(
expanded_remotes,
previous as usize,
remote_check_models,
gid,
gadgets,
token.clone(),
))
.await;
expanded_remotes[previous as usize].unwrap()
} else {
gid
};
let gadgets = gadgets.read().await;
let gadget = gadgets.get(&previous).unwrap();
match remote_check_model.port.unwrap() {
bin::error_model_type::remote_check_model::Port::Output(port) => {
let next = get_or_receiver(&gadget.outputs[port as usize], token);
drop(gadgets); let next = match next {
Ok(gid) => Some(gid),
Err(handle) => handle.await.unwrap_or(None),
};
if let Some(next) = next {
expanded_remotes[ri] = Some(next.gid);
}
}
bin::error_model_type::remote_check_model::Port::Input(port) => {
let connector = &gadget.instance.connectors[port as usize];
expanded_remotes[ri] = Some(connector.gid);
}
}
}
}
#[tonic::async_trait]
impl coordinator::coordinator_server::Coordinator for MonolithicCoordinator {
async fn load_library(&self, request: Request<bin::Library>) -> Result<Response<()>, Status> {
let library = request.into_inner();
let mut port_types = self.port_types.write().await;
for port_type in library.port_types.into_iter() {
if port_types.contains_key(&port_type.ptype) {
return Err(Status::already_exists(format!("ptype={}", port_type.ptype)));
}
port_types.insert(port_type.ptype, Arc::new(port_type));
}
drop(port_types);
let mut gadget_types = self.gadget_types.write().await;
for gadget_type in library.gadget_types.into_iter() {
if gadget_types.contains_key(&gadget_type.gtype) {
return Err(Status::already_exists(format!("gtype={}", gadget_type.gtype)));
}
gadget_types.insert(gadget_type.gtype, Arc::new(gadget_type));
}
drop(gadget_types);
let mut check_model_types = self.check_model_types.write().await;
for check_model_type in library.check_model_types.into_iter() {
if check_model_types.contains_key(&check_model_type.ctype) {
return Err(Status::already_exists(format!("ctype={}", check_model_type.ctype)));
}
check_model_types.insert(check_model_type.ctype, Arc::new(check_model_type));
}
drop(check_model_types);
let mut error_model_types = self.error_model_types.write().await;
for error_model_type in library.error_model_types.into_iter() {
if error_model_types.contains_key(&error_model_type.etype) {
return Err(Status::already_exists(format!("etype={}", error_model_type.etype)));
}
error_model_types.insert(error_model_type.etype, Arc::new(error_model_type));
}
drop(error_model_types);
Ok(().into())
}
async fn unload(&self, _unload: Request<coordinator::UnloadLibrary>) -> Result<Response<()>, Status> {
unimplemented!()
}
async fn execute(&self, request: Request<bin::Instruction>) -> Result<Response<coordinator::ExecuteResponse>, Status> {
let instruction = request.into_inner();
let create = instruction
.create
.ok_or_else(|| Status::invalid_argument("unknown instruction"))?;
let id = match create {
bin::instruction::Create::Gadget(gadget) => {
let port_types = self.port_types.read().await;
let gadget_types = self.gadget_types.read().await;
let mut gadgets = self.gadgets.write().await;
let gid = if gadget.gid == 0 {
let mut next_gid = self.next_gid.lock().await;
while gadgets.contains_key(&*next_gid) {
*next_gid += 1;
}
let gid = *next_gid;
*next_gid += 1;
gid
} else {
gadget.gid
};
let gadget_type = gadget_types
.get(&gadget.gtype)
.ok_or_else(|| Status::not_found(format!("gtype={}", gadget.gtype)))?;
debug_assert!(gadget.connectors.len() == gadget_type.inputs.len());
let mut pending_subgraphs = self.pending_subgraphs.lock().await;
let mut gid_to_union_index = self.gid_to_union_index.lock().await;
let union_index = pending_subgraphs.payload.len();
pending_subgraphs.insert(MonolithicUnionNode::default());
gid_to_union_index.insert(gid, union_index);
for (port, connector) in gadget.connectors.iter().enumerate() {
debug_assert!(gadgets.contains_key(&connector.gid));
debug_assert!({
let peer_outputs = &gadgets[&connector.gid].outputs;
(connector.port as usize) < peer_outputs.len()
&& peer_outputs[connector.port as usize].borrow().is_none()
});
let peer_union_index = gid_to_union_index[&connector.gid];
pending_subgraphs.union(union_index, peer_union_index);
gadgets.get_mut(&connector.gid).unwrap().outputs[connector.port as usize]
.send_replace(Some(bin::gadget::Connector { gid, port: port as u64 }));
}
let node = pending_subgraphs.get_mut(union_index);
node.num_unconnected_outputs += gadget_type.outputs.len();
node.num_unconnected_outputs -= gadget.connectors.len();
node.num_unloaded_gadgets += 1;
let mut tracker = self.pauli_frame_tracker.lock().await;
tracker.add_gadget(gid, gadget_type, gadget.modifier.as_ref(), &port_types, &gadget.connectors);
let (tx, rx) = oneshot::channel();
let mut gadget = gadget;
gadget.gid = gid;
gadgets.insert(
gid,
Gadget {
instance: gadget,
outcomes: None,
binding_cid: watch::channel(None).0,
outputs: gadget_type.outputs.iter().map(|_| watch::channel(None).0).collect(),
tx,
rx: Some(rx),
},
);
gid
}
bin::instruction::Create::CheckModel(check_model) => {
let check_model_types = self.check_model_types.read().await;
let mut gadgets = self.gadgets.write().await;
let mut check_models = self.check_models.write().await;
let cid = if check_model.cid == 0 {
let mut next_cid = self.next_cid.lock().await;
while check_models.contains_key(&*next_cid) {
*next_cid += 1;
}
let cid = *next_cid;
*next_cid += 1;
cid
} else {
check_model.cid
};
let check_model_type = check_model_types
.get(&check_model.ctype)
.ok_or_else(|| Status::not_found(format!("ctype={}", check_model.ctype)))?;
let gadget = gadgets.get_mut(&check_model.gid).ok_or_else(|| {
Status::invalid_argument(format!("cid={cid} binding to unknown gid={}", check_model.gid))
})?;
debug_assert!(check_model_type.gtype == WILDCARD || check_model_type.gtype == gadget.instance.gtype);
debug_assert!(gadget.binding_cid.borrow().is_none());
gadget.binding_cid.send_replace(Some(cid));
let mut modified_remote: Vec<_> = check_model_type.remote_gadgets.iter().cloned().map(Some).collect();
if let Some(modifier) = &check_model.modifier {
for rereoute in &modifier.reroute_remote_gadgets {
while (rereoute.remote_gadget_index as usize) >= modified_remote.len() {
modified_remote.push(None);
}
modified_remote[rereoute.remote_gadget_index as usize] = rereoute.value.clone();
}
}
let modified_remote = Arc::new(modified_remote);
let mut check_model = check_model;
check_model.cid = cid;
check_models.insert(
cid,
CheckModel {
instance: check_model.clone(),
attaching_eid_vec: vec![],
modified_remote_gadgets: modified_remote.clone(),
expanded_remote_gadgets: watch::channel(None).0,
},
);
let gadgets = self.gadgets.clone();
let check_models = self.check_models.clone();
if self.config.async_expand {
let token = self.cancellation.read().await.clone();
let _guard = self.task_counter.guard();
tokio::spawn(async move {
let _guard = _guard;
let expanded_remote_gadgets =
Self::expand_remote_gadgets(&check_model, &modified_remote, gadgets.as_ref(), token).await;
let mut check_models = check_models.write().await;
if let Some(cm) = check_models.get_mut(&cid) {
cm.expanded_remote_gadgets.send_replace(Some(expanded_remote_gadgets));
}
});
}
cid
}
bin::instruction::Create::ErrorModel(error_model) => {
let error_model_types = self.error_model_types.read().await;
let mut check_models = self.check_models.write().await;
let mut error_models = self.error_models.write().await;
let eid = if error_model.eid == 0 {
let mut next_eid = self.next_eid.lock().await;
while error_models.contains_key(&*next_eid) {
*next_eid += 1;
}
let eid = *next_eid;
*next_eid += 1;
eid
} else {
error_model.eid
};
let error_model_type = error_model_types
.get(&error_model.etype)
.ok_or_else(|| Status::not_found(format!("etype={}", error_model.etype)))?;
let check_model = check_models.get_mut(&error_model.cid).ok_or_else(|| {
Status::invalid_argument(format!("eid={eid} attaching to unknown cid={}", error_model.cid))
})?;
debug_assert!(error_model_type.ctype == WILDCARD || error_model_type.ctype == check_model.instance.ctype);
check_model.attaching_eid_vec.push(eid);
let mut modified_remote: Vec<_> = error_model_type.remote_check_models.iter().cloned().map(Some).collect();
if let Some(modifier) = &error_model.modifier {
for rereoute in &modifier.reroute_remote_check_models {
while (rereoute.remote_check_model_index as usize) >= modified_remote.len() {
modified_remote.push(None);
}
modified_remote[rereoute.remote_check_model_index as usize] = rereoute.value.clone();
}
}
let modified_remote = Arc::new(modified_remote);
let mut error_model = error_model;
error_model.eid = eid;
error_models.insert(
eid,
ErrorModel {
instance: error_model.clone(),
modified_remote_check_models: modified_remote.clone(),
expanded_remote_check_models: watch::channel(None).0,
},
);
let gadgets = self.gadgets.clone();
let check_models = self.check_models.clone();
let error_models = self.error_models.clone();
if self.config.async_expand {
let token = self.cancellation.read().await.clone();
let _guard = self.task_counter.guard();
tokio::spawn(async move {
let _guard = _guard;
let expanded_remote_check_models = Self::expand_remote_check_models(
&error_model,
&modified_remote,
gadgets.as_ref(),
check_models.as_ref(),
token,
)
.await;
let mut error_models = error_models.write().await;
if let Some(em) = error_models.get_mut(&eid) {
em.expanded_remote_check_models
.send_replace(Some(expanded_remote_check_models));
}
});
}
eid
}
};
Ok((coordinator::ExecuteResponse { id }).into())
}
async fn decode(&self, request: Request<coordinator::Outcomes>) -> Result<Response<coordinator::Readouts>, Status> {
let outcomes = request.into_inner();
let _task_guard = self.task_counter.guard();
let gadget_types = self.gadget_types.read().await;
let mut gadgets = self.gadgets.write().await;
let gid = outcomes.gid;
let gadget = gadgets
.get_mut(&gid)
.ok_or_else(|| Status::not_found(format!("gid={}", gid)))?;
if gadget.outcomes.is_some() {
return Err(Status::already_exists(format!("gid={} outcomes loaded", gid)));
}
gadget.outcomes.replace(
outcomes
.outcomes
.ok_or_else(|| Status::invalid_argument("missing outcomes"))?,
);
let mut pending_subgraphs = self.pending_subgraphs.lock().await;
let gid_to_union_index = self.gid_to_union_index.lock().await;
let union_index = gid_to_union_index[&gid];
let node = pending_subgraphs.get_mut(union_index);
node.num_unloaded_gadgets -= 1;
let is_final_gadget = node.num_unloaded_gadgets == 0 && node.num_unconnected_outputs == 0;
let rx = gadget.rx.take().unwrap();
let gadget_type = gadget_types.get(&gadget.instance.gtype).unwrap();
let mut readouts = Vec::with_capacity(gadget_type.readouts.len());
let data: &BitVector = gadget.outcomes.as_ref().unwrap();
for readout in gadget_type.readouts.iter() {
let mut value = false;
for &mi in readout.measurement_indices.iter() {
value ^= get_bit(data, mi);
}
readouts.push(value);
}
self.pauli_frame_tracker.lock().await.load_raw(gid, &readouts, data);
drop(gid_to_union_index);
drop(pending_subgraphs);
drop(gadgets);
drop(gadget_types);
if is_final_gadget {
self.decode_subgraph(gid).await;
}
let readouts = rx.await.map_err(|_| Status::internal(format!("gid={} receive error", gid)))?;
return Ok((coordinator::Readouts {
gid,
readouts: Some(readouts),
..Default::default()
})
.into());
}
async fn reset(&self, request: Request<coordinator::ResetRequest>) -> Result<Response<()>, Status> {
let flags = request.into_inner();
{
let token = self.cancellation.read().await;
token.cancel();
}
self.task_counter.wait_for_zero().await;
{
let mut token = self.cancellation.write().await;
*token = CancellationToken::new();
}
if flags.reset_library {
self.port_types.write().await.clear();
self.gadget_types.write().await.clear();
self.check_model_types.write().await.clear();
self.error_model_types.write().await.clear();
}
self.gadgets.write().await.clear();
self.check_models.write().await.clear();
self.error_models.write().await.clear();
*self.next_gid.lock().await = 1;
*self.next_cid.lock().await = 1;
*self.next_eid.lock().await = 1;
let mut pending_subgraphs = self.pending_subgraphs.lock().await;
pending_subgraphs.remove_all();
self.gid_to_union_index.lock().await.clear();
self.pauli_frame_tracker.lock().await.reset();
self.black_box_decoder
.clone()
.reset(blackbox_decoder::ResetRequest {
reset_hypergraphs: flags.reset_decoder_service,
..Default::default()
})
.await
.map_err(|e| Status::internal(format!("reset decoder service error: {}", e)))?;
if flags.reset_decoder_service {
let mut loaded_decoders = self.loaded_decoders.write().await;
loaded_decoders.clear();
}
Ok(().into())
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MonolithicUnionNode {
pub set_size: usize,
pub num_unloaded_gadgets: usize,
pub num_unconnected_outputs: usize,
}
impl UnionNodeTrait for MonolithicUnionNode {
#[inline]
fn union(left: &Self, right: &Self) -> (bool, Self) {
let result = Self {
set_size: left.set_size + right.set_size,
num_unloaded_gadgets: left.num_unloaded_gadgets + right.num_unloaded_gadgets,
num_unconnected_outputs: left.num_unconnected_outputs + right.num_unconnected_outputs,
};
(left.set_size >= right.set_size, result)
}
#[inline]
fn clear(&mut self) {
self.set_size = 1;
}
#[inline]
fn default() -> Self {
Self {
set_size: 1,
num_unloaded_gadgets: 0,
num_unconnected_outputs: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bin::error_model::ErrorModelModifier;
use crate::bin::error_model_type::Error;
fn mapping_with_eids(global_eid_of: Vec<u64>) -> RelativeMapping {
RelativeMapping {
global_eid_of,
..Default::default()
}
}
fn pm_dense(probabilities: Vec<f64>) -> bin::ProbabilityModifier {
bin::ProbabilityModifier {
probabilities,
sparse_indices: vec![],
sparse_probabilities: vec![],
}
}
fn make_error_model_instance(eid: u64, etype: u64, modifier: Option<bin::ProbabilityModifier>) -> bin::ErrorModel {
bin::ErrorModel {
eid,
etype,
cid: 1,
modifier: modifier.map(|p| ErrorModelModifier {
probability_modifier: Some(p),
reroute_remote_check_models: vec![],
}),
..Default::default()
}
}
fn make_error_model(instance: bin::ErrorModel) -> ErrorModel {
let (sender, _receiver) = watch::channel(None);
ErrorModel {
instance,
modified_remote_check_models: Arc::new(vec![]),
expanded_remote_check_models: sender,
}
}
fn make_emt(etype: u64, errors: Vec<Error>) -> bin::ErrorModelType {
bin::ErrorModelType {
etype,
ctype: 1,
errors,
remote_check_models: vec![],
..Default::default()
}
}
fn make_error(probability: f64) -> Error {
Error {
checks: vec![bin::error_model_type::RemoteCheck {
remote_check_model: None,
check_index: 0,
}],
probability,
..Default::default()
}
}
#[test]
fn build_modifier_fingerprints_picks_up_probability_modifier() {
let mapping = mapping_with_eids(vec![1]);
let mut emts: HashMap<u64, Arc<bin::ErrorModelType>> = HashMap::new();
emts.insert(1, Arc::new(make_emt(1, vec![make_error(0.1)])));
let mut models_a: HashMap<u64, ErrorModel> = HashMap::new();
models_a.insert(
1,
make_error_model(make_error_model_instance(1, 1, Some(pm_dense(vec![0.1])))),
);
let mut models_b: HashMap<u64, ErrorModel> = HashMap::new();
models_b.insert(
1,
make_error_model(make_error_model_instance(1, 1, Some(pm_dense(vec![0.2])))),
);
let fps_a = build_modifier_fingerprints(&mapping, &models_a, &emts);
let fps_b = build_modifier_fingerprints(&mapping, &models_b, &emts);
assert_ne!(fps_a, fps_b);
assert_eq!(fps_a.len(), 1);
}
#[test]
fn build_modifier_fingerprints_picks_up_etype_structure() {
let mapping = mapping_with_eids(vec![1]);
let mut models: HashMap<u64, ErrorModel> = HashMap::new();
models.insert(1, make_error_model(make_error_model_instance(1, 1, None)));
let mut emts_v1: HashMap<u64, Arc<bin::ErrorModelType>> = HashMap::new();
emts_v1.insert(1, Arc::new(make_emt(1, vec![make_error(0.1)])));
let mut emts_v2: HashMap<u64, Arc<bin::ErrorModelType>> = HashMap::new();
emts_v2.insert(1, Arc::new(make_emt(1, vec![make_error(0.1), make_error(0.2)])));
let fps_v1 = build_modifier_fingerprints(&mapping, &models, &emts_v1);
let fps_v2 = build_modifier_fingerprints(&mapping, &models, &emts_v2);
assert_ne!(fps_v1, fps_v2);
}
#[test]
fn build_modifier_fingerprints_is_positional() {
let mut models: HashMap<u64, ErrorModel> = HashMap::new();
models.insert(
1,
make_error_model(make_error_model_instance(1, 1, Some(pm_dense(vec![0.1])))),
);
models.insert(
2,
make_error_model(make_error_model_instance(2, 1, Some(pm_dense(vec![0.9])))),
);
let mut emts: HashMap<u64, Arc<bin::ErrorModelType>> = HashMap::new();
emts.insert(1, Arc::new(make_emt(1, vec![make_error(0.1)])));
let mapping_ab = mapping_with_eids(vec![1, 2]);
let mapping_ba = mapping_with_eids(vec![2, 1]);
let fps_ab = build_modifier_fingerprints(&mapping_ab, &models, &emts);
let fps_ba = build_modifier_fingerprints(&mapping_ba, &models, &emts);
assert_ne!(fps_ab, fps_ba);
}
#[test]
fn build_modifier_fingerprints_equal_for_identical_state() {
let mapping = mapping_with_eids(vec![1]);
let mut emts: HashMap<u64, Arc<bin::ErrorModelType>> = HashMap::new();
emts.insert(1, Arc::new(make_emt(1, vec![make_error(0.1)])));
let mut models_a: HashMap<u64, ErrorModel> = HashMap::new();
models_a.insert(
1,
make_error_model(make_error_model_instance(1, 1, Some(pm_dense(vec![0.1])))),
);
let mut models_b: HashMap<u64, ErrorModel> = HashMap::new();
models_b.insert(
1,
make_error_model(make_error_model_instance(1, 1, Some(pm_dense(vec![0.1])))),
);
let fps_a = build_modifier_fingerprints(&mapping, &models_a, &emts);
let fps_b = build_modifier_fingerprints(&mapping, &models_b, &emts);
assert_eq!(fps_a, fps_b);
}
}