fusion_blossom/
util.rs

1use super::mwpm_solver::PrimalDualSolver;
2use super::pointers::*;
3use super::rand_xoshiro;
4use crate::rand_xoshiro::rand_core::RngCore;
5#[cfg(feature = "python_binding")]
6use pyo3::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeSet;
9use std::fs::File;
10use std::io::prelude::*;
11use std::time::Instant;
12
13cfg_if::cfg_if! {
14    if #[cfg(feature="i32_weight")] {
15        /// use i32 to store weight to be compatible with blossom V library (c_int)
16        pub type Weight = i32;
17    } else {
18        pub type Weight = isize;
19    }
20}
21
22cfg_if::cfg_if! {
23    if #[cfg(feature="u32_index")] {
24        // use u32 to store index, for less memory usage
25        pub type EdgeIndex = u32;
26        pub type VertexIndex = u32;  // the vertex index in the decoding graph
27        pub type NodeIndex = VertexIndex;
28        pub type DefectIndex = VertexIndex;
29        pub type VertexNodeIndex = VertexIndex;  // must be same as VertexIndex, NodeIndex, DefectIndex
30        pub type VertexNum = VertexIndex;
31        pub type NodeNum = VertexIndex;
32    } else {
33        pub type EdgeIndex = usize;
34        pub type VertexIndex = usize;
35        pub type NodeIndex = VertexIndex;
36        pub type DefectIndex = VertexIndex;
37        pub type VertexNodeIndex = VertexIndex;  // must be same as VertexIndex, NodeIndex, DefectIndex
38        pub type VertexNum = VertexIndex;
39        pub type NodeNum = VertexIndex;
40    }
41}
42
43#[cfg(feature = "python_binding")]
44macro_rules! bind_trait_python_json {
45    ($struct_name:ident) => {
46        #[pymethods]
47        impl $struct_name {
48            #[pyo3(name = "to_json")]
49            fn python_to_json(&self) -> PyResult<String> {
50                serde_json::to_string(self).map_err(|err| pyo3::exceptions::PyTypeError::new_err(format!("{err:?}")))
51            }
52            #[staticmethod]
53            #[pyo3(name = "from_json")]
54            fn python_from_json(value: String) -> PyResult<Self> {
55                serde_json::from_str(value.as_str())
56                    .map_err(|err| pyo3::exceptions::PyTypeError::new_err(format!("{err:?}")))
57            }
58        }
59    };
60}
61
62#[cfg_attr(feature = "python_binding", cfg_eval)]
63#[cfg_attr(feature = "python_binding", pyclass)]
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct SolverInitializer {
66    /// the number of vertices
67    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
68    pub vertex_num: VertexNum,
69    /// weighted edges, where vertex indices are within the range [0, vertex_num)
70    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
71    pub weighted_edges: Vec<(VertexIndex, VertexIndex, Weight)>,
72    /// the virtual vertices
73    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
74    pub virtual_vertices: Vec<VertexIndex>,
75}
76
77#[cfg(feature = "python_binding")]
78bind_trait_python_json! {SolverInitializer}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81#[cfg_attr(feature = "python_binding", cfg_eval)]
82#[cfg_attr(feature = "python_binding", pyclass)]
83pub struct SyndromePattern {
84    /// the vertices corresponding to defect measurements
85    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
86    pub defect_vertices: Vec<VertexIndex>,
87    /// the edges that experience erasures, i.e. known errors;
88    /// note that erasure decoding can also be implemented using `dynamic_weights`,
89    /// but for user convenience we keep this interface
90    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
91    #[serde(default = "default_erasures")]
92    pub erasures: Vec<EdgeIndex>,
93    /// general dynamically weighted edges
94    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
95    #[serde(default = "default_dynamic_weights")]
96    pub dynamic_weights: Vec<(EdgeIndex, Weight)>,
97}
98
99pub fn default_dynamic_weights() -> Vec<(EdgeIndex, Weight)> {
100    vec![]
101}
102
103pub fn default_erasures() -> Vec<EdgeIndex> {
104    vec![]
105}
106
107impl SyndromePattern {
108    pub fn new(defect_vertices: Vec<VertexIndex>, erasures: Vec<EdgeIndex>) -> Self {
109        Self {
110            defect_vertices,
111            erasures,
112            dynamic_weights: vec![],
113        }
114    }
115    pub fn new_dynamic_weights(
116        defect_vertices: Vec<VertexIndex>,
117        erasures: Vec<EdgeIndex>,
118        dynamic_weights: Vec<(EdgeIndex, Weight)>,
119    ) -> Self {
120        Self {
121            defect_vertices,
122            erasures,
123            dynamic_weights,
124        }
125    }
126}
127
128#[cfg_attr(feature = "python_binding", cfg_eval)]
129#[cfg_attr(feature = "python_binding", pymethods)]
130impl SyndromePattern {
131    #[cfg_attr(feature = "python_binding", new)]
132    #[cfg_attr(feature = "python_binding", pyo3(signature = (defect_vertices=vec![], erasures=vec![], dynamic_weights=vec![], syndrome_vertices=None)))]
133    pub fn py_new(
134        mut defect_vertices: Vec<VertexIndex>,
135        erasures: Vec<EdgeIndex>,
136        dynamic_weights: Vec<(EdgeIndex, Weight)>,
137        syndrome_vertices: Option<Vec<VertexIndex>>,
138    ) -> Self {
139        if let Some(syndrome_vertices) = syndrome_vertices {
140            assert!(
141                defect_vertices.is_empty(),
142                "do not pass both `syndrome_vertices` and `defect_vertices` since they're aliasing"
143            );
144            defect_vertices = syndrome_vertices;
145        }
146        assert!(
147            erasures.is_empty() || dynamic_weights.is_empty(),
148            "erasures and dynamic_weights cannot be provided at the same time"
149        );
150        Self::new_dynamic_weights(defect_vertices, erasures, dynamic_weights)
151    }
152    #[cfg_attr(feature = "python_binding", staticmethod)]
153    pub fn new_vertices(defect_vertices: Vec<VertexIndex>) -> Self {
154        Self::new(defect_vertices, vec![])
155    }
156    #[cfg_attr(feature = "python_binding", staticmethod)]
157    pub fn new_empty() -> Self {
158        Self::new(vec![], vec![])
159    }
160    #[cfg(feature = "python_binding")]
161    fn __repr__(&self) -> String {
162        format!("{:?}", self)
163    }
164}
165
166/// an efficient representation of partitioned vertices and erasures when they're ordered
167#[derive(Debug, Clone, Serialize)]
168pub struct PartitionedSyndromePattern<'a> {
169    /// the original syndrome pattern to be partitioned
170    pub syndrome_pattern: &'a SyndromePattern,
171    /// the defect range of this partition: it must be continuous if the defect vertices are ordered
172    pub whole_defect_range: DefectRange,
173}
174
175impl<'a> PartitionedSyndromePattern<'a> {
176    pub fn new(syndrome_pattern: &'a SyndromePattern) -> Self {
177        assert!(
178            syndrome_pattern.erasures.is_empty(),
179            "erasure partition not supported yet;
180        even if the edges in the erasure is well ordered, they may not be able to be represented as
181        a single range simply because the partition is vertex-based. need more consideration"
182        );
183        Self {
184            syndrome_pattern,
185            whole_defect_range: DefectRange::new(0, syndrome_pattern.defect_vertices.len() as DefectIndex),
186        }
187    }
188}
189
190#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
191#[serde(transparent)]
192#[cfg_attr(feature = "python_binding", cfg_eval)]
193#[cfg_attr(feature = "python_binding", pyclass)]
194pub struct IndexRange {
195    pub range: [VertexNodeIndex; 2],
196}
197
198// just to distinguish them in code, essentially nothing different
199pub type VertexRange = IndexRange;
200pub type NodeRange = IndexRange;
201pub type DefectRange = IndexRange;
202
203#[cfg_attr(feature = "python_binding", cfg_eval)]
204#[cfg_attr(feature = "python_binding", pymethods)]
205impl IndexRange {
206    #[cfg_attr(feature = "python_binding", new)]
207    pub fn new(start: VertexNodeIndex, end: VertexNodeIndex) -> Self {
208        debug_assert!(end >= start, "invalid range [{}, {})", start, end);
209        Self { range: [start, end] }
210    }
211    #[cfg_attr(feature = "python_binding", staticmethod)]
212    pub fn new_length(start: VertexNodeIndex, length: VertexNodeIndex) -> Self {
213        Self::new(start, start + length)
214    }
215    pub fn is_empty(&self) -> bool {
216        self.range[1] == self.range[0]
217    }
218    #[allow(clippy::unnecessary_cast)]
219    pub fn len(&self) -> usize {
220        (self.range[1] - self.range[0]) as usize
221    }
222    pub fn start(&self) -> VertexNodeIndex {
223        self.range[0]
224    }
225    pub fn end(&self) -> VertexNodeIndex {
226        self.range[1]
227    }
228    pub fn append_by(&mut self, append_count: VertexNodeIndex) {
229        self.range[1] += append_count;
230    }
231    pub fn bias_by(&mut self, bias: VertexNodeIndex) {
232        self.range[0] += bias;
233        self.range[1] += bias;
234    }
235    pub fn sanity_check(&self) {
236        assert!(self.start() <= self.end(), "invalid vertex range {:?}", self);
237    }
238    pub fn contains(&self, vertex_index: VertexNodeIndex) -> bool {
239        vertex_index >= self.start() && vertex_index < self.end()
240    }
241    /// fuse two ranges together, returning (the whole range, the interfacing range)
242    pub fn fuse(&self, other: &Self) -> (Self, Self) {
243        self.sanity_check();
244        other.sanity_check();
245        assert!(self.range[1] <= other.range[0], "only lower range can fuse higher range");
246        (
247            Self::new(self.range[0], other.range[1]),
248            Self::new(self.range[1], other.range[0]),
249        )
250    }
251    #[cfg(feature = "python_binding")]
252    #[pyo3(name = "contains_any")]
253    pub fn python_contains_any(&self, vertex_indices: Vec<VertexNodeIndex>) -> bool {
254        self.contains_any(&vertex_indices)
255    }
256    #[cfg(feature = "python_binding")]
257    fn __repr__(&self) -> String {
258        format!("{:?}", self)
259    }
260}
261
262impl IndexRange {
263    pub fn iter(&self) -> std::ops::Range<VertexNodeIndex> {
264        self.range[0]..self.range[1]
265    }
266    pub fn contains_any(&self, vertex_indices: &[VertexNodeIndex]) -> bool {
267        for vertex_index in vertex_indices.iter() {
268            if self.contains(*vertex_index) {
269                return true;
270            }
271        }
272        false
273    }
274}
275
276/// a general partition unit that could contain mirrored vertices
277#[derive(Debug, Clone)]
278pub struct PartitionUnit {
279    /// unit index
280    pub unit_index: usize,
281    /// whether it's enabled; when disabled, the mirrored vertices behaves just like virtual vertices
282    pub enabled: bool,
283}
284
285pub type PartitionUnitPtr = ArcManualSafeLock<PartitionUnit>;
286pub type PartitionUnitWeak = WeakManualSafeLock<PartitionUnit>;
287
288impl std::fmt::Debug for PartitionUnitPtr {
289    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
290        let partition_unit = self.read_recursive();
291        write!(
292            f,
293            "{}{}",
294            if partition_unit.enabled { "E" } else { "D" },
295            partition_unit.unit_index
296        )
297    }
298}
299
300impl std::fmt::Debug for PartitionUnitWeak {
301    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
302        self.upgrade_force().fmt(f)
303    }
304}
305
306/// user input partition configuration
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(deny_unknown_fields)]
309#[cfg_attr(feature = "python_binding", cfg_eval)]
310#[cfg_attr(feature = "python_binding", pyclass)]
311pub struct PartitionConfig {
312    /// the number of vertices
313    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
314    pub vertex_num: VertexNum,
315    /// detailed plan of partitioning serial modules: each serial module possesses a list of vertices, including all interface vertices
316    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
317    pub partitions: Vec<VertexRange>,
318    /// detailed plan of interfacing vertices
319    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
320    pub fusions: Vec<(usize, usize)>,
321}
322
323#[cfg(feature = "python_binding")]
324bind_trait_python_json! {PartitionConfig}
325
326#[cfg_attr(feature = "python_binding", cfg_eval)]
327#[cfg_attr(feature = "python_binding", pymethods)]
328impl PartitionConfig {
329    #[cfg_attr(feature = "python_binding", new)]
330    pub fn new(vertex_num: VertexNum) -> Self {
331        Self {
332            vertex_num,
333            partitions: vec![VertexRange::new(0, vertex_num as VertexIndex)],
334            fusions: vec![],
335        }
336    }
337
338    #[cfg(feature = "python_binding")]
339    fn __repr__(&self) -> String {
340        format!("{:?}", self)
341    }
342
343    #[allow(clippy::unnecessary_cast)]
344    pub fn info(&self) -> PartitionInfo {
345        assert!(!self.partitions.is_empty(), "at least one partition must exist");
346        let mut whole_ranges = vec![];
347        let mut owning_ranges = vec![];
348        for &partition in self.partitions.iter() {
349            partition.sanity_check();
350            assert!(
351                partition.end() <= self.vertex_num as VertexIndex,
352                "invalid vertex index {} in partitions",
353                partition.end()
354            );
355            whole_ranges.push(partition);
356            owning_ranges.push(partition);
357        }
358        let unit_count = self.partitions.len() + self.fusions.len();
359        let mut parents: Vec<Option<usize>> = (0..unit_count).map(|_| None).collect();
360        for (fusion_index, (left_index, right_index)) in self.fusions.iter().enumerate() {
361            let unit_index = fusion_index + self.partitions.len();
362            assert!(
363                *left_index < unit_index,
364                "dependency wrong, {} depending on {}",
365                unit_index,
366                left_index
367            );
368            assert!(
369                *right_index < unit_index,
370                "dependency wrong, {} depending on {}",
371                unit_index,
372                right_index
373            );
374            assert!(parents[*left_index].is_none(), "cannot fuse {} twice", left_index);
375            assert!(parents[*right_index].is_none(), "cannot fuse {} twice", right_index);
376            parents[*left_index] = Some(unit_index);
377            parents[*right_index] = Some(unit_index);
378            // fusing range
379            let (whole_range, interface_range) = whole_ranges[*left_index].fuse(&whole_ranges[*right_index]);
380            whole_ranges.push(whole_range);
381            owning_ranges.push(interface_range);
382        }
383        // check that all nodes except for the last one has been merged
384        for (unit_index, parent) in parents.iter().enumerate().take(unit_count - 1) {
385            assert!(parent.is_some(), "found unit {} without being fused", unit_index);
386        }
387        // check that the final node has the full range
388        let last_unit_index = self.partitions.len() + self.fusions.len() - 1;
389        assert!(
390            whole_ranges[last_unit_index].start() == 0,
391            "final range not covering all vertices {:?}",
392            whole_ranges[last_unit_index]
393        );
394        assert!(
395            whole_ranges[last_unit_index].end() == self.vertex_num as VertexIndex,
396            "final range not covering all vertices {:?}",
397            whole_ranges[last_unit_index]
398        );
399        // construct partition info
400        let mut partition_unit_info: Vec<_> = (0..self.partitions.len() + self.fusions.len())
401            .map(|i| PartitionUnitInfo {
402                whole_range: whole_ranges[i],
403                owning_range: owning_ranges[i],
404                children: if i >= self.partitions.len() {
405                    Some(self.fusions[i - self.partitions.len()])
406                } else {
407                    None
408                },
409                parent: parents[i],
410                leaves: if i < self.partitions.len() { vec![i] } else { vec![] },
411                descendants: BTreeSet::new(),
412            })
413            .collect();
414        // build descendants
415        for (fusion_index, (left_index, right_index)) in self.fusions.iter().enumerate() {
416            let unit_index = fusion_index + self.partitions.len();
417            let mut leaves = vec![];
418            leaves.extend(partition_unit_info[*left_index].leaves.iter());
419            leaves.extend(partition_unit_info[*right_index].leaves.iter());
420            partition_unit_info[unit_index].leaves.extend(leaves.iter());
421            let mut descendants = vec![];
422            descendants.push(*left_index);
423            descendants.push(*right_index);
424            descendants.extend(partition_unit_info[*left_index].descendants.iter());
425            descendants.extend(partition_unit_info[*right_index].descendants.iter());
426            partition_unit_info[unit_index].descendants.extend(descendants.iter());
427        }
428        let mut vertex_to_owning_unit: Vec<_> = (0..self.vertex_num).map(|_| usize::MAX).collect();
429        for (unit_index, unit_range) in partition_unit_info.iter().map(|x| x.owning_range).enumerate() {
430            for vertex_index in unit_range.iter() {
431                vertex_to_owning_unit[vertex_index as usize] = unit_index;
432            }
433        }
434        PartitionInfo {
435            config: self.clone(),
436            units: partition_unit_info,
437            vertex_to_owning_unit,
438        }
439    }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
443#[cfg_attr(feature = "python_binding", cfg_eval)]
444#[cfg_attr(feature = "python_binding", pyclass)]
445pub struct PartitionInfo {
446    /// the initial configuration that creates this info
447    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
448    pub config: PartitionConfig,
449    /// individual info of each unit
450    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
451    pub units: Vec<PartitionUnitInfo>,
452    /// the mapping from vertices to the owning unit: serial unit (holding real vertices) as well as parallel units (holding interfacing vertices);
453    /// used for loading syndrome to the holding units
454    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
455    pub vertex_to_owning_unit: Vec<usize>,
456}
457
458#[cfg(feature = "python_binding")]
459bind_trait_python_json! {PartitionInfo}
460
461#[cfg_attr(feature = "python_binding", pymethods)]
462impl PartitionInfo {
463    /// split a sequence of syndrome into multiple parts, each corresponds to a unit;
464    /// this is a slow method and should only be used when the syndrome pattern is not well-ordered
465    #[allow(clippy::unnecessary_cast)]
466    pub fn partition_syndrome_unordered(&self, syndrome_pattern: &SyndromePattern) -> Vec<SyndromePattern> {
467        let mut partitioned_syndrome: Vec<_> = (0..self.units.len()).map(|_| SyndromePattern::new_empty()).collect();
468        for defect_vertex in syndrome_pattern.defect_vertices.iter() {
469            let unit_index = self.vertex_to_owning_unit[*defect_vertex as usize];
470            partitioned_syndrome[unit_index].defect_vertices.push(*defect_vertex);
471        }
472        // TODO: partition edges
473        partitioned_syndrome
474    }
475
476    #[cfg(feature = "python_binding")]
477    fn __repr__(&self) -> String {
478        format!("{:?}", self)
479    }
480}
481
482impl<'a> PartitionedSyndromePattern<'a> {
483    /// partition the syndrome pattern into 2 partitioned syndrome pattern and my whole range
484    #[allow(clippy::unnecessary_cast)]
485    pub fn partition(&self, partition_unit_info: &PartitionUnitInfo) -> (Self, (Self, Self)) {
486        // first binary search the start of owning defect vertices
487        let owning_start_index = {
488            let mut left_index = self.whole_defect_range.start();
489            let mut right_index = self.whole_defect_range.end();
490            while left_index != right_index {
491                let mid_index = (left_index + right_index) / 2;
492                let mid_defect_vertex = self.syndrome_pattern.defect_vertices[mid_index as usize];
493                if mid_defect_vertex < partition_unit_info.owning_range.start() {
494                    left_index = mid_index + 1;
495                } else {
496                    right_index = mid_index;
497                }
498            }
499            left_index
500        };
501        // second binary search the end of owning defect vertices
502        let owning_end_index = {
503            let mut left_index = self.whole_defect_range.start();
504            let mut right_index = self.whole_defect_range.end();
505            while left_index != right_index {
506                let mid_index = (left_index + right_index) / 2;
507                let mid_defect_vertex = self.syndrome_pattern.defect_vertices[mid_index as usize];
508                if mid_defect_vertex < partition_unit_info.owning_range.end() {
509                    left_index = mid_index + 1;
510                } else {
511                    right_index = mid_index;
512                }
513            }
514            left_index
515        };
516        (
517            Self {
518                syndrome_pattern: self.syndrome_pattern,
519                whole_defect_range: DefectRange::new(owning_start_index, owning_end_index),
520            },
521            (
522                Self {
523                    syndrome_pattern: self.syndrome_pattern,
524                    whole_defect_range: DefectRange::new(self.whole_defect_range.start(), owning_start_index),
525                },
526                Self {
527                    syndrome_pattern: self.syndrome_pattern,
528                    whole_defect_range: DefectRange::new(owning_end_index, self.whole_defect_range.end()),
529                },
530            ),
531        )
532    }
533
534    #[allow(clippy::unnecessary_cast)]
535    pub fn expand(&self) -> SyndromePattern {
536        let mut defect_vertices = Vec::with_capacity(self.whole_defect_range.len());
537        for defect_index in self.whole_defect_range.iter() {
538            defect_vertices.push(self.syndrome_pattern.defect_vertices[defect_index as usize]);
539        }
540        SyndromePattern::new(defect_vertices, vec![])
541    }
542}
543
544#[derive(Debug, Clone, Serialize, Deserialize)]
545#[cfg_attr(feature = "python_binding", cfg_eval)]
546#[cfg_attr(feature = "python_binding", pyclass)]
547pub struct PartitionUnitInfo {
548    /// the whole range of units
549    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
550    pub whole_range: VertexRange,
551    /// the owning range of units, meaning vertices inside are exclusively belonging to the unit
552    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
553    pub owning_range: VertexRange,
554    /// left and right
555    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
556    pub children: Option<(usize, usize)>,
557    /// parent dual module
558    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
559    pub parent: Option<usize>,
560    /// all the leaf dual modules
561    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
562    pub leaves: Vec<usize>,
563    /// all the descendants
564    #[cfg_attr(feature = "python_binding", pyo3(get, set))]
565    pub descendants: BTreeSet<usize>,
566}
567
568#[cfg(feature = "python_binding")]
569bind_trait_python_json! {PartitionUnitInfo}
570
571#[cfg(feature = "python_binding")]
572#[pymethods]
573impl PartitionUnitInfo {
574    fn __repr__(&self) -> String {
575        format!("{:?}", self)
576    }
577}
578
579#[derive(Debug, Clone)]
580pub struct PartitionedSolverInitializer {
581    /// unit index
582    pub unit_index: usize,
583    /// the number of all vertices (including those partitioned into other serial modules)
584    pub vertex_num: VertexNum,
585    /// the number of all edges (including those partitioned into other serial modules)
586    pub edge_num: usize,
587    /// vertices exclusively owned by this partition; this part must be a continuous range
588    pub owning_range: VertexRange,
589    /// applicable when all the owning vertices are partitioned (i.e. this belongs to a fusion unit)
590    pub owning_interface: Option<PartitionUnitWeak>,
591    /// if applicable, parent interface comes first, then the grandparent interface, ... note that some ancestor might be skipped because it has no mirrored vertices;
592    /// we skip them because if the partition is in a chain, most of them would only have to know two interfaces on the left and on the right; nothing else necessary.
593    /// (unit_index, list of vertices owned by this ancestor unit and should be mirrored at this partition and whether it's virtual)
594    pub interfaces: Vec<(PartitionUnitWeak, Vec<(VertexIndex, bool)>)>,
595    /// weighted edges, where the first vertex index is within the range [vertex_index_bias, vertex_index_bias + vertex_num) and
596    /// the second is either in [vertex_index_bias, vertex_index_bias + vertex_num) or inside
597    pub weighted_edges: Vec<(VertexIndex, VertexIndex, Weight, EdgeIndex)>,
598    /// the virtual vertices
599    pub virtual_vertices: Vec<VertexIndex>,
600}
601
602/// perform index transformation
603#[allow(clippy::unnecessary_cast)]
604pub fn build_old_to_new(reordered_vertices: &[VertexIndex]) -> Vec<Option<VertexIndex>> {
605    let mut old_to_new: Vec<Option<VertexIndex>> = (0..reordered_vertices.len()).map(|_| None).collect();
606    for (new_index, old_index) in reordered_vertices.iter().enumerate() {
607        assert_eq!(old_to_new[*old_index as usize], None, "duplicate vertex found {}", old_index);
608        old_to_new[*old_index as usize] = Some(new_index as VertexIndex);
609    }
610    old_to_new
611}
612
613/// translate defect vertices into the current new index given reordered_vertices
614#[allow(clippy::unnecessary_cast)]
615pub fn translated_defect_to_reordered(
616    reordered_vertices: &[VertexIndex],
617    old_defect_vertices: &[VertexIndex],
618) -> Vec<VertexIndex> {
619    let old_to_new = build_old_to_new(reordered_vertices);
620    old_defect_vertices
621        .iter()
622        .map(|old_index| old_to_new[*old_index as usize].unwrap())
623        .collect()
624}
625
626#[cfg_attr(feature = "python_binding", cfg_eval)]
627#[cfg_attr(feature = "python_binding", pymethods)]
628impl SolverInitializer {
629    #[cfg_attr(feature = "python_binding", new)]
630    pub fn new(
631        vertex_num: VertexNum,
632        weighted_edges: Vec<(VertexIndex, VertexIndex, Weight)>,
633        virtual_vertices: Vec<VertexIndex>,
634    ) -> SolverInitializer {
635        SolverInitializer {
636            vertex_num,
637            weighted_edges,
638            virtual_vertices,
639        }
640    }
641    #[cfg(feature = "python_binding")]
642    fn __repr__(&self) -> String {
643        format!("{:?}", self)
644    }
645}
646
647impl SolverInitializer {
648    #[allow(clippy::unnecessary_cast)]
649    pub fn syndrome_of(&self, subgraph: &[EdgeIndex]) -> BTreeSet<VertexIndex> {
650        let mut defects = BTreeSet::new();
651        for edge_index in subgraph {
652            let (left, right, _weight) = self.weighted_edges[*edge_index as usize];
653            for vertex_index in [left, right] {
654                if defects.contains(&vertex_index) {
655                    defects.remove(&vertex_index);
656                } else {
657                    defects.insert(vertex_index);
658                }
659            }
660        }
661        // remove virtual vertices
662        for vertex_index in self.virtual_vertices.iter() {
663            defects.remove(vertex_index);
664        }
665        defects
666    }
667}
668
669/// timestamp type determines how many fast clear before a hard clear is required, see [`FastClear`]
670pub type FastClearTimestamp = usize;
671
672#[allow(dead_code)]
673/// use Xoshiro256StarStar for deterministic random number generator
674pub type DeterministicRng = rand_xoshiro::Xoshiro256StarStar;
675
676pub trait F64Rng {
677    fn next_f64(&mut self) -> f64;
678}
679
680impl F64Rng for DeterministicRng {
681    fn next_f64(&mut self) -> f64 {
682        f64::from_bits(0x3FF << 52 | self.next_u64() >> 12) - 1.
683    }
684}
685
686/// record the decoding time of multiple syndrome patterns
687pub struct BenchmarkProfiler {
688    /// each record corresponds to a different syndrome pattern
689    pub records: Vec<BenchmarkProfilerEntry>,
690    /// summation of all decoding time
691    pub sum_round_time: f64,
692    /// syndrome count
693    pub sum_syndrome: usize,
694    /// noisy measurement round
695    pub noisy_measurements: VertexNum,
696    /// the file to output the profiler results
697    pub benchmark_profiler_output: Option<File>,
698}
699
700impl BenchmarkProfiler {
701    pub fn new(noisy_measurements: VertexNum, detail_log_file: Option<(String, &PartitionInfo)>) -> Self {
702        let benchmark_profiler_output = detail_log_file.map(|(filename, partition_info)| {
703            let mut file = File::create(filename).unwrap();
704            file.write_all(serde_json::to_string(&partition_info.config).unwrap().as_bytes())
705                .unwrap();
706            file.write_all(b"\n").unwrap();
707            file.write_all(
708                serde_json::to_string(&json!({
709                    "noisy_measurements": noisy_measurements,
710                }))
711                .unwrap()
712                .as_bytes(),
713            )
714            .unwrap();
715            file.write_all(b"\n").unwrap();
716            file
717        });
718        Self {
719            records: vec![],
720            sum_round_time: 0.,
721            sum_syndrome: 0,
722            noisy_measurements,
723            benchmark_profiler_output,
724        }
725    }
726    /// record the beginning of a decoding procedure
727    pub fn begin(&mut self, syndrome_pattern: &SyndromePattern) {
728        // sanity check last entry, if exists, is complete
729        if let Some(last_entry) = self.records.last() {
730            assert!(
731                last_entry.is_complete(),
732                "the last benchmark profiler entry is not complete, make sure to call `begin` and `end` in pairs"
733            );
734        }
735        let entry = BenchmarkProfilerEntry::new(syndrome_pattern);
736        self.records.push(entry);
737        self.records.last_mut().unwrap().record_begin();
738    }
739    pub fn event(&mut self, event_name: String) {
740        let last_entry = self
741            .records
742            .last_mut()
743            .expect("last entry not exists, call `begin` before `end`");
744        last_entry.record_event(event_name);
745    }
746    /// record the ending of a decoding procedure
747    pub fn end(&mut self, solver: Option<&dyn PrimalDualSolver>) {
748        let last_entry = self
749            .records
750            .last_mut()
751            .expect("last entry not exists, call `begin` before `end`");
752        last_entry.record_end();
753        self.sum_round_time += last_entry.round_time.unwrap();
754        self.sum_syndrome += last_entry.syndrome_pattern.defect_vertices.len();
755        if let Some(file) = self.benchmark_profiler_output.as_mut() {
756            let mut events = serde_json::Map::new();
757            for (event_name, time) in last_entry.events.iter() {
758                events.insert(event_name.clone(), json!(time));
759            }
760            let mut value = json!({
761                "round_time": last_entry.round_time.unwrap(),
762                "defect_num": last_entry.syndrome_pattern.defect_vertices.len(),
763                "events": events,
764            });
765            if let Some(solver) = solver {
766                let solver_profile = solver.generate_profiler_report();
767                value
768                    .as_object_mut()
769                    .unwrap()
770                    .insert("solver_profile".to_string(), solver_profile);
771            }
772            file.write_all(serde_json::to_string(&value).unwrap().as_bytes()).unwrap();
773            file.write_all(b"\n").unwrap();
774        }
775    }
776    /// print out a brief one-line statistics
777    pub fn brief(&self) -> String {
778        let total = self.sum_round_time / (self.records.len() as f64);
779        let per_round = total / (1. + self.noisy_measurements as f64);
780        let per_defect = self.sum_round_time / (self.sum_syndrome as f64);
781        format!("total: {total:.3e}, round: {per_round:.3e}, defect: {per_defect:.3e},")
782    }
783}
784
785pub struct BenchmarkProfilerEntry {
786    /// the syndrome pattern of this decoding problem
787    pub syndrome_pattern: SyndromePattern,
788    /// the time of beginning a decoding procedure
789    begin_time: Option<Instant>,
790    /// record additional events
791    pub events: Vec<(String, f64)>,
792    /// interval between calling [`Self::record_begin`] to calling [`Self::record_end`]
793    pub round_time: Option<f64>,
794}
795
796impl BenchmarkProfilerEntry {
797    pub fn new(syndrome_pattern: &SyndromePattern) -> Self {
798        Self {
799            syndrome_pattern: syndrome_pattern.clone(),
800            begin_time: None,
801            events: vec![],
802            round_time: None,
803        }
804    }
805    /// record the beginning of a decoding procedure
806    pub fn record_begin(&mut self) {
807        assert_eq!(self.begin_time, None, "do not call `record_begin` twice on the same entry");
808        self.begin_time = Some(Instant::now());
809    }
810    /// record the ending of a decoding procedure
811    pub fn record_end(&mut self) {
812        let begin_time = self
813            .begin_time
814            .as_ref()
815            .expect("make sure to call `record_begin` before calling `record_end`");
816        self.round_time = Some(begin_time.elapsed().as_secs_f64());
817    }
818    pub fn record_event(&mut self, event_name: String) {
819        let begin_time = self
820            .begin_time
821            .as_ref()
822            .expect("make sure to call `record_begin` before calling `record_end`");
823        self.events.push((event_name, begin_time.elapsed().as_secs_f64()));
824    }
825    pub fn is_complete(&self) -> bool {
826        self.round_time.is_some()
827    }
828}
829
830/**
831 * If you want to modify a field of a Rust struct, it will return a copy of it to avoid memory unsafety.
832 * Thus, typical way of modifying a python field doesn't work, e.g. `obj.a.b.c = 1` won't actually modify `obj`.
833 * This helper class is used to modify a field easier; but please note this can be very time consuming if not optimized well.
834 *
835 * Example:
836 * with PyMut(code, "vertices") as vertices:
837 *     with fb.PyMut(vertices[0], "position") as position:
838 *         position.i = 100
839*/
840#[cfg(feature = "python_binding")]
841#[pyclass]
842pub struct PyMut {
843    /// the python object that provides getter and setter function for the attribute
844    #[pyo3(get, set)]
845    object: PyObject,
846    /// the name of the attribute
847    #[pyo3(get, set)]
848    attr_name: String,
849    /// the python attribute object that is taken from `object[attr_name]`
850    #[pyo3(get, set)]
851    attr_object: Option<PyObject>,
852}
853
854#[cfg(feature = "python_binding")]
855#[pymethods]
856impl PyMut {
857    #[new]
858    pub fn new(object: PyObject, attr_name: String) -> Self {
859        Self {
860            object,
861            attr_name,
862            attr_object: None,
863        }
864    }
865    pub fn __enter__(&mut self) -> PyObject {
866        assert!(self.attr_object.is_none(), "do not enter twice");
867        Python::with_gil(|py| {
868            let attr_object = self.object.getattr(py, self.attr_name.as_str()).unwrap();
869            self.attr_object = Some(attr_object.clone_ref(py));
870            attr_object
871        })
872    }
873    pub fn __exit__(&mut self, _exc_type: PyObject, _exc_val: PyObject, _exc_tb: PyObject) {
874        Python::with_gil(|py| {
875            self.object
876                .setattr(py, self.attr_name.as_str(), self.attr_object.take().unwrap())
877                .unwrap()
878        })
879    }
880}
881
882#[cfg(feature = "python_binding")]
883pub fn json_to_pyobject_locked<'py>(value: serde_json::Value, py: Python<'py>) -> PyObject {
884    match value {
885        serde_json::Value::Null => py.None(),
886        serde_json::Value::Bool(value) => value.to_object(py).into(),
887        serde_json::Value::Number(value) => {
888            if value.is_i64() {
889                value.as_i64().to_object(py).into()
890            } else {
891                value.as_f64().to_object(py).into()
892            }
893        }
894        serde_json::Value::String(value) => value.to_object(py).into(),
895        serde_json::Value::Array(array) => {
896            let elements: Vec<PyObject> = array.into_iter().map(|value| json_to_pyobject_locked(value, py)).collect();
897            pyo3::types::PyList::new(py, elements).into()
898        }
899        serde_json::Value::Object(map) => {
900            let pydict = pyo3::types::PyDict::new(py);
901            for (key, value) in map.into_iter() {
902                let pyobject = json_to_pyobject_locked(value, py);
903                pydict.set_item(key, pyobject).unwrap();
904            }
905            pydict.into()
906        }
907    }
908}
909
910#[cfg(feature = "python_binding")]
911pub fn json_to_pyobject(value: serde_json::Value) -> PyObject {
912    Python::with_gil(|py| json_to_pyobject_locked(value, py))
913}
914
915#[cfg(feature = "python_binding")]
916pub fn pyobject_to_json_locked<'py>(value: PyObject, py: Python<'py>) -> serde_json::Value {
917    let value: &PyAny = value.as_ref(py);
918    if value.is_none() {
919        serde_json::Value::Null
920    } else if value.is_instance_of::<pyo3::types::PyBool>().unwrap() {
921        json!(value.extract::<bool>().unwrap())
922    } else if value.is_instance_of::<pyo3::types::PyInt>().unwrap() {
923        json!(value.extract::<i64>().unwrap())
924    } else if value.is_instance_of::<pyo3::types::PyFloat>().unwrap() {
925        json!(value.extract::<f64>().unwrap())
926    } else if value.is_instance_of::<pyo3::types::PyString>().unwrap() {
927        json!(value.extract::<String>().unwrap())
928    } else if value.is_instance_of::<pyo3::types::PyList>().unwrap() {
929        let elements: Vec<serde_json::Value> = value
930            .extract::<Vec<PyObject>>()
931            .unwrap()
932            .into_iter()
933            .map(|object| pyobject_to_json_locked(object, py))
934            .collect();
935        json!(elements)
936    } else if value.is_instance_of::<pyo3::types::PyDict>().unwrap() {
937        let map: &pyo3::types::PyDict = value.downcast().unwrap();
938        let mut json_map = serde_json::Map::new();
939        for (key, value) in map.iter() {
940            json_map.insert(
941                key.extract::<String>().unwrap(),
942                pyobject_to_json_locked(value.to_object(py), py),
943            );
944        }
945        serde_json::Value::Object(json_map)
946    } else {
947        unimplemented!("unsupported python type, should be (cascaded) dict, list and basic numerical types")
948    }
949}
950
951#[cfg(feature = "python_binding")]
952pub fn pyobject_to_json(value: PyObject) -> serde_json::Value {
953    Python::with_gil(|py| pyobject_to_json_locked(value, py))
954}
955
956#[cfg(feature = "python_binding")]
957#[pyfunction]
958pub(crate) fn register(py: Python<'_>, m: &PyModule) -> PyResult<()> {
959    m.add_class::<SolverInitializer>()?;
960    m.add_class::<PyMut>()?;
961    m.add_class::<PartitionUnitInfo>()?;
962    m.add_class::<PartitionInfo>()?;
963    m.add_class::<PartitionConfig>()?;
964    m.add_class::<SyndromePattern>()?;
965    use crate::pyo3::PyTypeInfo;
966    // m.add_class::<IndexRange>()?;
967    m.add("VertexRange", VertexRange::type_object(py))?;
968    m.add("DefectRange", DefectRange::type_object(py))?;
969    m.add("SyndromeRange", DefectRange::type_object(py))?; // backward compatibility
970    m.add("NodeRange", NodeRange::type_object(py))?;
971    Ok(())
972}
973
974#[cfg(test)]
975pub mod tests {
976    use super::*;
977
978    /// test syndrome partition utilities
979    #[test]
980    fn util_partitioned_syndrome_pattern_1() {
981        // cargo test util_partitioned_syndrome_pattern_1 -- --nocapture
982        let mut partition_config = PartitionConfig::new(132);
983        partition_config.partitions = vec![
984            VertexRange::new(0, 72),   // unit 0
985            VertexRange::new(84, 132), // unit 1
986        ];
987        partition_config.fusions = vec![
988            (0, 1), // unit 2, by fusing 0 and 1
989        ];
990        let partition_info = partition_config.info();
991        let tests = vec![
992            (vec![10, 11, 12, 71, 72, 73, 84, 85, 111], DefectRange::new(4, 6)),
993            (vec![10, 11, 12, 13, 71, 72, 73, 84, 85, 111], DefectRange::new(5, 7)),
994            (vec![10, 11, 12, 71, 72, 73, 83, 84, 85, 111], DefectRange::new(4, 7)),
995            (
996                vec![10, 11, 12, 71, 72, 73, 84, 85, 100, 101, 102, 103, 111],
997                DefectRange::new(4, 6),
998            ),
999        ];
1000        for (defect_vertices, expected_defect_range) in tests.into_iter() {
1001            let syndrome_pattern = SyndromePattern::new(defect_vertices, vec![]);
1002            let partitioned_syndrome_pattern = PartitionedSyndromePattern::new(&syndrome_pattern);
1003            let (owned_partitioned, (_left_partitioned, _right_partitioned)) =
1004                partitioned_syndrome_pattern.partition(&partition_info.units[2]);
1005            println!("defect_range: {:?}", owned_partitioned.whole_defect_range);
1006            assert_eq!(owned_partitioned.whole_defect_range, expected_defect_range);
1007        }
1008    }
1009}