1use super::mwpm_solver::PrimalDualSolver;
2use super::pointers::*;
3use super::rand_xoshiro;
4use crate::rand_xoshiro::rand_core::RngCore;
5#[cfg(feature = "python_binding")]
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeSet;
9use std::fs::File;
10use std::io::prelude::*;
11use std::time::Instant;
12
13cfg_if::cfg_if! {
14 if #[cfg(feature="i32_weight")] {
15 pub type Weight = i32;
17 } else {
18 pub type Weight = isize;
19 }
20}
21
22cfg_if::cfg_if! {
23 if #[cfg(feature="u32_index")] {
24 pub type EdgeIndex = u32;
26 pub type VertexIndex = u32; pub type NodeIndex = VertexIndex;
28 pub type DefectIndex = VertexIndex;
29 pub type VertexNodeIndex = VertexIndex; pub type VertexNum = VertexIndex;
31 pub type NodeNum = VertexIndex;
32 } else {
33 pub type EdgeIndex = usize;
34 pub type VertexIndex = usize;
35 pub type NodeIndex = VertexIndex;
36 pub type DefectIndex = VertexIndex;
37 pub type VertexNodeIndex = VertexIndex; pub type VertexNum = VertexIndex;
39 pub type NodeNum = VertexIndex;
40 }
41}
42
43#[cfg(feature = "python_binding")]
44macro_rules! bind_trait_python_json {
45 ($struct_name:ident) => {
46 #[pymethods]
47 impl $struct_name {
48 #[pyo3(name = "to_json")]
49 fn python_to_json(&self) -> PyResult<String> {
50 serde_json::to_string(self).map_err(|err| pyo3::exceptions::PyTypeError::new_err(format!("{err:?}")))
51 }
52 #[staticmethod]
53 #[pyo3(name = "from_json")]
54 fn python_from_json(value: String) -> PyResult<Self> {
55 serde_json::from_str(value.as_str())
56 .map_err(|err| pyo3::exceptions::PyTypeError::new_err(format!("{err:?}")))
57 }
58 }
59 };
60}
61
62#[cfg_attr(feature = "python_binding", cfg_eval)]
63#[cfg_attr(feature = "python_binding", pyclass)]
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct SolverInitializer {
66 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
68 pub vertex_num: VertexNum,
69 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
71 pub weighted_edges: Vec<(VertexIndex, VertexIndex, Weight)>,
72 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
74 pub virtual_vertices: Vec<VertexIndex>,
75}
76
77#[cfg(feature = "python_binding")]
78bind_trait_python_json! {SolverInitializer}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[cfg_attr(feature = "python_binding", cfg_eval)]
82#[cfg_attr(feature = "python_binding", pyclass)]
83pub struct SyndromePattern {
84 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
86 pub defect_vertices: Vec<VertexIndex>,
87 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
91 #[serde(default = "default_erasures")]
92 pub erasures: Vec<EdgeIndex>,
93 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
95 #[serde(default = "default_dynamic_weights")]
96 pub dynamic_weights: Vec<(EdgeIndex, Weight)>,
97}
98
99pub fn default_dynamic_weights() -> Vec<(EdgeIndex, Weight)> {
100 vec![]
101}
102
103pub fn default_erasures() -> Vec<EdgeIndex> {
104 vec![]
105}
106
107impl SyndromePattern {
108 pub fn new(defect_vertices: Vec<VertexIndex>, erasures: Vec<EdgeIndex>) -> Self {
109 Self {
110 defect_vertices,
111 erasures,
112 dynamic_weights: vec![],
113 }
114 }
115 pub fn new_dynamic_weights(
116 defect_vertices: Vec<VertexIndex>,
117 erasures: Vec<EdgeIndex>,
118 dynamic_weights: Vec<(EdgeIndex, Weight)>,
119 ) -> Self {
120 Self {
121 defect_vertices,
122 erasures,
123 dynamic_weights,
124 }
125 }
126}
127
128#[cfg_attr(feature = "python_binding", cfg_eval)]
129#[cfg_attr(feature = "python_binding", pymethods)]
130impl SyndromePattern {
131 #[cfg_attr(feature = "python_binding", new)]
132 #[cfg_attr(feature = "python_binding", pyo3(signature = (defect_vertices=vec![], erasures=vec![], dynamic_weights=vec![], syndrome_vertices=None)))]
133 pub fn py_new(
134 mut defect_vertices: Vec<VertexIndex>,
135 erasures: Vec<EdgeIndex>,
136 dynamic_weights: Vec<(EdgeIndex, Weight)>,
137 syndrome_vertices: Option<Vec<VertexIndex>>,
138 ) -> Self {
139 if let Some(syndrome_vertices) = syndrome_vertices {
140 assert!(
141 defect_vertices.is_empty(),
142 "do not pass both `syndrome_vertices` and `defect_vertices` since they're aliasing"
143 );
144 defect_vertices = syndrome_vertices;
145 }
146 assert!(
147 erasures.is_empty() || dynamic_weights.is_empty(),
148 "erasures and dynamic_weights cannot be provided at the same time"
149 );
150 Self::new_dynamic_weights(defect_vertices, erasures, dynamic_weights)
151 }
152 #[cfg_attr(feature = "python_binding", staticmethod)]
153 pub fn new_vertices(defect_vertices: Vec<VertexIndex>) -> Self {
154 Self::new(defect_vertices, vec![])
155 }
156 #[cfg_attr(feature = "python_binding", staticmethod)]
157 pub fn new_empty() -> Self {
158 Self::new(vec![], vec![])
159 }
160 #[cfg(feature = "python_binding")]
161 fn __repr__(&self) -> String {
162 format!("{:?}", self)
163 }
164}
165
166#[derive(Debug, Clone, Serialize)]
168pub struct PartitionedSyndromePattern<'a> {
169 pub syndrome_pattern: &'a SyndromePattern,
171 pub whole_defect_range: DefectRange,
173}
174
175impl<'a> PartitionedSyndromePattern<'a> {
176 pub fn new(syndrome_pattern: &'a SyndromePattern) -> Self {
177 assert!(
178 syndrome_pattern.erasures.is_empty(),
179 "erasure partition not supported yet;
180 even if the edges in the erasure is well ordered, they may not be able to be represented as
181 a single range simply because the partition is vertex-based. need more consideration"
182 );
183 Self {
184 syndrome_pattern,
185 whole_defect_range: DefectRange::new(0, syndrome_pattern.defect_vertices.len() as DefectIndex),
186 }
187 }
188}
189
190#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
191#[serde(transparent)]
192#[cfg_attr(feature = "python_binding", cfg_eval)]
193#[cfg_attr(feature = "python_binding", pyclass)]
194pub struct IndexRange {
195 pub range: [VertexNodeIndex; 2],
196}
197
198pub type VertexRange = IndexRange;
200pub type NodeRange = IndexRange;
201pub type DefectRange = IndexRange;
202
203#[cfg_attr(feature = "python_binding", cfg_eval)]
204#[cfg_attr(feature = "python_binding", pymethods)]
205impl IndexRange {
206 #[cfg_attr(feature = "python_binding", new)]
207 pub fn new(start: VertexNodeIndex, end: VertexNodeIndex) -> Self {
208 debug_assert!(end >= start, "invalid range [{}, {})", start, end);
209 Self { range: [start, end] }
210 }
211 #[cfg_attr(feature = "python_binding", staticmethod)]
212 pub fn new_length(start: VertexNodeIndex, length: VertexNodeIndex) -> Self {
213 Self::new(start, start + length)
214 }
215 pub fn is_empty(&self) -> bool {
216 self.range[1] == self.range[0]
217 }
218 #[allow(clippy::unnecessary_cast)]
219 pub fn len(&self) -> usize {
220 (self.range[1] - self.range[0]) as usize
221 }
222 pub fn start(&self) -> VertexNodeIndex {
223 self.range[0]
224 }
225 pub fn end(&self) -> VertexNodeIndex {
226 self.range[1]
227 }
228 pub fn append_by(&mut self, append_count: VertexNodeIndex) {
229 self.range[1] += append_count;
230 }
231 pub fn bias_by(&mut self, bias: VertexNodeIndex) {
232 self.range[0] += bias;
233 self.range[1] += bias;
234 }
235 pub fn sanity_check(&self) {
236 assert!(self.start() <= self.end(), "invalid vertex range {:?}", self);
237 }
238 pub fn contains(&self, vertex_index: VertexNodeIndex) -> bool {
239 vertex_index >= self.start() && vertex_index < self.end()
240 }
241 pub fn fuse(&self, other: &Self) -> (Self, Self) {
243 self.sanity_check();
244 other.sanity_check();
245 assert!(self.range[1] <= other.range[0], "only lower range can fuse higher range");
246 (
247 Self::new(self.range[0], other.range[1]),
248 Self::new(self.range[1], other.range[0]),
249 )
250 }
251 #[cfg(feature = "python_binding")]
252 #[pyo3(name = "contains_any")]
253 pub fn python_contains_any(&self, vertex_indices: Vec<VertexNodeIndex>) -> bool {
254 self.contains_any(&vertex_indices)
255 }
256 #[cfg(feature = "python_binding")]
257 fn __repr__(&self) -> String {
258 format!("{:?}", self)
259 }
260}
261
262impl IndexRange {
263 pub fn iter(&self) -> std::ops::Range<VertexNodeIndex> {
264 self.range[0]..self.range[1]
265 }
266 pub fn contains_any(&self, vertex_indices: &[VertexNodeIndex]) -> bool {
267 for vertex_index in vertex_indices.iter() {
268 if self.contains(*vertex_index) {
269 return true;
270 }
271 }
272 false
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct PartitionUnit {
279 pub unit_index: usize,
281 pub enabled: bool,
283}
284
285pub type PartitionUnitPtr = ArcManualSafeLock<PartitionUnit>;
286pub type PartitionUnitWeak = WeakManualSafeLock<PartitionUnit>;
287
288impl std::fmt::Debug for PartitionUnitPtr {
289 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
290 let partition_unit = self.read_recursive();
291 write!(
292 f,
293 "{}{}",
294 if partition_unit.enabled { "E" } else { "D" },
295 partition_unit.unit_index
296 )
297 }
298}
299
300impl std::fmt::Debug for PartitionUnitWeak {
301 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
302 self.upgrade_force().fmt(f)
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(deny_unknown_fields)]
309#[cfg_attr(feature = "python_binding", cfg_eval)]
310#[cfg_attr(feature = "python_binding", pyclass)]
311pub struct PartitionConfig {
312 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
314 pub vertex_num: VertexNum,
315 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
317 pub partitions: Vec<VertexRange>,
318 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
320 pub fusions: Vec<(usize, usize)>,
321}
322
323#[cfg(feature = "python_binding")]
324bind_trait_python_json! {PartitionConfig}
325
326#[cfg_attr(feature = "python_binding", cfg_eval)]
327#[cfg_attr(feature = "python_binding", pymethods)]
328impl PartitionConfig {
329 #[cfg_attr(feature = "python_binding", new)]
330 pub fn new(vertex_num: VertexNum) -> Self {
331 Self {
332 vertex_num,
333 partitions: vec![VertexRange::new(0, vertex_num as VertexIndex)],
334 fusions: vec![],
335 }
336 }
337
338 #[cfg(feature = "python_binding")]
339 fn __repr__(&self) -> String {
340 format!("{:?}", self)
341 }
342
343 #[allow(clippy::unnecessary_cast)]
344 pub fn info(&self) -> PartitionInfo {
345 assert!(!self.partitions.is_empty(), "at least one partition must exist");
346 let mut whole_ranges = vec![];
347 let mut owning_ranges = vec![];
348 for &partition in self.partitions.iter() {
349 partition.sanity_check();
350 assert!(
351 partition.end() <= self.vertex_num as VertexIndex,
352 "invalid vertex index {} in partitions",
353 partition.end()
354 );
355 whole_ranges.push(partition);
356 owning_ranges.push(partition);
357 }
358 let unit_count = self.partitions.len() + self.fusions.len();
359 let mut parents: Vec<Option<usize>> = (0..unit_count).map(|_| None).collect();
360 for (fusion_index, (left_index, right_index)) in self.fusions.iter().enumerate() {
361 let unit_index = fusion_index + self.partitions.len();
362 assert!(
363 *left_index < unit_index,
364 "dependency wrong, {} depending on {}",
365 unit_index,
366 left_index
367 );
368 assert!(
369 *right_index < unit_index,
370 "dependency wrong, {} depending on {}",
371 unit_index,
372 right_index
373 );
374 assert!(parents[*left_index].is_none(), "cannot fuse {} twice", left_index);
375 assert!(parents[*right_index].is_none(), "cannot fuse {} twice", right_index);
376 parents[*left_index] = Some(unit_index);
377 parents[*right_index] = Some(unit_index);
378 let (whole_range, interface_range) = whole_ranges[*left_index].fuse(&whole_ranges[*right_index]);
380 whole_ranges.push(whole_range);
381 owning_ranges.push(interface_range);
382 }
383 for (unit_index, parent) in parents.iter().enumerate().take(unit_count - 1) {
385 assert!(parent.is_some(), "found unit {} without being fused", unit_index);
386 }
387 let last_unit_index = self.partitions.len() + self.fusions.len() - 1;
389 assert!(
390 whole_ranges[last_unit_index].start() == 0,
391 "final range not covering all vertices {:?}",
392 whole_ranges[last_unit_index]
393 );
394 assert!(
395 whole_ranges[last_unit_index].end() == self.vertex_num as VertexIndex,
396 "final range not covering all vertices {:?}",
397 whole_ranges[last_unit_index]
398 );
399 let mut partition_unit_info: Vec<_> = (0..self.partitions.len() + self.fusions.len())
401 .map(|i| PartitionUnitInfo {
402 whole_range: whole_ranges[i],
403 owning_range: owning_ranges[i],
404 children: if i >= self.partitions.len() {
405 Some(self.fusions[i - self.partitions.len()])
406 } else {
407 None
408 },
409 parent: parents[i],
410 leaves: if i < self.partitions.len() { vec![i] } else { vec![] },
411 descendants: BTreeSet::new(),
412 })
413 .collect();
414 for (fusion_index, (left_index, right_index)) in self.fusions.iter().enumerate() {
416 let unit_index = fusion_index + self.partitions.len();
417 let mut leaves = vec![];
418 leaves.extend(partition_unit_info[*left_index].leaves.iter());
419 leaves.extend(partition_unit_info[*right_index].leaves.iter());
420 partition_unit_info[unit_index].leaves.extend(leaves.iter());
421 let mut descendants = vec![];
422 descendants.push(*left_index);
423 descendants.push(*right_index);
424 descendants.extend(partition_unit_info[*left_index].descendants.iter());
425 descendants.extend(partition_unit_info[*right_index].descendants.iter());
426 partition_unit_info[unit_index].descendants.extend(descendants.iter());
427 }
428 let mut vertex_to_owning_unit: Vec<_> = (0..self.vertex_num).map(|_| usize::MAX).collect();
429 for (unit_index, unit_range) in partition_unit_info.iter().map(|x| x.owning_range).enumerate() {
430 for vertex_index in unit_range.iter() {
431 vertex_to_owning_unit[vertex_index as usize] = unit_index;
432 }
433 }
434 PartitionInfo {
435 config: self.clone(),
436 units: partition_unit_info,
437 vertex_to_owning_unit,
438 }
439 }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
443#[cfg_attr(feature = "python_binding", cfg_eval)]
444#[cfg_attr(feature = "python_binding", pyclass)]
445pub struct PartitionInfo {
446 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
448 pub config: PartitionConfig,
449 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
451 pub units: Vec<PartitionUnitInfo>,
452 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
455 pub vertex_to_owning_unit: Vec<usize>,
456}
457
458#[cfg(feature = "python_binding")]
459bind_trait_python_json! {PartitionInfo}
460
461#[cfg_attr(feature = "python_binding", pymethods)]
462impl PartitionInfo {
463 #[allow(clippy::unnecessary_cast)]
466 pub fn partition_syndrome_unordered(&self, syndrome_pattern: &SyndromePattern) -> Vec<SyndromePattern> {
467 let mut partitioned_syndrome: Vec<_> = (0..self.units.len()).map(|_| SyndromePattern::new_empty()).collect();
468 for defect_vertex in syndrome_pattern.defect_vertices.iter() {
469 let unit_index = self.vertex_to_owning_unit[*defect_vertex as usize];
470 partitioned_syndrome[unit_index].defect_vertices.push(*defect_vertex);
471 }
472 partitioned_syndrome
474 }
475
476 #[cfg(feature = "python_binding")]
477 fn __repr__(&self) -> String {
478 format!("{:?}", self)
479 }
480}
481
482impl<'a> PartitionedSyndromePattern<'a> {
483 #[allow(clippy::unnecessary_cast)]
485 pub fn partition(&self, partition_unit_info: &PartitionUnitInfo) -> (Self, (Self, Self)) {
486 let owning_start_index = {
488 let mut left_index = self.whole_defect_range.start();
489 let mut right_index = self.whole_defect_range.end();
490 while left_index != right_index {
491 let mid_index = (left_index + right_index) / 2;
492 let mid_defect_vertex = self.syndrome_pattern.defect_vertices[mid_index as usize];
493 if mid_defect_vertex < partition_unit_info.owning_range.start() {
494 left_index = mid_index + 1;
495 } else {
496 right_index = mid_index;
497 }
498 }
499 left_index
500 };
501 let owning_end_index = {
503 let mut left_index = self.whole_defect_range.start();
504 let mut right_index = self.whole_defect_range.end();
505 while left_index != right_index {
506 let mid_index = (left_index + right_index) / 2;
507 let mid_defect_vertex = self.syndrome_pattern.defect_vertices[mid_index as usize];
508 if mid_defect_vertex < partition_unit_info.owning_range.end() {
509 left_index = mid_index + 1;
510 } else {
511 right_index = mid_index;
512 }
513 }
514 left_index
515 };
516 (
517 Self {
518 syndrome_pattern: self.syndrome_pattern,
519 whole_defect_range: DefectRange::new(owning_start_index, owning_end_index),
520 },
521 (
522 Self {
523 syndrome_pattern: self.syndrome_pattern,
524 whole_defect_range: DefectRange::new(self.whole_defect_range.start(), owning_start_index),
525 },
526 Self {
527 syndrome_pattern: self.syndrome_pattern,
528 whole_defect_range: DefectRange::new(owning_end_index, self.whole_defect_range.end()),
529 },
530 ),
531 )
532 }
533
534 #[allow(clippy::unnecessary_cast)]
535 pub fn expand(&self) -> SyndromePattern {
536 let mut defect_vertices = Vec::with_capacity(self.whole_defect_range.len());
537 for defect_index in self.whole_defect_range.iter() {
538 defect_vertices.push(self.syndrome_pattern.defect_vertices[defect_index as usize]);
539 }
540 SyndromePattern::new(defect_vertices, vec![])
541 }
542}
543
544#[derive(Debug, Clone, Serialize, Deserialize)]
545#[cfg_attr(feature = "python_binding", cfg_eval)]
546#[cfg_attr(feature = "python_binding", pyclass)]
547pub struct PartitionUnitInfo {
548 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
550 pub whole_range: VertexRange,
551 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
553 pub owning_range: VertexRange,
554 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
556 pub children: Option<(usize, usize)>,
557 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
559 pub parent: Option<usize>,
560 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
562 pub leaves: Vec<usize>,
563 #[cfg_attr(feature = "python_binding", pyo3(get, set))]
565 pub descendants: BTreeSet<usize>,
566}
567
568#[cfg(feature = "python_binding")]
569bind_trait_python_json! {PartitionUnitInfo}
570
571#[cfg(feature = "python_binding")]
572#[pymethods]
573impl PartitionUnitInfo {
574 fn __repr__(&self) -> String {
575 format!("{:?}", self)
576 }
577}
578
579#[derive(Debug, Clone)]
580pub struct PartitionedSolverInitializer {
581 pub unit_index: usize,
583 pub vertex_num: VertexNum,
585 pub edge_num: usize,
587 pub owning_range: VertexRange,
589 pub owning_interface: Option<PartitionUnitWeak>,
591 pub interfaces: Vec<(PartitionUnitWeak, Vec<(VertexIndex, bool)>)>,
595 pub weighted_edges: Vec<(VertexIndex, VertexIndex, Weight, EdgeIndex)>,
598 pub virtual_vertices: Vec<VertexIndex>,
600}
601
602#[allow(clippy::unnecessary_cast)]
604pub fn build_old_to_new(reordered_vertices: &[VertexIndex]) -> Vec<Option<VertexIndex>> {
605 let mut old_to_new: Vec<Option<VertexIndex>> = (0..reordered_vertices.len()).map(|_| None).collect();
606 for (new_index, old_index) in reordered_vertices.iter().enumerate() {
607 assert_eq!(old_to_new[*old_index as usize], None, "duplicate vertex found {}", old_index);
608 old_to_new[*old_index as usize] = Some(new_index as VertexIndex);
609 }
610 old_to_new
611}
612
613#[allow(clippy::unnecessary_cast)]
615pub fn translated_defect_to_reordered(
616 reordered_vertices: &[VertexIndex],
617 old_defect_vertices: &[VertexIndex],
618) -> Vec<VertexIndex> {
619 let old_to_new = build_old_to_new(reordered_vertices);
620 old_defect_vertices
621 .iter()
622 .map(|old_index| old_to_new[*old_index as usize].unwrap())
623 .collect()
624}
625
626#[cfg_attr(feature = "python_binding", cfg_eval)]
627#[cfg_attr(feature = "python_binding", pymethods)]
628impl SolverInitializer {
629 #[cfg_attr(feature = "python_binding", new)]
630 pub fn new(
631 vertex_num: VertexNum,
632 weighted_edges: Vec<(VertexIndex, VertexIndex, Weight)>,
633 virtual_vertices: Vec<VertexIndex>,
634 ) -> SolverInitializer {
635 SolverInitializer {
636 vertex_num,
637 weighted_edges,
638 virtual_vertices,
639 }
640 }
641 #[cfg(feature = "python_binding")]
642 fn __repr__(&self) -> String {
643 format!("{:?}", self)
644 }
645}
646
647impl SolverInitializer {
648 #[allow(clippy::unnecessary_cast)]
649 pub fn syndrome_of(&self, subgraph: &[EdgeIndex]) -> BTreeSet<VertexIndex> {
650 let mut defects = BTreeSet::new();
651 for edge_index in subgraph {
652 let (left, right, _weight) = self.weighted_edges[*edge_index as usize];
653 for vertex_index in [left, right] {
654 if defects.contains(&vertex_index) {
655 defects.remove(&vertex_index);
656 } else {
657 defects.insert(vertex_index);
658 }
659 }
660 }
661 for vertex_index in self.virtual_vertices.iter() {
663 defects.remove(vertex_index);
664 }
665 defects
666 }
667}
668
669pub type FastClearTimestamp = usize;
671
672#[allow(dead_code)]
673pub type DeterministicRng = rand_xoshiro::Xoshiro256StarStar;
675
676pub trait F64Rng {
677 fn next_f64(&mut self) -> f64;
678}
679
680impl F64Rng for DeterministicRng {
681 fn next_f64(&mut self) -> f64 {
682 f64::from_bits(0x3FF << 52 | self.next_u64() >> 12) - 1.
683 }
684}
685
686pub struct BenchmarkProfiler {
688 pub records: Vec<BenchmarkProfilerEntry>,
690 pub sum_round_time: f64,
692 pub sum_syndrome: usize,
694 pub noisy_measurements: VertexNum,
696 pub benchmark_profiler_output: Option<File>,
698}
699
700impl BenchmarkProfiler {
701 pub fn new(noisy_measurements: VertexNum, detail_log_file: Option<(String, &PartitionInfo)>) -> Self {
702 let benchmark_profiler_output = detail_log_file.map(|(filename, partition_info)| {
703 let mut file = File::create(filename).unwrap();
704 file.write_all(serde_json::to_string(&partition_info.config).unwrap().as_bytes())
705 .unwrap();
706 file.write_all(b"\n").unwrap();
707 file.write_all(
708 serde_json::to_string(&json!({
709 "noisy_measurements": noisy_measurements,
710 }))
711 .unwrap()
712 .as_bytes(),
713 )
714 .unwrap();
715 file.write_all(b"\n").unwrap();
716 file
717 });
718 Self {
719 records: vec![],
720 sum_round_time: 0.,
721 sum_syndrome: 0,
722 noisy_measurements,
723 benchmark_profiler_output,
724 }
725 }
726 pub fn begin(&mut self, syndrome_pattern: &SyndromePattern) {
728 if let Some(last_entry) = self.records.last() {
730 assert!(
731 last_entry.is_complete(),
732 "the last benchmark profiler entry is not complete, make sure to call `begin` and `end` in pairs"
733 );
734 }
735 let entry = BenchmarkProfilerEntry::new(syndrome_pattern);
736 self.records.push(entry);
737 self.records.last_mut().unwrap().record_begin();
738 }
739 pub fn event(&mut self, event_name: String) {
740 let last_entry = self
741 .records
742 .last_mut()
743 .expect("last entry not exists, call `begin` before `end`");
744 last_entry.record_event(event_name);
745 }
746 pub fn end(&mut self, solver: Option<&dyn PrimalDualSolver>) {
748 let last_entry = self
749 .records
750 .last_mut()
751 .expect("last entry not exists, call `begin` before `end`");
752 last_entry.record_end();
753 self.sum_round_time += last_entry.round_time.unwrap();
754 self.sum_syndrome += last_entry.syndrome_pattern.defect_vertices.len();
755 if let Some(file) = self.benchmark_profiler_output.as_mut() {
756 let mut events = serde_json::Map::new();
757 for (event_name, time) in last_entry.events.iter() {
758 events.insert(event_name.clone(), json!(time));
759 }
760 let mut value = json!({
761 "round_time": last_entry.round_time.unwrap(),
762 "defect_num": last_entry.syndrome_pattern.defect_vertices.len(),
763 "events": events,
764 });
765 if let Some(solver) = solver {
766 let solver_profile = solver.generate_profiler_report();
767 value
768 .as_object_mut()
769 .unwrap()
770 .insert("solver_profile".to_string(), solver_profile);
771 }
772 file.write_all(serde_json::to_string(&value).unwrap().as_bytes()).unwrap();
773 file.write_all(b"\n").unwrap();
774 }
775 }
776 pub fn brief(&self) -> String {
778 let total = self.sum_round_time / (self.records.len() as f64);
779 let per_round = total / (1. + self.noisy_measurements as f64);
780 let per_defect = self.sum_round_time / (self.sum_syndrome as f64);
781 format!("total: {total:.3e}, round: {per_round:.3e}, defect: {per_defect:.3e},")
782 }
783}
784
785pub struct BenchmarkProfilerEntry {
786 pub syndrome_pattern: SyndromePattern,
788 begin_time: Option<Instant>,
790 pub events: Vec<(String, f64)>,
792 pub round_time: Option<f64>,
794}
795
796impl BenchmarkProfilerEntry {
797 pub fn new(syndrome_pattern: &SyndromePattern) -> Self {
798 Self {
799 syndrome_pattern: syndrome_pattern.clone(),
800 begin_time: None,
801 events: vec![],
802 round_time: None,
803 }
804 }
805 pub fn record_begin(&mut self) {
807 assert_eq!(self.begin_time, None, "do not call `record_begin` twice on the same entry");
808 self.begin_time = Some(Instant::now());
809 }
810 pub fn record_end(&mut self) {
812 let begin_time = self
813 .begin_time
814 .as_ref()
815 .expect("make sure to call `record_begin` before calling `record_end`");
816 self.round_time = Some(begin_time.elapsed().as_secs_f64());
817 }
818 pub fn record_event(&mut self, event_name: String) {
819 let begin_time = self
820 .begin_time
821 .as_ref()
822 .expect("make sure to call `record_begin` before calling `record_end`");
823 self.events.push((event_name, begin_time.elapsed().as_secs_f64()));
824 }
825 pub fn is_complete(&self) -> bool {
826 self.round_time.is_some()
827 }
828}
829
830#[cfg(feature = "python_binding")]
841#[pyclass]
842pub struct PyMut {
843 #[pyo3(get, set)]
845 object: PyObject,
846 #[pyo3(get, set)]
848 attr_name: String,
849 #[pyo3(get, set)]
851 attr_object: Option<PyObject>,
852}
853
854#[cfg(feature = "python_binding")]
855#[pymethods]
856impl PyMut {
857 #[new]
858 pub fn new(object: PyObject, attr_name: String) -> Self {
859 Self {
860 object,
861 attr_name,
862 attr_object: None,
863 }
864 }
865 pub fn __enter__(&mut self) -> PyObject {
866 assert!(self.attr_object.is_none(), "do not enter twice");
867 Python::with_gil(|py| {
868 let attr_object = self.object.getattr(py, self.attr_name.as_str()).unwrap();
869 self.attr_object = Some(attr_object.clone_ref(py));
870 attr_object
871 })
872 }
873 pub fn __exit__(&mut self, _exc_type: PyObject, _exc_val: PyObject, _exc_tb: PyObject) {
874 Python::with_gil(|py| {
875 self.object
876 .setattr(py, self.attr_name.as_str(), self.attr_object.take().unwrap())
877 .unwrap()
878 })
879 }
880}
881
882#[cfg(feature = "python_binding")]
883pub fn json_to_pyobject_locked<'py>(value: serde_json::Value, py: Python<'py>) -> PyObject {
884 match value {
885 serde_json::Value::Null => py.None(),
886 serde_json::Value::Bool(value) => value.to_object(py).into(),
887 serde_json::Value::Number(value) => {
888 if value.is_i64() {
889 value.as_i64().to_object(py).into()
890 } else {
891 value.as_f64().to_object(py).into()
892 }
893 }
894 serde_json::Value::String(value) => value.to_object(py).into(),
895 serde_json::Value::Array(array) => {
896 let elements: Vec<PyObject> = array.into_iter().map(|value| json_to_pyobject_locked(value, py)).collect();
897 pyo3::types::PyList::new(py, elements).into()
898 }
899 serde_json::Value::Object(map) => {
900 let pydict = pyo3::types::PyDict::new(py);
901 for (key, value) in map.into_iter() {
902 let pyobject = json_to_pyobject_locked(value, py);
903 pydict.set_item(key, pyobject).unwrap();
904 }
905 pydict.into()
906 }
907 }
908}
909
910#[cfg(feature = "python_binding")]
911pub fn json_to_pyobject(value: serde_json::Value) -> PyObject {
912 Python::with_gil(|py| json_to_pyobject_locked(value, py))
913}
914
915#[cfg(feature = "python_binding")]
916pub fn pyobject_to_json_locked<'py>(value: PyObject, py: Python<'py>) -> serde_json::Value {
917 let value: &PyAny = value.as_ref(py);
918 if value.is_none() {
919 serde_json::Value::Null
920 } else if value.is_instance_of::<pyo3::types::PyBool>().unwrap() {
921 json!(value.extract::<bool>().unwrap())
922 } else if value.is_instance_of::<pyo3::types::PyInt>().unwrap() {
923 json!(value.extract::<i64>().unwrap())
924 } else if value.is_instance_of::<pyo3::types::PyFloat>().unwrap() {
925 json!(value.extract::<f64>().unwrap())
926 } else if value.is_instance_of::<pyo3::types::PyString>().unwrap() {
927 json!(value.extract::<String>().unwrap())
928 } else if value.is_instance_of::<pyo3::types::PyList>().unwrap() {
929 let elements: Vec<serde_json::Value> = value
930 .extract::<Vec<PyObject>>()
931 .unwrap()
932 .into_iter()
933 .map(|object| pyobject_to_json_locked(object, py))
934 .collect();
935 json!(elements)
936 } else if value.is_instance_of::<pyo3::types::PyDict>().unwrap() {
937 let map: &pyo3::types::PyDict = value.downcast().unwrap();
938 let mut json_map = serde_json::Map::new();
939 for (key, value) in map.iter() {
940 json_map.insert(
941 key.extract::<String>().unwrap(),
942 pyobject_to_json_locked(value.to_object(py), py),
943 );
944 }
945 serde_json::Value::Object(json_map)
946 } else {
947 unimplemented!("unsupported python type, should be (cascaded) dict, list and basic numerical types")
948 }
949}
950
951#[cfg(feature = "python_binding")]
952pub fn pyobject_to_json(value: PyObject) -> serde_json::Value {
953 Python::with_gil(|py| pyobject_to_json_locked(value, py))
954}
955
956#[cfg(feature = "python_binding")]
957#[pyfunction]
958pub(crate) fn register(py: Python<'_>, m: &PyModule) -> PyResult<()> {
959 m.add_class::<SolverInitializer>()?;
960 m.add_class::<PyMut>()?;
961 m.add_class::<PartitionUnitInfo>()?;
962 m.add_class::<PartitionInfo>()?;
963 m.add_class::<PartitionConfig>()?;
964 m.add_class::<SyndromePattern>()?;
965 use crate::pyo3::PyTypeInfo;
966 m.add("VertexRange", VertexRange::type_object(py))?;
968 m.add("DefectRange", DefectRange::type_object(py))?;
969 m.add("SyndromeRange", DefectRange::type_object(py))?; m.add("NodeRange", NodeRange::type_object(py))?;
971 Ok(())
972}
973
974#[cfg(test)]
975pub mod tests {
976 use super::*;
977
978 #[test]
980 fn util_partitioned_syndrome_pattern_1() {
981 let mut partition_config = PartitionConfig::new(132);
983 partition_config.partitions = vec![
984 VertexRange::new(0, 72), VertexRange::new(84, 132), ];
987 partition_config.fusions = vec![
988 (0, 1), ];
990 let partition_info = partition_config.info();
991 let tests = vec![
992 (vec![10, 11, 12, 71, 72, 73, 84, 85, 111], DefectRange::new(4, 6)),
993 (vec![10, 11, 12, 13, 71, 72, 73, 84, 85, 111], DefectRange::new(5, 7)),
994 (vec![10, 11, 12, 71, 72, 73, 83, 84, 85, 111], DefectRange::new(4, 7)),
995 (
996 vec![10, 11, 12, 71, 72, 73, 84, 85, 100, 101, 102, 103, 111],
997 DefectRange::new(4, 6),
998 ),
999 ];
1000 for (defect_vertices, expected_defect_range) in tests.into_iter() {
1001 let syndrome_pattern = SyndromePattern::new(defect_vertices, vec![]);
1002 let partitioned_syndrome_pattern = PartitionedSyndromePattern::new(&syndrome_pattern);
1003 let (owned_partitioned, (_left_partitioned, _right_partitioned)) =
1004 partitioned_syndrome_pattern.partition(&partition_info.units[2]);
1005 println!("defect_range: {:?}", owned_partitioned.whole_defect_range);
1006 assert_eq!(owned_partitioned.whole_defect_range, expected_defect_range);
1007 }
1008 }
1009}