use super::util::*;
use super::dual_module::*;
use crate::derivative::Derivative;
use std::collections::{BTreeMap, BTreeSet, HashMap};
use super::complete_graph::*;
use super::visualize::*;
use super::pointers::*;
#[cfg(feature="python_binding")]
use pyo3::prelude::*;
#[derive(Derivative)]
#[derivative(Debug)]
#[cfg_attr(feature = "python_binding", cfg_eval)]
#[cfg_attr(feature = "python_binding", pyclass)]
pub struct IntermediateMatching {
pub peer_matchings: Vec<((DualNodePtr, DualNodeWeak), (DualNodePtr, DualNodeWeak))>,
pub virtual_matchings: Vec<((DualNodePtr, DualNodeWeak), VertexIndex)>,
}
#[derive(Derivative)]
#[derivative(Debug)]
#[cfg_attr(feature = "python_binding", cfg_eval)]
#[cfg_attr(feature = "python_binding", pyclass)]
pub struct PerfectMatching {
pub peer_matchings: Vec<(DualNodePtr, DualNodePtr)>,
pub virtual_matchings: Vec<(DualNodePtr, VertexIndex)>,
}
pub trait PrimalModuleImpl {
fn new_empty(solver_initializer: &SolverInitializer) -> Self;
fn clear(&mut self);
fn load_defect_dual_node(&mut self, dual_node_ptr: &DualNodePtr);
fn load_defect<D: DualModuleImpl>(&mut self, defect_vertex: VertexIndex, interface_ptr: &DualModuleInterfacePtr, dual_module: &mut D) {
interface_ptr.create_defect_node(defect_vertex, dual_module);
let interface = interface_ptr.read_recursive();
let index = interface.nodes_length - 1;
self.load_defect_dual_node(interface.nodes[index].as_ref().expect("must load a fresh dual module interface, found empty node"))
}
fn load(&mut self, interface_ptr: &DualModuleInterfacePtr) {
let interface = interface_ptr.read_recursive();
debug_assert!(interface.parent.is_none(), "cannot load an interface that is already fused");
debug_assert!(interface.children.is_none(), "please customize load function if interface is fused");
for index in 0..interface.nodes_length as NodeIndex {
let node = &interface.nodes[index as usize];
debug_assert!(node.is_some(), "must load a fresh dual module interface, found empty node");
let node_ptr = node.as_ref().unwrap();
let node = node_ptr.read_recursive();
debug_assert!(matches!(node.class, DualNodeClass::DefectVertex{ .. }), "must load a fresh dual module interface, found a blossom");
debug_assert_eq!(node.index, index, "must load a fresh dual module interface, found index out of order");
self.load_defect_dual_node(node_ptr);
}
}
fn resolve<D: DualModuleImpl>(&mut self, group_max_update_length: GroupMaxUpdateLength, interface: &DualModuleInterfacePtr, dual_module: &mut D);
fn intermediate_matching<D: DualModuleImpl>(&mut self, interface: &DualModuleInterfacePtr, dual_module: &mut D) -> IntermediateMatching;
fn perfect_matching<D: DualModuleImpl>(&mut self, interface: &DualModuleInterfacePtr, dual_module: &mut D) -> PerfectMatching {
let intermediate_matching = self.intermediate_matching(interface, dual_module);
intermediate_matching.get_perfect_matching()
}
fn solve<D: DualModuleImpl>(&mut self, interface: &DualModuleInterfacePtr, syndrome_pattern: &SyndromePattern, dual_module: &mut D) {
self.solve_step_callback(interface, syndrome_pattern, dual_module, |_, _, _, _| {})
}
fn solve_visualizer<D: DualModuleImpl + FusionVisualizer>(&mut self, interface: &DualModuleInterfacePtr, syndrome_pattern: &SyndromePattern, dual_module: &mut D
, visualizer: Option<&mut Visualizer>) where Self: FusionVisualizer + Sized {
if let Some(visualizer) = visualizer {
self.solve_step_callback(interface, syndrome_pattern, dual_module, |interface, dual_module, primal_module, group_max_update_length| {
if cfg!(debug_assertions) {
println!("group_max_update_length: {:?}", group_max_update_length);
}
if let Some(length) = group_max_update_length.get_none_zero_growth() {
visualizer.snapshot_combined(format!("grow {length}"), vec![interface, dual_module, primal_module]).unwrap();
} else {
let first_conflict = format!("{:?}", group_max_update_length.peek().unwrap());
visualizer.snapshot_combined(format!("resolve {first_conflict}"), vec![interface, dual_module, primal_module]).unwrap();
};
});
visualizer.snapshot_combined("solved".to_string(), vec![interface, dual_module, self]).unwrap();
} else {
self.solve(interface, syndrome_pattern, dual_module);
}
}
fn solve_step_callback<D: DualModuleImpl, F>(&mut self, interface: &DualModuleInterfacePtr, syndrome_pattern: &SyndromePattern, dual_module: &mut D, callback: F)
where F: FnMut(&DualModuleInterfacePtr, &mut D, &mut Self, &GroupMaxUpdateLength) {
interface.load(syndrome_pattern, dual_module);
self.load(interface);
self.solve_step_callback_interface_loaded(interface, dual_module, callback);
}
fn solve_step_callback_interface_loaded<D: DualModuleImpl, F>(&mut self, interface: &DualModuleInterfacePtr, dual_module: &mut D, mut callback: F)
where F: FnMut(&DualModuleInterfacePtr, &mut D, &mut Self, &GroupMaxUpdateLength) {
let mut group_max_update_length = dual_module.compute_maximum_update_length();
while !group_max_update_length.is_empty() {
callback(interface, dual_module, self, &group_max_update_length);
if let Some(length) = group_max_update_length.get_none_zero_growth() {
interface.grow(length, dual_module);
} else {
self.resolve(group_max_update_length, interface, dual_module);
}
group_max_update_length = dual_module.compute_maximum_update_length();
}
}
fn generate_profiler_report(&self) -> serde_json::Value { json!({}) }
}
impl Default for IntermediateMatching {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(feature = "python_binding", cfg_eval)]
#[cfg_attr(feature = "python_binding", pymethods)]
impl IntermediateMatching {
#[cfg_attr(feature = "python_binding", new)]
pub fn new() -> Self {
Self {
peer_matchings: vec![],
virtual_matchings: vec![],
}
}
pub fn append(&mut self, other: &mut Self) {
self.peer_matchings.append(&mut other.peer_matchings);
self.virtual_matchings.append(&mut other.virtual_matchings);
}
pub fn get_perfect_matching(&self) -> PerfectMatching {
let mut perfect_matching = PerfectMatching::new();
for ((dual_node_ptr_1, touching_weak_1), (dual_node_ptr_2, touching_weak_2)) in self.peer_matchings.iter() {
let touching_ptr_1 = touching_weak_1.upgrade_force();
let touching_ptr_2 = touching_weak_2.upgrade_force();
perfect_matching.peer_matchings.extend(Self::expand_peer_matching(dual_node_ptr_1, &touching_ptr_1, dual_node_ptr_2, &touching_ptr_2));
}
for ((dual_node_ptr, touching_weak), virtual_vertex) in self.virtual_matchings.iter() {
let touching_ptr = touching_weak.upgrade_force();
perfect_matching.peer_matchings.extend(Self::expand_blossom(dual_node_ptr, &touching_ptr));
perfect_matching.virtual_matchings.push((touching_ptr, *virtual_vertex));
}
perfect_matching
}
#[cfg(feature = "python_binding")]
fn __repr__(&self) -> String { format!("{:?}", self) }
#[cfg(feature = "python_binding")]
#[getter]
pub fn get_peer_matchings(&self) -> Vec<((NodeIndex, NodeIndex), (NodeIndex, NodeIndex))> {
self.peer_matchings.iter().map(|((a, b), (c, d))|
((a.updated_index(), b.upgrade_force().updated_index()), (c.updated_index(), d.upgrade_force().updated_index()))).collect()
}
#[cfg(feature = "python_binding")]
#[getter]
pub fn get_virtual_matchings(&self) -> Vec<((NodeIndex, NodeIndex), VertexIndex)> {
self.virtual_matchings.iter().map(|((a, b), c)|
((a.updated_index(), b.upgrade_force().updated_index()), *c)).collect()
}
}
impl IntermediateMatching {
pub fn expand_peer_matching(dual_node_ptr_1: &DualNodePtr, touching_ptr_1: &DualNodePtr, dual_node_ptr_2: &DualNodePtr
, touching_ptr_2: &DualNodePtr) -> Vec<(DualNodePtr, DualNodePtr)> {
let mut perfect_matching = vec![];
perfect_matching.extend(Self::expand_blossom(dual_node_ptr_1, touching_ptr_1));
perfect_matching.extend(Self::expand_blossom(dual_node_ptr_2, touching_ptr_2));
perfect_matching.push((touching_ptr_1.clone(), touching_ptr_2.clone()));
perfect_matching
}
pub fn expand_blossom(blossom_ptr: &DualNodePtr, touching_ptr: &DualNodePtr) -> Vec<(DualNodePtr, DualNodePtr)> {
let mut perfect_matching = vec![];
let mut child_ptr = touching_ptr.clone();
while &child_ptr != blossom_ptr {
let child_weak = child_ptr.downgrade();
let child = child_ptr.read_recursive();
if let Some(parent_blossom_weak) = child.parent_blossom.as_ref() {
let parent_blossom_ptr = parent_blossom_weak.upgrade_force();
let parent_blossom = parent_blossom_ptr.read_recursive();
if let DualNodeClass::Blossom{ nodes_circle, touching_children } = &parent_blossom.class {
let idx = nodes_circle.iter().position(|ptr| ptr == &child_weak).expect("should find child");
debug_assert!(nodes_circle.len() % 2 == 1 && nodes_circle.len() >= 3, "must be a valid blossom");
for i in (0..(nodes_circle.len()-1)).step_by(2) {
let idx_1 = (idx + i + 1) % nodes_circle.len();
let idx_2 = (idx + i + 2) % nodes_circle.len();
let dual_node_ptr_1 = nodes_circle[idx_1].upgrade_force();
let dual_node_ptr_2 = nodes_circle[idx_2].upgrade_force();
let touching_ptr_1 = touching_children[idx_1].1.upgrade_force(); let touching_ptr_2 = touching_children[idx_2].0.upgrade_force(); perfect_matching.extend(Self::expand_peer_matching(
&dual_node_ptr_1, &touching_ptr_1, &dual_node_ptr_2, &touching_ptr_2
))
}
}
drop(child);
child_ptr = parent_blossom_ptr.clone();
} else { panic!("cannot find parent of {}", child.index) }
}
perfect_matching
}
}
impl Default for PerfectMatching {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(feature = "python_binding", cfg_eval)]
#[cfg_attr(feature = "python_binding", pymethods)]
impl PerfectMatching {
#[cfg_attr(feature = "python_binding", new)]
pub fn new() -> Self {
Self {
peer_matchings: vec![],
virtual_matchings: vec![],
}
}
pub fn legacy_get_mwpm_result(&self, defect_vertices: Vec<VertexIndex>) -> Vec<SyndromeIndex> {
let mut peer_matching_maps = BTreeMap::<VertexIndex, VertexIndex>::new();
for (ptr_1, ptr_2) in self.peer_matchings.iter() {
let a_vid = {
let node = ptr_1.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
let b_vid = {
let node = ptr_2.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
peer_matching_maps.insert(a_vid, b_vid);
peer_matching_maps.insert(b_vid, a_vid);
}
let mut virtual_matching_maps = BTreeMap::<VertexIndex, VertexIndex>::new();
for (ptr, virtual_vertex) in self.virtual_matchings.iter() {
let a_vid = {
let node = ptr.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
virtual_matching_maps.insert(a_vid, *virtual_vertex);
}
let mut mwpm_result = Vec::with_capacity(defect_vertices.len());
for defect_vertex in defect_vertices.iter() {
if let Some(a) = peer_matching_maps.get(defect_vertex) {
mwpm_result.push(*a);
} else if let Some(v) = virtual_matching_maps.get(defect_vertex) {
mwpm_result.push(*v);
} else { panic!("cannot find defect vertex {}", defect_vertex) }
}
mwpm_result
}
#[cfg(feature = "python_binding")]
fn __repr__(&self) -> String { format!("{:?}", self) }
#[cfg(feature = "python_binding")]
#[getter]
pub fn get_peer_matchings(&self) -> Vec<(NodeIndex, NodeIndex)> {
self.peer_matchings.iter().map(|(a, b)|
(a.updated_index(), b.updated_index())).collect()
}
#[cfg(feature = "python_binding")]
#[getter]
pub fn get_virtual_matchings(&self) -> Vec<(NodeIndex, VertexIndex)> {
self.virtual_matchings.iter().map(|(a, b)|
(a.updated_index(), *b)).collect()
}
}
impl FusionVisualizer for PerfectMatching {
fn snapshot(&self, abbrev: bool) -> serde_json::Value {
let primal_nodes = if self.peer_matchings.is_empty() && self.virtual_matchings.is_empty() {
vec![]
} else {
let mut maximum_node_index = 0;
for (ptr_1, ptr_2) in self.peer_matchings.iter() {
maximum_node_index = std::cmp::max(maximum_node_index, ptr_1.get_ancestor_blossom().read_recursive().index);
maximum_node_index = std::cmp::max(maximum_node_index, ptr_2.get_ancestor_blossom().read_recursive().index);
}
for (ptr, _virtual_vertex) in self.virtual_matchings.iter() {
maximum_node_index = std::cmp::max(maximum_node_index, ptr.get_ancestor_blossom().read_recursive().index);
}
let mut primal_nodes = vec![json!(null); maximum_node_index as usize + 1];
for (ptr_1, ptr_2) in self.peer_matchings.iter() {
for (ptr_a, ptr_b) in [(ptr_1, ptr_2), (ptr_2, ptr_1)] {
primal_nodes[ptr_a.read_recursive().index as usize] = json!({
if abbrev { "m" } else { "temporary_match" }: {
if abbrev { "p" } else { "peer" }: ptr_b.read_recursive().index,
if abbrev { "t" } else { "touching" }: ptr_a.read_recursive().index,
},
if abbrev { "t" } else { "tree_node" }: {
if abbrev { "r" } else { "root" }: ptr_a.read_recursive().index,
if abbrev { "d" } else { "depth" }: 1,
},
});
}
}
for (ptr, virtual_vertex) in self.virtual_matchings.iter() {
primal_nodes[ptr.read_recursive().index as usize] = json!({
if abbrev { "m" } else { "temporary_match" }: {
if abbrev { "v" } else { "virtual_vertex" }: virtual_vertex,
if abbrev { "t" } else { "touching" }: ptr.read_recursive().index,
},
if abbrev { "t" } else { "tree_node" }: {
if abbrev { "r" } else { "root" }: ptr.read_recursive().index,
if abbrev { "d" } else { "depth" }: 1,
},
});
}
primal_nodes
};
json!({
"primal_nodes": primal_nodes,
})
}
}
#[derive(Debug, Clone)]
pub struct SubGraphBuilder {
pub vertex_num: VertexNum,
vertex_pair_edges: HashMap<(VertexIndex, VertexIndex), EdgeIndex>,
pub complete_graph: CompleteGraph,
pub subgraph: BTreeSet<EdgeIndex>,
}
impl SubGraphBuilder {
pub fn new(initializer: &SolverInitializer) -> Self {
let mut vertex_pair_edges = HashMap::with_capacity(initializer.weighted_edges.len());
for (edge_index, (i, j, _)) in initializer.weighted_edges.iter().enumerate() {
let id = if i < j { (*i, *j) } else { (*j, *i) };
vertex_pair_edges.insert(id, edge_index as EdgeIndex);
}
Self {
vertex_num: initializer.vertex_num,
vertex_pair_edges,
complete_graph: CompleteGraph::new(initializer.vertex_num, &initializer.weighted_edges),
subgraph: BTreeSet::new(),
}
}
pub fn clear(&mut self) {
self.subgraph.clear();
self.complete_graph.reset();
}
pub fn load_erasures(&mut self, erasures: &[EdgeIndex]) {
self.complete_graph.load_erasures(erasures);
}
pub fn load_perfect_matching(&mut self, perfect_matching: &PerfectMatching) {
self.subgraph.clear();
for (ptr_1, ptr_2) in perfect_matching.peer_matchings.iter() {
let a_vid = {
let node = ptr_1.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
let b_vid = {
let node = ptr_2.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
self.add_matching(a_vid, b_vid);
}
for (ptr, virtual_vertex) in perfect_matching.virtual_matchings.iter() {
let a_vid = {
let node = ptr.read_recursive();
if let DualNodeClass::DefectVertex{ defect_index } = &node.class { *defect_index } else { unreachable!("can only be syndrome") }
};
self.add_matching(a_vid, *virtual_vertex);
}
}
pub fn add_matching(&mut self, vertex_1: VertexIndex, vertex_2: VertexIndex) {
let (path, _) = self.complete_graph.get_path(vertex_1, vertex_2);
let mut a = vertex_1;
for (vertex, _) in path.iter() {
let b = *vertex;
let id = if a < b { (a, b) } else { (b, a) };
let edge_index = *self.vertex_pair_edges.get(&id).expect("edge should exist");
if self.subgraph.contains(&edge_index) {
self.subgraph.remove(&edge_index);
} else {
self.subgraph.insert(edge_index);
}
a = b;
}
}
pub fn total_weight(&self) -> Weight {
let mut weight = 0;
for edge_index in self.subgraph.iter() {
weight += self.complete_graph.weighted_edges[*edge_index as usize].2;
}
weight
}
pub fn get_subgraph(&self) -> Vec<EdgeIndex> {
self.subgraph.iter().copied().collect()
}
}
pub struct VisualizeSubgraph<'a> {
pub subgraph: &'a Vec<EdgeIndex>,
}
impl<'a> VisualizeSubgraph<'a> {
pub fn new(subgraph: &'a Vec<EdgeIndex>) -> Self {
Self {
subgraph
}
}
}
impl FusionVisualizer for VisualizeSubgraph<'_> {
fn snapshot(&self, _abbrev: bool) -> serde_json::Value {
json!({
"subgraph": self.subgraph,
})
}
}
#[cfg(feature="python_binding")]
#[pyfunction]
pub(crate) fn register(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<IntermediateMatching>()?;
m.add_class::<PerfectMatching>()?;
Ok(())
}