1#![cfg_attr(feature = "unsafe_pointer", allow(dropping_references))]
7use super::complete_graph::*;
8use super::dual_module::*;
9use super::pointers::*;
10use super::util::*;
11use super::visualize::*;
12use crate::derivative::Derivative;
13#[cfg(feature = "python_binding")]
14use pyo3::prelude::*;
15use std::collections::{BTreeMap, BTreeSet, HashMap};
16
17#[derive(Derivative)]
18#[derivative(Debug)]
19#[cfg_attr(feature = "python_binding", cfg_eval)]
20#[cfg_attr(feature = "python_binding", pyclass)]
21pub struct IntermediateMatching {
22 pub peer_matchings: Vec<((DualNodePtr, DualNodeWeak), (DualNodePtr, DualNodeWeak))>,
24 pub virtual_matchings: Vec<((DualNodePtr, DualNodeWeak), VertexIndex)>,
26}
27
28#[derive(Derivative)]
29#[derivative(Debug)]
30#[cfg_attr(feature = "python_binding", cfg_eval)]
31#[cfg_attr(feature = "python_binding", pyclass)]
32pub struct PerfectMatching {
33 pub peer_matchings: Vec<(DualNodePtr, DualNodePtr)>,
35 pub virtual_matchings: Vec<(DualNodePtr, VertexIndex)>,
37}
38
39pub trait PrimalModuleImpl {
41 fn new_empty(solver_initializer: &SolverInitializer) -> Self;
43
44 fn clear(&mut self);
46
47 fn load_defect_dual_node(&mut self, dual_node_ptr: &DualNodePtr);
48
49 fn load_defect<D: DualModuleImpl>(
51 &mut self,
52 defect_vertex: VertexIndex,
53 interface_ptr: &DualModuleInterfacePtr,
54 dual_module: &mut D,
55 ) {
56 interface_ptr.create_defect_node(defect_vertex, dual_module);
57 let interface = interface_ptr.read_recursive();
58 let index = interface.nodes_length - 1;
59 self.load_defect_dual_node(
60 interface.nodes[index]
61 .as_ref()
62 .expect("must load a fresh dual module interface, found empty node"),
63 )
64 }
65
66 #[allow(clippy::unnecessary_cast)]
68 fn load(&mut self, interface_ptr: &DualModuleInterfacePtr) {
69 let interface = interface_ptr.read_recursive();
70 debug_assert!(interface.parent.is_none(), "cannot load an interface that is already fused");
71 debug_assert!(
72 interface.children.is_none(),
73 "please customize load function if interface is fused"
74 );
75 for index in 0..interface.nodes_length as NodeIndex {
76 let node = &interface.nodes[index as usize];
77 debug_assert!(node.is_some(), "must load a fresh dual module interface, found empty node");
78 let node_ptr = node.as_ref().unwrap();
79 let node = node_ptr.read_recursive();
80 debug_assert!(
81 matches!(node.class, DualNodeClass::DefectVertex { .. }),
82 "must load a fresh dual module interface, found a blossom"
83 );
84 debug_assert_eq!(
85 node.index, index,
86 "must load a fresh dual module interface, found index out of order"
87 );
88 self.load_defect_dual_node(node_ptr);
89 }
90 }
91
92 fn resolve<D: DualModuleImpl>(
97 &mut self,
98 group_max_update_length: GroupMaxUpdateLength,
99 interface: &DualModuleInterfacePtr,
100 dual_module: &mut D,
101 );
102
103 fn intermediate_matching<D: DualModuleImpl>(
105 &mut self,
106 interface: &DualModuleInterfacePtr,
107 dual_module: &mut D,
108 ) -> IntermediateMatching;
109
110 fn perfect_matching<D: DualModuleImpl>(
112 &mut self,
113 interface: &DualModuleInterfacePtr,
114 dual_module: &mut D,
115 ) -> PerfectMatching {
116 let intermediate_matching = self.intermediate_matching(interface, dual_module);
117 intermediate_matching.get_perfect_matching()
118 }
119
120 fn solve<D: DualModuleImpl>(
121 &mut self,
122 interface: &DualModuleInterfacePtr,
123 syndrome_pattern: &SyndromePattern,
124 dual_module: &mut D,
125 ) {
126 self.solve_step_callback(interface, syndrome_pattern, dual_module, |_, _, _, _| {})
127 }
128
129 fn solve_visualizer<D: DualModuleImpl + FusionVisualizer>(
130 &mut self,
131 interface: &DualModuleInterfacePtr,
132 syndrome_pattern: &SyndromePattern,
133 dual_module: &mut D,
134 visualizer: Option<&mut Visualizer>,
135 ) where
136 Self: FusionVisualizer + Sized,
137 {
138 if let Some(visualizer) = visualizer {
139 self.solve_step_callback(
140 interface,
141 syndrome_pattern,
142 dual_module,
143 |interface, dual_module, primal_module, group_max_update_length| {
144 #[cfg(test)]
145 println!("group_max_update_length: {:?}", group_max_update_length);
146 if let Some(length) = group_max_update_length.get_none_zero_growth() {
147 visualizer
148 .snapshot_combined(format!("grow {length}"), vec![interface, dual_module, primal_module])
149 .unwrap();
150 } else {
151 let first_conflict = format!("{:?}", group_max_update_length.peek().unwrap());
152 visualizer
153 .snapshot_combined(
154 format!("resolve {first_conflict}"),
155 vec![interface, dual_module, primal_module],
156 )
157 .unwrap();
158 };
159 },
160 );
161 visualizer
162 .snapshot_combined("solved".to_string(), vec![interface, dual_module, self])
163 .unwrap();
164 } else {
165 self.solve(interface, syndrome_pattern, dual_module);
166 }
167 }
168
169 fn solve_step_callback<D: DualModuleImpl, F>(
170 &mut self,
171 interface: &DualModuleInterfacePtr,
172 syndrome_pattern: &SyndromePattern,
173 dual_module: &mut D,
174 callback: F,
175 ) where
176 F: FnMut(&DualModuleInterfacePtr, &mut D, &mut Self, &GroupMaxUpdateLength),
177 {
178 interface.load(syndrome_pattern, dual_module);
179 self.load(interface);
180 self.solve_step_callback_interface_loaded(interface, dual_module, callback);
181 }
182
183 fn solve_step_callback_interface_loaded<D: DualModuleImpl, F>(
184 &mut self,
185 interface: &DualModuleInterfacePtr,
186 dual_module: &mut D,
187 mut callback: F,
188 ) where
189 F: FnMut(&DualModuleInterfacePtr, &mut D, &mut Self, &GroupMaxUpdateLength),
190 {
191 let mut group_max_update_length = dual_module.compute_maximum_update_length();
192 while !group_max_update_length.is_empty() {
193 callback(interface, dual_module, self, &group_max_update_length);
194 if let Some(length) = group_max_update_length.get_none_zero_growth() {
195 interface.grow(length, dual_module);
196 } else {
197 self.resolve(group_max_update_length, interface, dual_module);
198 }
199 group_max_update_length = dual_module.compute_maximum_update_length();
200 }
201 }
202
203 fn generate_profiler_report(&self) -> serde_json::Value {
205 json!({})
206 }
207}
208
209impl Default for IntermediateMatching {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[cfg_attr(feature = "python_binding", cfg_eval)]
216#[cfg_attr(feature = "python_binding", pymethods)]
217impl IntermediateMatching {
218 #[cfg_attr(feature = "python_binding", new)]
219 pub fn new() -> Self {
220 Self {
221 peer_matchings: vec![],
222 virtual_matchings: vec![],
223 }
224 }
225
226 pub fn append(&mut self, other: &mut Self) {
227 self.peer_matchings.append(&mut other.peer_matchings);
228 self.virtual_matchings.append(&mut other.virtual_matchings);
229 }
230
231 pub fn get_perfect_matching(&self) -> PerfectMatching {
233 let mut perfect_matching = PerfectMatching::new();
234 for ((dual_node_ptr_1, touching_weak_1), (dual_node_ptr_2, touching_weak_2)) in self.peer_matchings.iter() {
236 let touching_ptr_1 = touching_weak_1.upgrade_force();
237 let touching_ptr_2 = touching_weak_2.upgrade_force();
238 perfect_matching.peer_matchings.extend(Self::expand_peer_matching(
239 dual_node_ptr_1,
240 &touching_ptr_1,
241 dual_node_ptr_2,
242 &touching_ptr_2,
243 ));
244 }
245 for ((dual_node_ptr, touching_weak), virtual_vertex) in self.virtual_matchings.iter() {
247 let touching_ptr = touching_weak.upgrade_force();
248 perfect_matching
249 .peer_matchings
250 .extend(Self::expand_blossom(dual_node_ptr, &touching_ptr));
251 perfect_matching.virtual_matchings.push((touching_ptr, *virtual_vertex));
252 }
253 perfect_matching
254 }
255
256 #[cfg(feature = "python_binding")]
257 fn __repr__(&self) -> String {
258 format!("{:?}", self)
259 }
260
261 #[cfg(feature = "python_binding")]
262 #[getter]
263 pub fn get_peer_matchings(&self) -> Vec<((NodeIndex, NodeIndex), (NodeIndex, NodeIndex))> {
264 self.peer_matchings
265 .iter()
266 .map(|((a, b), (c, d))| {
267 (
268 (a.updated_index(), b.upgrade_force().updated_index()),
269 (c.updated_index(), d.upgrade_force().updated_index()),
270 )
271 })
272 .collect()
273 }
274
275 #[cfg(feature = "python_binding")]
276 #[getter]
277 pub fn get_virtual_matchings(&self) -> Vec<((NodeIndex, NodeIndex), VertexIndex)> {
278 self.virtual_matchings
279 .iter()
280 .map(|((a, b), c)| ((a.updated_index(), b.upgrade_force().updated_index()), *c))
281 .collect()
282 }
283}
284
285impl IntermediateMatching {
286 pub fn expand_peer_matching(
288 dual_node_ptr_1: &DualNodePtr,
289 touching_ptr_1: &DualNodePtr,
290 dual_node_ptr_2: &DualNodePtr,
291 touching_ptr_2: &DualNodePtr,
292 ) -> Vec<(DualNodePtr, DualNodePtr)> {
293 let mut perfect_matching = vec![];
295 perfect_matching.extend(Self::expand_blossom(dual_node_ptr_1, touching_ptr_1));
296 perfect_matching.extend(Self::expand_blossom(dual_node_ptr_2, touching_ptr_2));
297 perfect_matching.push((touching_ptr_1.clone(), touching_ptr_2.clone()));
298 perfect_matching
300 }
301
302 pub fn expand_blossom(blossom_ptr: &DualNodePtr, touching_ptr: &DualNodePtr) -> Vec<(DualNodePtr, DualNodePtr)> {
304 let mut perfect_matching = vec![];
306 let mut child_ptr = touching_ptr.clone();
307 while &child_ptr != blossom_ptr {
308 let child_weak = child_ptr.downgrade();
309 let child = child_ptr.read_recursive();
310 if let Some(parent_blossom_weak) = child.parent_blossom.as_ref() {
311 let parent_blossom_ptr = parent_blossom_weak.upgrade_force();
312 let parent_blossom = parent_blossom_ptr.read_recursive();
313 if let DualNodeClass::Blossom {
314 nodes_circle,
315 touching_children,
316 } = &parent_blossom.class
317 {
318 let idx = nodes_circle
319 .iter()
320 .position(|ptr| ptr == &child_weak)
321 .expect("should find child");
322 debug_assert!(
323 nodes_circle.len() % 2 == 1 && nodes_circle.len() >= 3,
324 "must be a valid blossom"
325 );
326 for i in (0..(nodes_circle.len() - 1)).step_by(2) {
327 let idx_1 = (idx + i + 1) % nodes_circle.len();
328 let idx_2 = (idx + i + 2) % nodes_circle.len();
329 let dual_node_ptr_1 = nodes_circle[idx_1].upgrade_force();
330 let dual_node_ptr_2 = nodes_circle[idx_2].upgrade_force();
331 let touching_ptr_1 = touching_children[idx_1].1.upgrade_force(); let touching_ptr_2 = touching_children[idx_2].0.upgrade_force(); perfect_matching.extend(Self::expand_peer_matching(
334 &dual_node_ptr_1,
335 &touching_ptr_1,
336 &dual_node_ptr_2,
337 &touching_ptr_2,
338 ))
339 }
340 }
341 drop(child);
342 child_ptr = parent_blossom_ptr.clone();
343 } else {
344 panic!("cannot find parent of {}", child.index)
345 }
346 }
347 perfect_matching
349 }
350}
351
352impl Default for PerfectMatching {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[cfg_attr(feature = "python_binding", cfg_eval)]
359#[cfg_attr(feature = "python_binding", pymethods)]
360impl PerfectMatching {
361 #[cfg_attr(feature = "python_binding", new)]
362 pub fn new() -> Self {
363 Self {
364 peer_matchings: vec![],
365 virtual_matchings: vec![],
366 }
367 }
368
369 pub fn legacy_get_mwpm_result(&self, defect_vertices: Vec<VertexIndex>) -> Vec<DefectIndex> {
371 let mut peer_matching_maps = BTreeMap::<VertexIndex, VertexIndex>::new();
372 for (ptr_1, ptr_2) in self.peer_matchings.iter() {
373 let a_vid = {
374 let node = ptr_1.read_recursive();
375 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
376 *defect_index
377 } else {
378 unreachable!("can only be syndrome")
379 }
380 };
381 let b_vid = {
382 let node = ptr_2.read_recursive();
383 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
384 *defect_index
385 } else {
386 unreachable!("can only be syndrome")
387 }
388 };
389 peer_matching_maps.insert(a_vid, b_vid);
390 peer_matching_maps.insert(b_vid, a_vid);
391 }
392 let mut virtual_matching_maps = BTreeMap::<VertexIndex, VertexIndex>::new();
393 for (ptr, virtual_vertex) in self.virtual_matchings.iter() {
394 let a_vid = {
395 let node = ptr.read_recursive();
396 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
397 *defect_index
398 } else {
399 unreachable!("can only be syndrome")
400 }
401 };
402 virtual_matching_maps.insert(a_vid, *virtual_vertex);
403 }
404 let mut mwpm_result = Vec::with_capacity(defect_vertices.len());
405 for defect_vertex in defect_vertices.iter() {
406 if let Some(a) = peer_matching_maps.get(defect_vertex) {
407 mwpm_result.push(*a);
408 } else if let Some(v) = virtual_matching_maps.get(defect_vertex) {
409 mwpm_result.push(*v);
410 } else {
411 panic!("cannot find defect vertex {}", defect_vertex)
412 }
413 }
414 mwpm_result
415 }
416
417 #[cfg(feature = "python_binding")]
418 fn __repr__(&self) -> String {
419 format!("{:?}", self)
420 }
421
422 #[cfg(feature = "python_binding")]
423 #[getter]
424 pub fn get_peer_matchings(&self) -> Vec<(NodeIndex, NodeIndex)> {
425 self.peer_matchings
426 .iter()
427 .map(|(a, b)| (a.updated_index(), b.updated_index()))
428 .collect()
429 }
430
431 #[cfg(feature = "python_binding")]
432 #[getter]
433 pub fn get_virtual_matchings(&self) -> Vec<(NodeIndex, VertexIndex)> {
434 self.virtual_matchings.iter().map(|(a, b)| (a.updated_index(), *b)).collect()
435 }
436}
437
438impl FusionVisualizer for PerfectMatching {
439 #[allow(clippy::unnecessary_cast)]
440 fn snapshot(&self, abbrev: bool) -> serde_json::Value {
441 let primal_nodes = if self.peer_matchings.is_empty() && self.virtual_matchings.is_empty() {
442 vec![]
443 } else {
444 let mut maximum_node_index = 0;
445 for (ptr_1, ptr_2) in self.peer_matchings.iter() {
446 maximum_node_index = std::cmp::max(maximum_node_index, ptr_1.get_ancestor_blossom().read_recursive().index);
447 maximum_node_index = std::cmp::max(maximum_node_index, ptr_2.get_ancestor_blossom().read_recursive().index);
448 }
449 for (ptr, _virtual_vertex) in self.virtual_matchings.iter() {
450 maximum_node_index = std::cmp::max(maximum_node_index, ptr.get_ancestor_blossom().read_recursive().index);
451 }
452 let mut primal_nodes = vec![json!(null); maximum_node_index as usize + 1];
453 for (ptr_1, ptr_2) in self.peer_matchings.iter() {
454 for (ptr_a, ptr_b) in [(ptr_1, ptr_2), (ptr_2, ptr_1)] {
455 primal_nodes[ptr_a.read_recursive().index as usize] = json!({
456 if abbrev { "m" } else { "temporary_match" }: {
457 if abbrev { "p" } else { "peer" }: ptr_b.read_recursive().index,
458 if abbrev { "t" } else { "touching" }: ptr_a.read_recursive().index,
459 },
460 if abbrev { "t" } else { "tree_node" }: {
461 if abbrev { "r" } else { "root" }: ptr_a.read_recursive().index,
462 if abbrev { "d" } else { "depth" }: 1,
463 },
464 });
465 }
466 }
467 for (ptr, virtual_vertex) in self.virtual_matchings.iter() {
468 primal_nodes[ptr.read_recursive().index as usize] = json!({
469 if abbrev { "m" } else { "temporary_match" }: {
470 if abbrev { "v" } else { "virtual_vertex" }: virtual_vertex,
471 if abbrev { "t" } else { "touching" }: ptr.read_recursive().index,
472 },
473 if abbrev { "t" } else { "tree_node" }: {
474 if abbrev { "r" } else { "root" }: ptr.read_recursive().index,
475 if abbrev { "d" } else { "depth" }: 1,
476 },
477 });
478 }
479 primal_nodes
480 };
481 json!({
482 "primal_nodes": primal_nodes,
483 })
484 }
485}
486
487#[derive(Debug, Clone)]
489pub struct SubGraphBuilder {
490 pub vertex_num: VertexNum,
492 vertex_pair_edges: HashMap<(VertexIndex, VertexIndex), EdgeIndex>,
494 pub complete_graph: CompleteGraph,
496 pub subgraph: BTreeSet<EdgeIndex>,
498}
499
500impl SubGraphBuilder {
501 pub fn new(initializer: &SolverInitializer) -> Self {
502 let mut vertex_pair_edges = HashMap::with_capacity(initializer.weighted_edges.len());
503 for (edge_index, (i, j, _)) in initializer.weighted_edges.iter().enumerate() {
504 let id = if i < j { (*i, *j) } else { (*j, *i) };
505 vertex_pair_edges.insert(id, edge_index as EdgeIndex);
506 }
507 Self {
508 vertex_num: initializer.vertex_num,
509 vertex_pair_edges,
510 complete_graph: CompleteGraph::new(initializer.vertex_num, &initializer.weighted_edges),
511 subgraph: BTreeSet::new(),
512 }
513 }
514
515 pub fn clear(&mut self) {
516 self.subgraph.clear();
517 self.complete_graph.reset();
518 }
519
520 pub fn load_erasures(&mut self, erasures: &[EdgeIndex]) {
522 self.complete_graph.load_erasures(erasures);
523 }
524
525 pub fn load_dynamic_weights(&mut self, dynamic_weights: &[(EdgeIndex, Weight)]) {
526 self.complete_graph.load_dynamic_weights(dynamic_weights);
527 }
528
529 pub fn load_perfect_matching(&mut self, perfect_matching: &PerfectMatching) {
531 self.subgraph.clear();
532 for (ptr_1, ptr_2) in perfect_matching.peer_matchings.iter() {
533 let a_vid = {
534 let node = ptr_1.read_recursive();
535 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
536 *defect_index
537 } else {
538 unreachable!("can only be syndrome")
539 }
540 };
541 let b_vid = {
542 let node = ptr_2.read_recursive();
543 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
544 *defect_index
545 } else {
546 unreachable!("can only be syndrome")
547 }
548 };
549 self.add_matching(a_vid, b_vid);
550 }
551 for (ptr, virtual_vertex) in perfect_matching.virtual_matchings.iter() {
552 let a_vid = {
553 let node = ptr.read_recursive();
554 if let DualNodeClass::DefectVertex { defect_index } = &node.class {
555 *defect_index
556 } else {
557 unreachable!("can only be syndrome")
558 }
559 };
560 self.add_matching(a_vid, *virtual_vertex);
561 }
562 }
563
564 pub fn load_subgraph(&mut self, subgraph: &[EdgeIndex]) {
565 self.subgraph.clear();
566 self.subgraph.extend(subgraph);
567 }
568
569 pub fn add_matching(&mut self, vertex_1: VertexIndex, vertex_2: VertexIndex) {
571 let (path, _) = self.complete_graph.get_path(vertex_1, vertex_2);
572 let mut a = vertex_1;
573 for (vertex, _) in path.iter() {
574 let b = *vertex;
575 let id = if a < b { (a, b) } else { (b, a) };
576 let edge_index = *self.vertex_pair_edges.get(&id).expect("edge should exist");
577 if self.subgraph.contains(&edge_index) {
578 self.subgraph.remove(&edge_index);
579 } else {
580 self.subgraph.insert(edge_index);
581 }
582 a = b;
583 }
584 }
585
586 #[allow(clippy::unnecessary_cast)]
588 pub fn total_weight(&self) -> Weight {
589 let mut weight = 0;
590 for edge_index in self.subgraph.iter() {
591 weight += self.complete_graph.weighted_edges[*edge_index as usize].2;
592 }
593 weight
594 }
595
596 pub fn get_subgraph(&self) -> Vec<EdgeIndex> {
598 self.subgraph.iter().copied().collect()
599 }
600}
601
602pub struct VisualizeSubgraph<'a> {
604 pub subgraph: &'a Vec<EdgeIndex>,
605}
606
607impl<'a> VisualizeSubgraph<'a> {
608 pub fn new(subgraph: &'a Vec<EdgeIndex>) -> Self {
609 Self { subgraph }
610 }
611}
612
613impl FusionVisualizer for VisualizeSubgraph<'_> {
614 fn snapshot(&self, _abbrev: bool) -> serde_json::Value {
615 json!({
616 "subgraph": self.subgraph,
617 })
618 }
619}
620
621#[cfg(feature = "python_binding")]
622#[pyfunction]
623pub(crate) fn register(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
624 m.add_class::<IntermediateMatching>()?;
625 m.add_class::<PerfectMatching>()?;
626 Ok(())
627}