fusion_blossom/
primal_module_parallel.rs

1//! Parallel Primal Module
2//!
3//! A parallel implementation of the primal module, by calling functions provided by the serial primal module
4//!
5
6#![cfg_attr(feature = "unsafe_pointer", allow(dropping_references))]
7use super::dual_module::*;
8use super::dual_module_parallel::*;
9use super::pointers::*;
10use super::primal_module::*;
11use super::primal_module_serial::*;
12use super::util::*;
13use super::visualize::*;
14use crate::rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use std::ops::DerefMut;
17use std::sync::{Arc, Condvar, Mutex};
18use std::time::{Duration, Instant};
19
20pub struct PrimalModuleParallel {
21    /// the basic wrapped serial modules at the beginning, afterwards the fused units are appended after them
22    pub units: Vec<PrimalModuleParallelUnitPtr>,
23    /// local configuration
24    pub config: PrimalModuleParallelConfig,
25    /// partition information generated by the config
26    pub partition_info: Arc<PartitionInfo>,
27    /// thread pool used to execute async functions in parallel
28    pub thread_pool: Arc<rayon::ThreadPool>,
29    /// the time of calling [`PrimalModuleParallel::parallel_solve_step_callback`] method
30    pub last_solve_start_time: ArcRwLock<Instant>,
31}
32
33pub struct PrimalModuleParallelUnit {
34    /// the index
35    pub unit_index: usize,
36    /// the dual module interface, for constant-time clear
37    pub interface_ptr: DualModuleInterfacePtr,
38    /// partition information generated by the config
39    pub partition_info: Arc<PartitionInfo>,
40    /// whether it's active or not; some units are "placeholder" units that are not active until they actually fuse their children
41    pub is_active: bool,
42    /// the owned serial primal module
43    pub serial_module: PrimalModuleSerialPtr,
44    /// left and right children dual modules
45    pub children: Option<(PrimalModuleParallelUnitWeak, PrimalModuleParallelUnitWeak)>,
46    /// parent dual module
47    pub parent: Option<PrimalModuleParallelUnitWeak>,
48    /// record the time of events
49    pub event_time: Option<PrimalModuleParallelUnitEventTime>,
50    /// streaming decode mocker, if exists, base partition will wait until specified time and then start decoding
51    pub streaming_decode_mocker: Option<StreamingDecodeMocker>,
52}
53
54pub type PrimalModuleParallelUnitPtr = ArcManualSafeLock<PrimalModuleParallelUnit>;
55pub type PrimalModuleParallelUnitWeak = WeakManualSafeLock<PrimalModuleParallelUnit>;
56
57impl std::fmt::Debug for PrimalModuleParallelUnitPtr {
58    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59        let unit = self.read_recursive();
60        write!(f, "{}", unit.unit_index)
61    }
62}
63
64impl std::fmt::Debug for PrimalModuleParallelUnitWeak {
65    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
66        self.upgrade_force().fmt(f)
67    }
68}
69
70/// the time of critical events, for profiling purposes
71#[derive(Debug, Clone, Serialize)]
72pub struct PrimalModuleParallelUnitEventTime {
73    /// unit starts executing
74    pub start: f64,
75    /// unit ends executing
76    pub end: f64,
77    /// thread index
78    pub thread_index: usize,
79}
80
81impl Default for PrimalModuleParallelUnitEventTime {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl PrimalModuleParallelUnitEventTime {
88    pub fn new() -> Self {
89        Self {
90            start: 0.,
91            end: 0.,
92            thread_index: rayon::current_thread_index().unwrap_or(0),
93        }
94    }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(deny_unknown_fields)]
99pub struct PrimalModuleParallelConfig {
100    /// enable async execution of dual operations; only used when calling top-level operations, not used in individual units
101    #[serde(default = "primal_module_parallel_default_configs::thread_pool_size")]
102    pub thread_pool_size: usize,
103    /// debug by sequentially run the fusion tasks, user must enable this for visualizer to work properly during the execution
104    #[serde(default = "primal_module_parallel_default_configs::debug_sequential")]
105    pub debug_sequential: bool,
106    /// schedule base partition tasks in the front
107    #[serde(default = "primal_module_parallel_default_configs::prioritize_base_partition")]
108    pub prioritize_base_partition: bool,
109    #[serde(default = "primal_module_parallel_default_configs::interleaving_base_fusion")]
110    pub interleaving_base_fusion: usize,
111    /// pin threads to cores sequentially
112    #[serde(default = "primal_module_parallel_default_configs::pin_threads_to_cores")]
113    pub pin_threads_to_cores: bool,
114    /// streaming decode mocker
115    pub streaming_decode_mock_measure_interval: Option<f64>,
116    /// streaming decoder using spin lock instead of threads.sleep to avoid context switch
117    #[serde(default = "primal_module_parallel_default_configs::streaming_decode_use_spin_lock")]
118    pub streaming_decode_use_spin_lock: bool,
119    /// max tree size for the serial modules, for faster speed at the cost of less accuracy
120    #[serde(default = "primal_module_parallel_default_configs::max_tree_size")]
121    pub max_tree_size: usize,
122}
123
124impl Default for PrimalModuleParallelConfig {
125    fn default() -> Self {
126        serde_json::from_value(json!({})).unwrap()
127    }
128}
129
130pub mod primal_module_parallel_default_configs {
131    pub fn thread_pool_size() -> usize {
132        0
133    } // by default to the number of CPU cores
134      // pub fn thread_pool_size() -> usize { 1 }  // debug: use a single core
135    pub fn debug_sequential() -> bool {
136        false
137    } // by default enabled: only disable when you need to debug and get visualizer to work
138    pub fn pin_threads_to_cores() -> bool {
139        false
140    } // pin threads to cores to achieve most stable results
141    pub fn prioritize_base_partition() -> bool {
142        true
143    } // by default enable because this is faster by placing time-consuming tasks in the front
144    pub fn interleaving_base_fusion() -> usize {
145        usize::MAX
146    } // starts interleaving base and fusion after this unit_index
147    pub fn streaming_decode_use_spin_lock() -> bool {
148        false
149    } // by default use threads.sleep; enable only when benchmarking latency
150    pub fn max_tree_size() -> usize {
151        usize::MAX
152    } // by default do not limit tree size
153}
154
155pub struct StreamingDecodeMocker {
156    /// indicating the syndrome ready time = `last_solve_start_time` + bias
157    pub bias: Duration,
158}
159
160impl PrimalModuleParallel {
161    /// recommended way to create a new instance, given a customized configuration
162    pub fn new_config(
163        initializer: &SolverInitializer,
164        partition_info: &PartitionInfo,
165        config: PrimalModuleParallelConfig,
166    ) -> Self {
167        let partition_info = Arc::new(partition_info.clone());
168        let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
169        if config.thread_pool_size != 0 {
170            thread_pool_builder = thread_pool_builder.num_threads(config.thread_pool_size);
171        }
172        if config.pin_threads_to_cores {
173            let core_ids = core_affinity::get_core_ids().unwrap();
174            // println!("core_ids: {core_ids:?}");
175            thread_pool_builder = thread_pool_builder.start_handler(move |thread_index| {
176                // https://stackoverflow.com/questions/7274585/linux-find-out-hyper-threaded-core-id
177                if thread_index < core_ids.len() {
178                    crate::core_affinity::set_for_current(core_ids[thread_index]);
179                } // otherwise let OS decide which core to execute
180            });
181        }
182        let thread_pool = thread_pool_builder.build().expect("creating thread pool failed");
183        let mut units = vec![];
184        let unit_count = partition_info.units.len();
185        thread_pool.scope(|_| {
186            (0..unit_count)
187                .into_par_iter()
188                .map(|unit_index| {
189                    // println!("unit_index: {unit_index}");
190                    let primal_module = PrimalModuleSerialPtr::new_empty(initializer);
191                    primal_module.write().max_tree_size = config.max_tree_size;
192                    PrimalModuleParallelUnitPtr::new_wrapper(primal_module, unit_index, Arc::clone(&partition_info))
193                })
194                .collect_into_vec(&mut units);
195        });
196        // fill in the children and parent references
197        for unit_index in 0..unit_count {
198            let mut unit = units[unit_index].write();
199            if let Some((left_children_index, right_children_index)) = &partition_info.units[unit_index].children {
200                unit.children = Some((
201                    units[*left_children_index].downgrade(),
202                    units[*right_children_index].downgrade(),
203                ))
204            }
205            if let Some(parent_index) = &partition_info.units[unit_index].parent {
206                unit.parent = Some(units[*parent_index].downgrade());
207            }
208            if let Some(measure_interval) = config.streaming_decode_mock_measure_interval {
209                if unit_index < partition_info.config.partitions.len() {
210                    // only base partition is blocked by mock hardware syndrome measurement
211                    unit.streaming_decode_mocker = Some(StreamingDecodeMocker {
212                        bias: Duration::from_secs_f64(measure_interval * (unit_index + 1) as f64),
213                    })
214                }
215            }
216        }
217        Self {
218            units,
219            config,
220            partition_info,
221            thread_pool: Arc::new(thread_pool),
222            last_solve_start_time: ArcRwLock::new_value(Instant::now()),
223        }
224    }
225}
226
227impl PrimalModuleImpl for PrimalModuleParallel {
228    fn new_empty(initializer: &SolverInitializer) -> Self {
229        Self::new_config(
230            initializer,
231            &PartitionConfig::new(initializer.vertex_num).info(),
232            PrimalModuleParallelConfig::default(),
233        )
234    }
235
236    #[inline(never)]
237    fn clear(&mut self) {
238        self.thread_pool.scope(|_| {
239            self.units.par_iter().enumerate().for_each(|(unit_idx, unit_ptr)| {
240                let mut unit = unit_ptr.write();
241                let partition_unit_info = &unit.partition_info.units[unit_idx];
242                let is_active = partition_unit_info.children.is_none();
243                unit.clear();
244                unit.is_active = is_active;
245            });
246        });
247    }
248
249    fn load_defect_dual_node(&mut self, _dual_node_ptr: &DualNodePtr) {
250        panic!("load interface directly into the parallel primal module is forbidden, use `parallel_solve` instead");
251    }
252
253    fn resolve<D: DualModuleImpl>(
254        &mut self,
255        _group_max_update_length: GroupMaxUpdateLength,
256        _interface: &DualModuleInterfacePtr,
257        _dual_module: &mut D,
258    ) {
259        panic!("parallel primal module cannot handle global resolve requests, use `parallel_solve` instead");
260    }
261
262    fn intermediate_matching<D: DualModuleImpl>(
263        &mut self,
264        interface: &DualModuleInterfacePtr,
265        dual_module: &mut D,
266    ) -> IntermediateMatching {
267        let mut intermediate_matching = IntermediateMatching::new();
268        for unit_ptr in self.units.iter() {
269            lock_write!(unit, unit_ptr);
270            if !unit.is_active {
271                continue;
272            } // do not visualize inactive units
273            intermediate_matching.append(&mut unit.serial_module.intermediate_matching(interface, dual_module));
274        }
275        intermediate_matching
276    }
277
278    fn generate_profiler_report(&self) -> serde_json::Value {
279        let event_time_vec: Vec<_> = self.units.iter().map(|ptr| ptr.read_recursive().event_time.clone()).collect();
280        json!({
281            "event_time_vec": event_time_vec,
282        })
283    }
284}
285
286impl PrimalModuleParallel {
287    pub fn parallel_solve<DualSerialModule: DualModuleImpl + Send + Sync>(
288        &mut self,
289        syndrome_pattern: &SyndromePattern,
290        parallel_dual_module: &DualModuleParallel<DualSerialModule>,
291    ) {
292        self.parallel_solve_step_callback(syndrome_pattern, parallel_dual_module, |_, _, _, _| {})
293    }
294
295    pub fn parallel_solve_visualizer<DualSerialModule: DualModuleImpl + Send + Sync + FusionVisualizer>(
296        &mut self,
297        syndrome_pattern: &SyndromePattern,
298        parallel_dual_module: &DualModuleParallel<DualSerialModule>,
299        visualizer: Option<&mut Visualizer>,
300    ) {
301        if let Some(visualizer) = visualizer {
302            self.parallel_solve_step_callback(
303                syndrome_pattern,
304                parallel_dual_module,
305                |interface_ptr, dual_module, primal_module, group_max_update_length| {
306                    if let Some(group_max_update_length) = group_max_update_length {
307                        if cfg!(debug_assertions) {
308                            println!("group_max_update_length: {:?}", group_max_update_length);
309                        }
310                        if let Some(length) = group_max_update_length.get_none_zero_growth() {
311                            visualizer
312                                .snapshot_combined(format!("grow {length}"), vec![interface_ptr, dual_module, primal_module])
313                                .unwrap();
314                        } else {
315                            let first_conflict = format!("{:?}", group_max_update_length.peek().unwrap());
316                            visualizer
317                                .snapshot_combined(
318                                    format!("resolve {first_conflict}"),
319                                    vec![interface_ptr, dual_module, primal_module],
320                                )
321                                .unwrap();
322                        };
323                    } else {
324                        visualizer
325                            .snapshot_combined("unit solved".to_string(), vec![interface_ptr, dual_module, primal_module])
326                            .unwrap();
327                    }
328                },
329            );
330            let last_unit = self.units.last().unwrap().read_recursive();
331            visualizer
332                .snapshot_combined(
333                    "solved".to_string(),
334                    vec![&last_unit.interface_ptr, parallel_dual_module, self],
335                )
336                .unwrap();
337        } else {
338            self.parallel_solve(syndrome_pattern, parallel_dual_module);
339        }
340    }
341
342    pub fn parallel_solve_step_callback<DualSerialModule: DualModuleImpl + Send + Sync, F>(
343        &mut self,
344        syndrome_pattern: &SyndromePattern,
345        parallel_dual_module: &DualModuleParallel<DualSerialModule>,
346        mut callback: F,
347    ) where
348        F: FnMut(
349                &DualModuleInterfacePtr,
350                &DualModuleParallelUnit<DualSerialModule>,
351                &PrimalModuleSerialPtr,
352                Option<&GroupMaxUpdateLength>,
353            ) + Send
354            + Sync,
355    {
356        let thread_pool = Arc::clone(&self.thread_pool);
357        *self.last_solve_start_time.write() = Instant::now();
358        if self.config.prioritize_base_partition {
359            if self.config.debug_sequential {
360                for unit_index in 0..self.partition_info.units.len() {
361                    let unit_ptr = self.units[unit_index].clone();
362                    unit_ptr.children_ready_solve::<DualSerialModule, F>(
363                        self,
364                        PartitionedSyndromePattern::new(syndrome_pattern),
365                        parallel_dual_module,
366                        &mut Some(&mut callback),
367                    );
368                }
369            } else {
370                use std::sync::atomic::{AtomicUsize, Ordering};
371                let ready_vec: Vec<_> = {
372                    (0..self.partition_info.units.len())
373                        .map(|_| Arc::new((Mutex::new(false), Condvar::new(), Arc::new(AtomicUsize::new(0)))))
374                        .collect()
375                };
376                thread_pool.scope_fifo(|s| {
377                    let issue_unit = |unit_index: usize| {
378                        let ready_vec = &ready_vec;
379                        let units = &self.units;
380                        let partition_info = &self.partition_info;
381                        let parallel_unit = &self;
382                        let parallel_dual_module = &parallel_dual_module;
383                        let streaming_decode_use_spin_lock = self.config.streaming_decode_use_spin_lock;
384                        s.spawn_fifo(move |_| {
385                            let ready_pair = ready_vec[unit_index].clone();
386                            let (ready, condvar, spin_ready) = &*ready_pair;
387                            if streaming_decode_use_spin_lock {
388                                let unit_ptr = units[unit_index].clone();
389                                if unit_index >= partition_info.config.partitions.len() {
390                                    // wait for children to complete
391                                    let fusion_index = unit_index - partition_info.config.partitions.len();
392                                    let (left_unit_index, right_unit_index) = partition_info.config.fusions[fusion_index];
393                                    for child_unit_index in [left_unit_index, right_unit_index] {
394                                        let child_ready_pair = ready_vec[child_unit_index].clone();
395                                        let (_, _, child_spin_ready) = &*child_ready_pair;
396                                        while child_spin_ready.load(Ordering::SeqCst) != 1 {
397                                            // hopefully this asserts false at the beginning
398                                            std::hint::spin_loop();
399                                            // println!("spin_loop");
400                                        }
401                                    }
402                                }
403                                unit_ptr.children_ready_solve::<DualSerialModule, F>(
404                                    parallel_unit,
405                                    PartitionedSyndromePattern::new(syndrome_pattern),
406                                    parallel_dual_module,
407                                    &mut None,
408                                );
409                                spin_ready.store(1, Ordering::SeqCst);
410                            } else {
411                                let mut is_ready = ready.lock().unwrap();
412                                let unit_ptr = units[unit_index].clone();
413                                if unit_index >= partition_info.config.partitions.len() {
414                                    // wait for children to complete
415                                    let fusion_index = unit_index - partition_info.config.partitions.len();
416                                    let (left_unit_index, right_unit_index) = partition_info.config.fusions[fusion_index];
417                                    for child_unit_index in [left_unit_index, right_unit_index] {
418                                        let child_ready_pair = ready_vec[child_unit_index].clone();
419                                        let (child_ready, child_condvar, _) = &*child_ready_pair;
420                                        let mut child_is_ready = child_ready.lock().unwrap();
421                                        while !*child_is_ready {
422                                            // hopefully this asserts false at the beginning
423                                            child_is_ready = child_condvar.wait(child_is_ready).unwrap();
424                                        }
425                                    }
426                                }
427                                unit_ptr.children_ready_solve::<DualSerialModule, F>(
428                                    parallel_unit,
429                                    PartitionedSyndromePattern::new(syndrome_pattern),
430                                    parallel_dual_module,
431                                    &mut None,
432                                );
433                                *is_ready = true;
434                                condvar.notify_one();
435                            }
436                        })
437                    };
438                    if self.config.interleaving_base_fusion >= self.partition_info.config.fusions.len() {
439                        for unit_index in 0..self.partition_info.units.len() {
440                            issue_unit(unit_index);
441                        }
442                    } else {
443                        for unit_index in 0..self.partition_info.config.partitions.len() {
444                            if unit_index >= self.config.interleaving_base_fusion {
445                                let fusion_index = self.partition_info.config.partitions.len()
446                                    + (unit_index - self.config.interleaving_base_fusion);
447                                issue_unit(fusion_index);
448                            }
449                            issue_unit(unit_index);
450                        }
451                        for bias_index in 1..self.config.interleaving_base_fusion {
452                            issue_unit(self.partition_info.units.len() - self.config.interleaving_base_fusion + bias_index);
453                        }
454                    }
455                });
456            }
457        } else {
458            let last_unit_ptr = self.units.last().unwrap().clone();
459            thread_pool.scope(|_| {
460                last_unit_ptr.iterative_solve_step_callback(
461                    self,
462                    PartitionedSyndromePattern::new(syndrome_pattern),
463                    parallel_dual_module,
464                    &mut Some(&mut callback),
465                )
466            })
467        }
468    }
469}
470
471impl FusionVisualizer for PrimalModuleParallel {
472    fn snapshot(&self, abbrev: bool) -> serde_json::Value {
473        // do the sanity check first before taking snapshot
474        // self.sanity_check().unwrap();
475        let mut value = json!({});
476        for unit_ptr in self.units.iter() {
477            let unit = unit_ptr.read_recursive();
478            if !unit.is_active {
479                continue;
480            } // do not visualize inactive units
481            let value_2 = unit.snapshot(abbrev);
482            snapshot_combine_values(&mut value, value_2, abbrev);
483        }
484        value
485    }
486}
487
488impl FusionVisualizer for PrimalModuleParallelUnit {
489    fn snapshot(&self, abbrev: bool) -> serde_json::Value {
490        self.serial_module.snapshot(abbrev)
491    }
492}
493
494impl PrimalModuleParallelUnitPtr {
495    /// create a simple wrapper over a serial dual module
496    pub fn new_wrapper(serial_module: PrimalModuleSerialPtr, unit_index: usize, partition_info: Arc<PartitionInfo>) -> Self {
497        let partition_unit_info = &partition_info.units[unit_index];
498        let is_active = partition_unit_info.children.is_none();
499        let interface_ptr = DualModuleInterfacePtr::new_empty();
500        interface_ptr.write().unit_index = unit_index;
501        Self::new_value(PrimalModuleParallelUnit {
502            unit_index,
503            interface_ptr,
504            partition_info,
505            is_active, // only activate the leaves in the dependency tree
506            serial_module,
507            children: None, // to be filled later
508            parent: None,   // to be filled later
509            event_time: None,
510            streaming_decode_mocker: None,
511        })
512    }
513
514    /// call this only if children is guaranteed to be ready and solved
515    #[allow(clippy::unnecessary_cast)]
516    #[allow(clippy::needless_borrow)]
517    fn children_ready_solve<DualSerialModule: DualModuleImpl + Send + Sync, F>(
518        &self,
519        primal_module_parallel: &PrimalModuleParallel,
520        partitioned_syndrome_pattern: PartitionedSyndromePattern,
521        parallel_dual_module: &DualModuleParallel<DualSerialModule>,
522        callback: &mut Option<&mut F>,
523    ) where
524        F: FnMut(
525                &DualModuleInterfacePtr,
526                &DualModuleParallelUnit<DualSerialModule>,
527                &PrimalModuleSerialPtr,
528                Option<&GroupMaxUpdateLength>,
529            ) + Send
530            + Sync,
531    {
532        let mut primal_unit = self.write();
533        if let Some(mocker) = &primal_unit.streaming_decode_mocker {
534            if primal_module_parallel.config.streaming_decode_use_spin_lock {
535                while primal_module_parallel.last_solve_start_time.read_recursive().elapsed() < mocker.bias {
536                    std::hint::spin_loop(); // spin to avoid context switch
537                }
538            } else {
539                let mut elapsed = primal_module_parallel.last_solve_start_time.read_recursive().elapsed();
540                while elapsed < mocker.bias {
541                    std::thread::sleep(mocker.bias - elapsed);
542                    elapsed = primal_module_parallel.last_solve_start_time.read_recursive().elapsed();
543                }
544            }
545        }
546        let mut event_time = PrimalModuleParallelUnitEventTime::new();
547        event_time.start = primal_module_parallel
548            .last_solve_start_time
549            .read_recursive()
550            .elapsed()
551            .as_secs_f64();
552        let dual_module_ptr = parallel_dual_module.get_unit(primal_unit.unit_index);
553        let mut dual_unit = dual_module_ptr.write();
554        let partition_unit_info = &primal_unit.partition_info.units[primal_unit.unit_index];
555        let (owned_defect_range, _) = partitioned_syndrome_pattern.partition(partition_unit_info);
556        let interface_ptr = primal_unit.interface_ptr.clone();
557        if let Some((left_child_weak, right_child_weak)) = primal_unit.children.as_ref() {
558            {
559                // set children to inactive to avoid being solved twice
560                for child_weak in [left_child_weak, right_child_weak] {
561                    let child_ptr = child_weak.upgrade_force();
562                    let mut child = child_ptr.write();
563                    debug_assert!(child.is_active, "cannot fuse inactive children");
564                    child.is_active = false;
565                }
566            }
567            primal_unit.fuse(&mut dual_unit);
568            if let Some(callback) = callback.as_mut() {
569                // do callback before actually breaking the matched pairs, for ease of visualization
570                callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
571            }
572            primal_unit.break_matching_with_mirror(dual_unit.deref_mut());
573            for defect_index in owned_defect_range.whole_defect_range.iter() {
574                let defect_vertex = partitioned_syndrome_pattern.syndrome_pattern.defect_vertices[defect_index as usize];
575                primal_unit
576                    .serial_module
577                    .load_defect(defect_vertex, &interface_ptr, dual_unit.deref_mut());
578            }
579            primal_unit.serial_module.solve_step_callback_interface_loaded(
580                &interface_ptr,
581                dual_unit.deref_mut(),
582                |interface, dual_module, primal_module, group_max_update_length| {
583                    if let Some(callback) = callback.as_mut() {
584                        callback(interface, dual_module, primal_module, Some(group_max_update_length));
585                    }
586                },
587            );
588            if let Some(callback) = callback.as_mut() {
589                callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
590            }
591        } else {
592            debug_assert!(primal_unit.is_active, "leaf must be active to be solved");
593            let syndrome_pattern = owned_defect_range.expand();
594            primal_unit.serial_module.solve_step_callback(
595                &interface_ptr,
596                &syndrome_pattern,
597                dual_unit.deref_mut(),
598                |interface, dual_module, primal_module, group_max_update_length| {
599                    if let Some(callback) = callback.as_mut() {
600                        callback(interface, dual_module, primal_module, Some(group_max_update_length));
601                    }
602                },
603            );
604            if let Some(callback) = callback.as_mut() {
605                callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
606            }
607        }
608        primal_unit.is_active = true;
609        event_time.end = primal_module_parallel
610            .last_solve_start_time
611            .read_recursive()
612            .elapsed()
613            .as_secs_f64();
614        primal_unit.event_time = Some(event_time);
615    }
616
617    /// call on the last primal node, and it will spawn tasks on the previous ones
618    fn iterative_solve_step_callback<DualSerialModule: DualModuleImpl + Send + Sync, F>(
619        &self,
620        primal_module_parallel: &PrimalModuleParallel,
621        partitioned_syndrome_pattern: PartitionedSyndromePattern,
622        parallel_dual_module: &DualModuleParallel<DualSerialModule>,
623        callback: &mut Option<&mut F>,
624    ) where
625        F: FnMut(
626                &DualModuleInterfacePtr,
627                &DualModuleParallelUnit<DualSerialModule>,
628                &PrimalModuleSerialPtr,
629                Option<&GroupMaxUpdateLength>,
630            ) + Send
631            + Sync,
632    {
633        let primal_unit = self.read_recursive();
634        // only when sequentially running the tasks will the callback take effect, otherwise it's unsafe to execute it from multiple threads
635        let debug_sequential = primal_module_parallel.config.debug_sequential;
636        if let Some((left_child_weak, right_child_weak)) = primal_unit.children.as_ref() {
637            // make children ready
638            debug_assert!(
639                !primal_unit.is_active,
640                "parent must be inactive at the time of solving children"
641            );
642            let partition_unit_info = &primal_unit.partition_info.units[primal_unit.unit_index];
643            let (_, (left_partitioned, right_partitioned)) = partitioned_syndrome_pattern.partition(partition_unit_info);
644            if debug_sequential {
645                left_child_weak.upgrade_force().iterative_solve_step_callback(
646                    primal_module_parallel,
647                    left_partitioned,
648                    parallel_dual_module,
649                    callback,
650                );
651                right_child_weak.upgrade_force().iterative_solve_step_callback(
652                    primal_module_parallel,
653                    right_partitioned,
654                    parallel_dual_module,
655                    callback,
656                );
657            } else {
658                rayon::join(
659                    || {
660                        left_child_weak
661                            .upgrade_force()
662                            .iterative_solve_step_callback::<DualSerialModule, F>(
663                                primal_module_parallel,
664                                left_partitioned,
665                                parallel_dual_module,
666                                &mut None,
667                            )
668                    },
669                    || {
670                        right_child_weak
671                            .upgrade_force()
672                            .iterative_solve_step_callback::<DualSerialModule, F>(
673                                primal_module_parallel,
674                                right_partitioned,
675                                parallel_dual_module,
676                                &mut None,
677                            )
678                    },
679                );
680            };
681        }
682        drop(primal_unit);
683        self.children_ready_solve(
684            primal_module_parallel,
685            partitioned_syndrome_pattern,
686            parallel_dual_module,
687            callback,
688        );
689    }
690}
691
692impl PrimalModuleParallelUnit {
693    /// fuse two units together, by copying the right child's content into the left child's content and resolve index;
694    /// note that this operation doesn't update on the dual module, call [`Self::break_matching_with_mirror`] if needed
695    pub fn fuse<DualSerialModule: DualModuleImpl + Send + Sync>(
696        &mut self,
697        dual_unit: &mut DualModuleParallelUnit<DualSerialModule>,
698    ) {
699        let (left_child_ptr, right_child_ptr) = (
700            self.children.as_ref().unwrap().0.upgrade_force(),
701            self.children.as_ref().unwrap().1.upgrade_force(),
702        );
703        let left_child = left_child_ptr.read_recursive();
704        let right_child = right_child_ptr.read_recursive();
705        dual_unit.fuse(&self.interface_ptr, (&left_child.interface_ptr, &right_child.interface_ptr));
706        self.serial_module.fuse(&left_child.serial_module, &right_child.serial_module);
707    }
708
709    /// break the matched pairs of interface vertices
710    #[allow(clippy::unnecessary_cast)]
711    pub fn break_matching_with_mirror(&mut self, dual_module: &mut impl DualModuleImpl) {
712        // use `possible_break` to efficiently break those
713        let mut possible_break = vec![];
714        let module = self.serial_module.read_recursive();
715        for node_index in module.possible_break.iter() {
716            let primal_node_ptr = module.get_node(*node_index);
717            if let Some(primal_node_ptr) = primal_node_ptr {
718                let mut primal_node = primal_node_ptr.write();
719                if let Some((MatchTarget::VirtualVertex(vertex_index), _)) = &primal_node.temporary_match {
720                    if self.partition_info.vertex_to_owning_unit[*vertex_index as usize] == self.unit_index {
721                        primal_node.temporary_match = None;
722                        self.interface_ptr.set_grow_state(
723                            &primal_node.origin.upgrade_force(),
724                            DualNodeGrowState::Grow,
725                            dual_module,
726                        );
727                    } else {
728                        // still possible break
729                        possible_break.push(*node_index);
730                    }
731                }
732            }
733        }
734        drop(module);
735        self.serial_module.write().possible_break = possible_break;
736    }
737}
738
739impl PrimalModuleImpl for PrimalModuleParallelUnit {
740    fn new_empty(_initializer: &SolverInitializer) -> Self {
741        panic!("creating parallel unit directly from initializer is forbidden, use `PrimalModuleParallel::new` instead");
742    }
743
744    fn clear(&mut self) {
745        self.serial_module.clear();
746        self.interface_ptr.clear();
747    }
748
749    fn load(&mut self, interface_ptr: &DualModuleInterfacePtr) {
750        self.serial_module.load(interface_ptr)
751    }
752
753    fn load_defect_dual_node(&mut self, dual_node_ptr: &DualNodePtr) {
754        self.serial_module.load_defect_dual_node(dual_node_ptr)
755    }
756
757    fn resolve<D: DualModuleImpl>(
758        &mut self,
759        group_max_update_length: GroupMaxUpdateLength,
760        interface: &DualModuleInterfacePtr,
761        dual_module: &mut D,
762    ) {
763        self.serial_module.resolve(group_max_update_length, interface, dual_module)
764    }
765
766    fn intermediate_matching<D: DualModuleImpl>(
767        &mut self,
768        interface: &DualModuleInterfacePtr,
769        dual_module: &mut D,
770    ) -> IntermediateMatching {
771        self.serial_module.intermediate_matching(interface, dual_module)
772    }
773}
774
775#[cfg(test)]
776pub mod tests {
777    use super::super::dual_module_serial::*;
778    use super::super::example_codes::*;
779    use super::*;
780
781    pub fn primal_module_parallel_basic_standard_syndrome_optional_viz<F>(
782        code: impl ExampleCode,
783        visualize_filename: Option<String>,
784        defect_vertices: Vec<VertexIndex>,
785        final_dual: Weight,
786        partition_func: F,
787        reordered_vertices: Option<Vec<VertexIndex>>,
788    ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
789    where
790        F: Fn(&SolverInitializer, &mut PartitionConfig),
791    {
792        primal_module_parallel_basic_standard_syndrome_optional_viz_config(
793            code,
794            visualize_filename,
795            defect_vertices,
796            final_dual,
797            partition_func,
798            reordered_vertices,
799            None,
800        )
801    }
802
803    pub fn primal_module_parallel_basic_standard_syndrome_optional_viz_config<F>(
804        mut code: impl ExampleCode,
805        visualize_filename: Option<String>,
806        mut defect_vertices: Vec<VertexIndex>,
807        final_dual: Weight,
808        partition_func: F,
809        reordered_vertices: Option<Vec<VertexIndex>>,
810        primal_config_json: Option<serde_json::Value>,
811    ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
812    where
813        F: Fn(&SolverInitializer, &mut PartitionConfig),
814    {
815        println!("{defect_vertices:?}");
816        if let Some(reordered_vertices) = &reordered_vertices {
817            code.reorder_vertices(reordered_vertices);
818            defect_vertices = translated_defect_to_reordered(reordered_vertices, &defect_vertices);
819        }
820        let mut visualizer = match visualize_filename.as_ref() {
821            Some(visualize_filename) => {
822                let visualizer = Visualizer::new(
823                    Some(visualize_data_folder() + visualize_filename.as_str()),
824                    code.get_positions(),
825                    true,
826                )
827                .unwrap();
828                print_visualize_link(visualize_filename.clone());
829                Some(visualizer)
830            }
831            None => None,
832        };
833        let initializer = code.get_initializer();
834        let mut partition_config = PartitionConfig::new(initializer.vertex_num);
835        partition_func(&initializer, &mut partition_config);
836        let partition_info = partition_config.info();
837        let mut dual_module =
838            DualModuleParallel::new_config(&initializer, &partition_info, DualModuleParallelConfig::default());
839        let primal_config = if let Some(value) = primal_config_json {
840            serde_json::from_value(value).unwrap()
841        } else {
842            PrimalModuleParallelConfig {
843                debug_sequential: true,
844                ..Default::default()
845            }
846        };
847        let mut primal_module = PrimalModuleParallel::new_config(&initializer, &partition_info, primal_config.clone());
848        code.set_defect_vertices(&defect_vertices);
849        primal_module.parallel_solve_visualizer(&code.get_syndrome(), &dual_module, visualizer.as_mut());
850        let useless_interface_ptr = DualModuleInterfacePtr::new_empty(); // don't actually use it
851        let perfect_matching = primal_module.perfect_matching(&useless_interface_ptr, &mut dual_module);
852        let mut subgraph_builder = SubGraphBuilder::new(&initializer);
853        subgraph_builder.load_perfect_matching(&perfect_matching);
854        let subgraph = subgraph_builder.get_subgraph();
855        if let Some(visualizer) = visualizer.as_mut() {
856            let last_interface_ptr = &primal_module.units.last().unwrap().read_recursive().interface_ptr;
857            visualizer
858                .snapshot_combined(
859                    "perfect matching and subgraph".to_string(),
860                    vec![
861                        last_interface_ptr,
862                        &dual_module,
863                        &perfect_matching,
864                        &VisualizeSubgraph::new(&subgraph),
865                    ],
866                )
867                .unwrap();
868        }
869        let sum_dual_variables = primal_module
870            .units
871            .last()
872            .unwrap()
873            .read_recursive()
874            .interface_ptr
875            .sum_dual_variables();
876        if primal_config.max_tree_size == usize::MAX {
877            // otherwise it's not necessarily MWPM
878            assert_eq!(
879                sum_dual_variables,
880                subgraph_builder.total_weight(),
881                "unmatched sum dual variables"
882            );
883        }
884        assert_eq!(sum_dual_variables, final_dual * 2, "unexpected final dual variable sum");
885        (primal_module, dual_module)
886    }
887
888    pub fn primal_module_parallel_standard_syndrome<F>(
889        code: impl ExampleCode,
890        visualize_filename: String,
891        defect_vertices: Vec<VertexIndex>,
892        final_dual: Weight,
893        partition_func: F,
894        reordered_vertices: Option<Vec<VertexIndex>>,
895    ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
896    where
897        F: Fn(&SolverInitializer, &mut PartitionConfig),
898    {
899        primal_module_parallel_basic_standard_syndrome_optional_viz(
900            code,
901            Some(visualize_filename),
902            defect_vertices,
903            final_dual,
904            partition_func,
905            reordered_vertices,
906        )
907    }
908
909    /// test a simple case
910    #[test]
911    fn primal_module_parallel_basic_1() {
912        // cargo test primal_module_parallel_basic_1 -- --nocapture
913        let visualize_filename = "primal_module_parallel_basic_1.json".to_string();
914        let defect_vertices = vec![39, 52, 63, 90, 100];
915        let half_weight = 500;
916        primal_module_parallel_standard_syndrome(
917            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
918            visualize_filename,
919            defect_vertices,
920            9 * half_weight,
921            |initializer, _config| {
922                println!("initializer: {initializer:?}");
923            },
924            None,
925        );
926    }
927
928    /// split into 2, with no syndrome vertex on the interface
929    #[test]
930    fn primal_module_parallel_basic_2() {
931        // cargo test primal_module_parallel_basic_2 -- --nocapture
932        let visualize_filename = "primal_module_parallel_basic_2.json".to_string();
933        let defect_vertices = vec![39, 52, 63, 90, 100];
934        let half_weight = 500;
935        primal_module_parallel_standard_syndrome(
936            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
937            visualize_filename,
938            defect_vertices,
939            9 * half_weight,
940            |_initializer, config| {
941                config.partitions = vec![
942                    VertexRange::new(0, 72),   // unit 0
943                    VertexRange::new(84, 132), // unit 1
944                ];
945                config.fusions = vec![
946                    (0, 1), // unit 2, by fusing 0 and 1
947                ];
948            },
949            None,
950        );
951    }
952
953    /// split into 2, with a syndrome vertex on the interface
954    #[test]
955    fn primal_module_parallel_basic_3() {
956        // cargo test primal_module_parallel_basic_3 -- --nocapture
957        let visualize_filename = "primal_module_parallel_basic_3.json".to_string();
958        let defect_vertices = vec![39, 52, 63, 90, 100];
959        let half_weight = 500;
960        primal_module_parallel_standard_syndrome(
961            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
962            visualize_filename,
963            defect_vertices,
964            9 * half_weight,
965            |_initializer, config| {
966                config.partitions = vec![
967                    VertexRange::new(0, 60),   // unit 0
968                    VertexRange::new(72, 132), // unit 1
969                ];
970                config.fusions = vec![
971                    (0, 1), // unit 2, by fusing 0 and 1
972                ];
973            },
974            None,
975        );
976    }
977
978    /// split into 4, with no syndrome vertex on the interface
979    #[test]
980    fn primal_module_parallel_basic_4() {
981        // cargo test primal_module_parallel_basic_4 -- --nocapture
982        let visualize_filename = "primal_module_parallel_basic_4.json".to_string();
983        // reorder vertices to enable the partition;
984        let defect_vertices = vec![39, 52, 63, 90, 100]; // indices are before the reorder
985        let half_weight = 500;
986        primal_module_parallel_standard_syndrome(
987            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
988            visualize_filename,
989            defect_vertices,
990            9 * half_weight,
991            |_initializer, config| {
992                config.partitions = vec![
993                    VertexRange::new(0, 36),
994                    VertexRange::new(42, 72),
995                    VertexRange::new(84, 108),
996                    VertexRange::new(112, 132),
997                ];
998                config.fusions = vec![(0, 1), (2, 3), (4, 5)];
999            },
1000            Some({
1001                let mut reordered_vertices = vec![];
1002                let split_horizontal = 6;
1003                let split_vertical = 5;
1004                for i in 0..split_horizontal {
1005                    // left-top block
1006                    for j in 0..split_vertical {
1007                        reordered_vertices.push(i * 12 + j);
1008                    }
1009                    reordered_vertices.push(i * 12 + 11);
1010                }
1011                for i in 0..split_horizontal {
1012                    // interface between the left-top block and the right-top block
1013                    reordered_vertices.push(i * 12 + split_vertical);
1014                }
1015                for i in 0..split_horizontal {
1016                    // right-top block
1017                    for j in (split_vertical + 1)..10 {
1018                        reordered_vertices.push(i * 12 + j);
1019                    }
1020                    reordered_vertices.push(i * 12 + 10);
1021                }
1022                {
1023                    // the big interface between top and bottom
1024                    for j in 0..12 {
1025                        reordered_vertices.push(split_horizontal * 12 + j);
1026                    }
1027                }
1028                for i in (split_horizontal + 1)..11 {
1029                    // left-bottom block
1030                    for j in 0..split_vertical {
1031                        reordered_vertices.push(i * 12 + j);
1032                    }
1033                    reordered_vertices.push(i * 12 + 11);
1034                }
1035                for i in (split_horizontal + 1)..11 {
1036                    // interface between the left-bottom block and the right-bottom block
1037                    reordered_vertices.push(i * 12 + split_vertical);
1038                }
1039                for i in (split_horizontal + 1)..11 {
1040                    // right-bottom block
1041                    for j in (split_vertical + 1)..10 {
1042                        reordered_vertices.push(i * 12 + j);
1043                    }
1044                    reordered_vertices.push(i * 12 + 10);
1045                }
1046                reordered_vertices
1047            }),
1048        );
1049    }
1050
1051    /// split into 4, with 2 defect vertices on parent interfaces
1052    #[test]
1053    fn primal_module_parallel_basic_5() {
1054        // cargo test primal_module_parallel_basic_5 -- --nocapture
1055        let visualize_filename = "primal_module_parallel_basic_5.json".to_string();
1056        // reorder vertices to enable the partition;
1057        let defect_vertices = vec![39, 52, 63, 90, 100]; // indices are before the reorder
1058        let half_weight = 500;
1059        primal_module_parallel_standard_syndrome(
1060            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
1061            visualize_filename,
1062            defect_vertices,
1063            9 * half_weight,
1064            |_initializer, config| {
1065                config.partitions = vec![
1066                    VertexRange::new(0, 25),
1067                    VertexRange::new(30, 60),
1068                    VertexRange::new(72, 97),
1069                    VertexRange::new(102, 132),
1070                ];
1071                config.fusions = vec![(0, 1), (2, 3), (4, 5)];
1072            },
1073            Some({
1074                let mut reordered_vertices = vec![];
1075                let split_horizontal = 5;
1076                let split_vertical = 4;
1077                for i in 0..split_horizontal {
1078                    // left-top block
1079                    for j in 0..split_vertical {
1080                        reordered_vertices.push(i * 12 + j);
1081                    }
1082                    reordered_vertices.push(i * 12 + 11);
1083                }
1084                for i in 0..split_horizontal {
1085                    // interface between the left-top block and the right-top block
1086                    reordered_vertices.push(i * 12 + split_vertical);
1087                }
1088                for i in 0..split_horizontal {
1089                    // right-top block
1090                    for j in (split_vertical + 1)..10 {
1091                        reordered_vertices.push(i * 12 + j);
1092                    }
1093                    reordered_vertices.push(i * 12 + 10);
1094                }
1095                {
1096                    // the big interface between top and bottom
1097                    for j in 0..12 {
1098                        reordered_vertices.push(split_horizontal * 12 + j);
1099                    }
1100                }
1101                for i in (split_horizontal + 1)..11 {
1102                    // left-bottom block
1103                    for j in 0..split_vertical {
1104                        reordered_vertices.push(i * 12 + j);
1105                    }
1106                    reordered_vertices.push(i * 12 + 11);
1107                }
1108                for i in (split_horizontal + 1)..11 {
1109                    // interface between the left-bottom block and the right-bottom block
1110                    reordered_vertices.push(i * 12 + split_vertical);
1111                }
1112                for i in (split_horizontal + 1)..11 {
1113                    // right-bottom block
1114                    for j in (split_vertical + 1)..10 {
1115                        reordered_vertices.push(i * 12 + j);
1116                    }
1117                    reordered_vertices.push(i * 12 + 10);
1118                }
1119                reordered_vertices
1120            }),
1121        );
1122    }
1123
1124    fn primal_module_parallel_debug_planar_code_common(
1125        d: VertexNum,
1126        visualize_filename: String,
1127        defect_vertices: Vec<VertexIndex>,
1128        final_dual: Weight,
1129    ) {
1130        let half_weight = 500;
1131        let split_horizontal = (d + 1) / 2;
1132        let row_count = d + 1;
1133        primal_module_parallel_standard_syndrome(
1134            CodeCapacityPlanarCode::new(d, 0.1, half_weight),
1135            visualize_filename,
1136            defect_vertices,
1137            final_dual * half_weight,
1138            |initializer, config| {
1139                config.partitions = vec![
1140                    VertexRange::new(0, split_horizontal * row_count),
1141                    VertexRange::new((split_horizontal + 1) * row_count, initializer.vertex_num),
1142                ];
1143                config.fusions = vec![(0, 1)];
1144            },
1145            None,
1146        );
1147    }
1148
1149    /// 68000 vs 69000 dual variable: probably missing some interface node
1150    /// panicked at 'vacating a non-boundary vertex is forbidden', src/dual_module_serial.rs:899:25
1151    /// reason: when executing sync events, I forgot to add the new propagated dual module to the active list;
1152    /// why it didn't show up before: because usually a node is created when executing sync event, in which case it's automatically in the active list
1153    /// if this node already exists before, and it's again synchronized, then it's not in the active list, leading to strange growth
1154    #[test]
1155    fn primal_module_parallel_debug_1() {
1156        // cargo test primal_module_parallel_debug_1 -- --nocapture
1157        let visualize_filename = "primal_module_parallel_debug_1.json".to_string();
1158        let defect_vertices = vec![88, 89, 102, 103, 105, 106, 118, 120, 122, 134, 138]; // indices are before the reorder
1159        primal_module_parallel_debug_planar_code_common(15, visualize_filename, defect_vertices, 10);
1160    }
1161
1162    /// test fusion union-find
1163    #[test]
1164    fn primal_module_parallel_union_find_basic_1() {
1165        // cargo test primal_module_parallel_union_find_basic_1 -- --nocapture
1166        let visualize_filename = "primal_module_parallel_union_find_basic_1.json".to_string();
1167        let defect_vertices = vec![51, 52, 53, 88];
1168        let half_weight = 500;
1169        primal_module_parallel_basic_standard_syndrome_optional_viz_config(
1170            CodeCapacityPlanarCode::new(11, 0.1, half_weight),
1171            Some(visualize_filename),
1172            defect_vertices,
1173            4 * half_weight,
1174            |_initializer, config| {
1175                config.partitions = vec![
1176                    VertexRange::new(0, 72),   // unit 0
1177                    VertexRange::new(84, 132), // unit 1
1178                ];
1179                config.fusions = vec![
1180                    (0, 1), // unit 2, by fusing 0 and 1
1181                ];
1182            },
1183            None,
1184            Some(json!({ "max_tree_size": 0, "debug_sequential": true })),
1185        );
1186    }
1187}