1#![cfg_attr(feature = "unsafe_pointer", allow(dropping_references))]
7use super::dual_module::*;
8use super::dual_module_parallel::*;
9use super::pointers::*;
10use super::primal_module::*;
11use super::primal_module_serial::*;
12use super::util::*;
13use super::visualize::*;
14use crate::rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use std::ops::DerefMut;
17use std::sync::{Arc, Condvar, Mutex};
18use std::time::{Duration, Instant};
19
20pub struct PrimalModuleParallel {
21 pub units: Vec<PrimalModuleParallelUnitPtr>,
23 pub config: PrimalModuleParallelConfig,
25 pub partition_info: Arc<PartitionInfo>,
27 pub thread_pool: Arc<rayon::ThreadPool>,
29 pub last_solve_start_time: ArcRwLock<Instant>,
31}
32
33pub struct PrimalModuleParallelUnit {
34 pub unit_index: usize,
36 pub interface_ptr: DualModuleInterfacePtr,
38 pub partition_info: Arc<PartitionInfo>,
40 pub is_active: bool,
42 pub serial_module: PrimalModuleSerialPtr,
44 pub children: Option<(PrimalModuleParallelUnitWeak, PrimalModuleParallelUnitWeak)>,
46 pub parent: Option<PrimalModuleParallelUnitWeak>,
48 pub event_time: Option<PrimalModuleParallelUnitEventTime>,
50 pub streaming_decode_mocker: Option<StreamingDecodeMocker>,
52}
53
54pub type PrimalModuleParallelUnitPtr = ArcManualSafeLock<PrimalModuleParallelUnit>;
55pub type PrimalModuleParallelUnitWeak = WeakManualSafeLock<PrimalModuleParallelUnit>;
56
57impl std::fmt::Debug for PrimalModuleParallelUnitPtr {
58 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59 let unit = self.read_recursive();
60 write!(f, "{}", unit.unit_index)
61 }
62}
63
64impl std::fmt::Debug for PrimalModuleParallelUnitWeak {
65 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
66 self.upgrade_force().fmt(f)
67 }
68}
69
70#[derive(Debug, Clone, Serialize)]
72pub struct PrimalModuleParallelUnitEventTime {
73 pub start: f64,
75 pub end: f64,
77 pub thread_index: usize,
79}
80
81impl Default for PrimalModuleParallelUnitEventTime {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl PrimalModuleParallelUnitEventTime {
88 pub fn new() -> Self {
89 Self {
90 start: 0.,
91 end: 0.,
92 thread_index: rayon::current_thread_index().unwrap_or(0),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(deny_unknown_fields)]
99pub struct PrimalModuleParallelConfig {
100 #[serde(default = "primal_module_parallel_default_configs::thread_pool_size")]
102 pub thread_pool_size: usize,
103 #[serde(default = "primal_module_parallel_default_configs::debug_sequential")]
105 pub debug_sequential: bool,
106 #[serde(default = "primal_module_parallel_default_configs::prioritize_base_partition")]
108 pub prioritize_base_partition: bool,
109 #[serde(default = "primal_module_parallel_default_configs::interleaving_base_fusion")]
110 pub interleaving_base_fusion: usize,
111 #[serde(default = "primal_module_parallel_default_configs::pin_threads_to_cores")]
113 pub pin_threads_to_cores: bool,
114 pub streaming_decode_mock_measure_interval: Option<f64>,
116 #[serde(default = "primal_module_parallel_default_configs::streaming_decode_use_spin_lock")]
118 pub streaming_decode_use_spin_lock: bool,
119 #[serde(default = "primal_module_parallel_default_configs::max_tree_size")]
121 pub max_tree_size: usize,
122}
123
124impl Default for PrimalModuleParallelConfig {
125 fn default() -> Self {
126 serde_json::from_value(json!({})).unwrap()
127 }
128}
129
130pub mod primal_module_parallel_default_configs {
131 pub fn thread_pool_size() -> usize {
132 0
133 } pub fn debug_sequential() -> bool {
136 false
137 } pub fn pin_threads_to_cores() -> bool {
139 false
140 } pub fn prioritize_base_partition() -> bool {
142 true
143 } pub fn interleaving_base_fusion() -> usize {
145 usize::MAX
146 } pub fn streaming_decode_use_spin_lock() -> bool {
148 false
149 } pub fn max_tree_size() -> usize {
151 usize::MAX
152 } }
154
155pub struct StreamingDecodeMocker {
156 pub bias: Duration,
158}
159
160impl PrimalModuleParallel {
161 pub fn new_config(
163 initializer: &SolverInitializer,
164 partition_info: &PartitionInfo,
165 config: PrimalModuleParallelConfig,
166 ) -> Self {
167 let partition_info = Arc::new(partition_info.clone());
168 let mut thread_pool_builder = rayon::ThreadPoolBuilder::new();
169 if config.thread_pool_size != 0 {
170 thread_pool_builder = thread_pool_builder.num_threads(config.thread_pool_size);
171 }
172 if config.pin_threads_to_cores {
173 let core_ids = core_affinity::get_core_ids().unwrap();
174 thread_pool_builder = thread_pool_builder.start_handler(move |thread_index| {
176 if thread_index < core_ids.len() {
178 crate::core_affinity::set_for_current(core_ids[thread_index]);
179 } });
181 }
182 let thread_pool = thread_pool_builder.build().expect("creating thread pool failed");
183 let mut units = vec![];
184 let unit_count = partition_info.units.len();
185 thread_pool.scope(|_| {
186 (0..unit_count)
187 .into_par_iter()
188 .map(|unit_index| {
189 let primal_module = PrimalModuleSerialPtr::new_empty(initializer);
191 primal_module.write().max_tree_size = config.max_tree_size;
192 PrimalModuleParallelUnitPtr::new_wrapper(primal_module, unit_index, Arc::clone(&partition_info))
193 })
194 .collect_into_vec(&mut units);
195 });
196 for unit_index in 0..unit_count {
198 let mut unit = units[unit_index].write();
199 if let Some((left_children_index, right_children_index)) = &partition_info.units[unit_index].children {
200 unit.children = Some((
201 units[*left_children_index].downgrade(),
202 units[*right_children_index].downgrade(),
203 ))
204 }
205 if let Some(parent_index) = &partition_info.units[unit_index].parent {
206 unit.parent = Some(units[*parent_index].downgrade());
207 }
208 if let Some(measure_interval) = config.streaming_decode_mock_measure_interval {
209 if unit_index < partition_info.config.partitions.len() {
210 unit.streaming_decode_mocker = Some(StreamingDecodeMocker {
212 bias: Duration::from_secs_f64(measure_interval * (unit_index + 1) as f64),
213 })
214 }
215 }
216 }
217 Self {
218 units,
219 config,
220 partition_info,
221 thread_pool: Arc::new(thread_pool),
222 last_solve_start_time: ArcRwLock::new_value(Instant::now()),
223 }
224 }
225}
226
227impl PrimalModuleImpl for PrimalModuleParallel {
228 fn new_empty(initializer: &SolverInitializer) -> Self {
229 Self::new_config(
230 initializer,
231 &PartitionConfig::new(initializer.vertex_num).info(),
232 PrimalModuleParallelConfig::default(),
233 )
234 }
235
236 #[inline(never)]
237 fn clear(&mut self) {
238 self.thread_pool.scope(|_| {
239 self.units.par_iter().enumerate().for_each(|(unit_idx, unit_ptr)| {
240 let mut unit = unit_ptr.write();
241 let partition_unit_info = &unit.partition_info.units[unit_idx];
242 let is_active = partition_unit_info.children.is_none();
243 unit.clear();
244 unit.is_active = is_active;
245 });
246 });
247 }
248
249 fn load_defect_dual_node(&mut self, _dual_node_ptr: &DualNodePtr) {
250 panic!("load interface directly into the parallel primal module is forbidden, use `parallel_solve` instead");
251 }
252
253 fn resolve<D: DualModuleImpl>(
254 &mut self,
255 _group_max_update_length: GroupMaxUpdateLength,
256 _interface: &DualModuleInterfacePtr,
257 _dual_module: &mut D,
258 ) {
259 panic!("parallel primal module cannot handle global resolve requests, use `parallel_solve` instead");
260 }
261
262 fn intermediate_matching<D: DualModuleImpl>(
263 &mut self,
264 interface: &DualModuleInterfacePtr,
265 dual_module: &mut D,
266 ) -> IntermediateMatching {
267 let mut intermediate_matching = IntermediateMatching::new();
268 for unit_ptr in self.units.iter() {
269 lock_write!(unit, unit_ptr);
270 if !unit.is_active {
271 continue;
272 } intermediate_matching.append(&mut unit.serial_module.intermediate_matching(interface, dual_module));
274 }
275 intermediate_matching
276 }
277
278 fn generate_profiler_report(&self) -> serde_json::Value {
279 let event_time_vec: Vec<_> = self.units.iter().map(|ptr| ptr.read_recursive().event_time.clone()).collect();
280 json!({
281 "event_time_vec": event_time_vec,
282 })
283 }
284}
285
286impl PrimalModuleParallel {
287 pub fn parallel_solve<DualSerialModule: DualModuleImpl + Send + Sync>(
288 &mut self,
289 syndrome_pattern: &SyndromePattern,
290 parallel_dual_module: &DualModuleParallel<DualSerialModule>,
291 ) {
292 self.parallel_solve_step_callback(syndrome_pattern, parallel_dual_module, |_, _, _, _| {})
293 }
294
295 pub fn parallel_solve_visualizer<DualSerialModule: DualModuleImpl + Send + Sync + FusionVisualizer>(
296 &mut self,
297 syndrome_pattern: &SyndromePattern,
298 parallel_dual_module: &DualModuleParallel<DualSerialModule>,
299 visualizer: Option<&mut Visualizer>,
300 ) {
301 if let Some(visualizer) = visualizer {
302 self.parallel_solve_step_callback(
303 syndrome_pattern,
304 parallel_dual_module,
305 |interface_ptr, dual_module, primal_module, group_max_update_length| {
306 if let Some(group_max_update_length) = group_max_update_length {
307 if cfg!(debug_assertions) {
308 println!("group_max_update_length: {:?}", group_max_update_length);
309 }
310 if let Some(length) = group_max_update_length.get_none_zero_growth() {
311 visualizer
312 .snapshot_combined(format!("grow {length}"), vec![interface_ptr, dual_module, primal_module])
313 .unwrap();
314 } else {
315 let first_conflict = format!("{:?}", group_max_update_length.peek().unwrap());
316 visualizer
317 .snapshot_combined(
318 format!("resolve {first_conflict}"),
319 vec![interface_ptr, dual_module, primal_module],
320 )
321 .unwrap();
322 };
323 } else {
324 visualizer
325 .snapshot_combined("unit solved".to_string(), vec![interface_ptr, dual_module, primal_module])
326 .unwrap();
327 }
328 },
329 );
330 let last_unit = self.units.last().unwrap().read_recursive();
331 visualizer
332 .snapshot_combined(
333 "solved".to_string(),
334 vec![&last_unit.interface_ptr, parallel_dual_module, self],
335 )
336 .unwrap();
337 } else {
338 self.parallel_solve(syndrome_pattern, parallel_dual_module);
339 }
340 }
341
342 pub fn parallel_solve_step_callback<DualSerialModule: DualModuleImpl + Send + Sync, F>(
343 &mut self,
344 syndrome_pattern: &SyndromePattern,
345 parallel_dual_module: &DualModuleParallel<DualSerialModule>,
346 mut callback: F,
347 ) where
348 F: FnMut(
349 &DualModuleInterfacePtr,
350 &DualModuleParallelUnit<DualSerialModule>,
351 &PrimalModuleSerialPtr,
352 Option<&GroupMaxUpdateLength>,
353 ) + Send
354 + Sync,
355 {
356 let thread_pool = Arc::clone(&self.thread_pool);
357 *self.last_solve_start_time.write() = Instant::now();
358 if self.config.prioritize_base_partition {
359 if self.config.debug_sequential {
360 for unit_index in 0..self.partition_info.units.len() {
361 let unit_ptr = self.units[unit_index].clone();
362 unit_ptr.children_ready_solve::<DualSerialModule, F>(
363 self,
364 PartitionedSyndromePattern::new(syndrome_pattern),
365 parallel_dual_module,
366 &mut Some(&mut callback),
367 );
368 }
369 } else {
370 use std::sync::atomic::{AtomicUsize, Ordering};
371 let ready_vec: Vec<_> = {
372 (0..self.partition_info.units.len())
373 .map(|_| Arc::new((Mutex::new(false), Condvar::new(), Arc::new(AtomicUsize::new(0)))))
374 .collect()
375 };
376 thread_pool.scope_fifo(|s| {
377 let issue_unit = |unit_index: usize| {
378 let ready_vec = &ready_vec;
379 let units = &self.units;
380 let partition_info = &self.partition_info;
381 let parallel_unit = &self;
382 let parallel_dual_module = ¶llel_dual_module;
383 let streaming_decode_use_spin_lock = self.config.streaming_decode_use_spin_lock;
384 s.spawn_fifo(move |_| {
385 let ready_pair = ready_vec[unit_index].clone();
386 let (ready, condvar, spin_ready) = &*ready_pair;
387 if streaming_decode_use_spin_lock {
388 let unit_ptr = units[unit_index].clone();
389 if unit_index >= partition_info.config.partitions.len() {
390 let fusion_index = unit_index - partition_info.config.partitions.len();
392 let (left_unit_index, right_unit_index) = partition_info.config.fusions[fusion_index];
393 for child_unit_index in [left_unit_index, right_unit_index] {
394 let child_ready_pair = ready_vec[child_unit_index].clone();
395 let (_, _, child_spin_ready) = &*child_ready_pair;
396 while child_spin_ready.load(Ordering::SeqCst) != 1 {
397 std::hint::spin_loop();
399 }
401 }
402 }
403 unit_ptr.children_ready_solve::<DualSerialModule, F>(
404 parallel_unit,
405 PartitionedSyndromePattern::new(syndrome_pattern),
406 parallel_dual_module,
407 &mut None,
408 );
409 spin_ready.store(1, Ordering::SeqCst);
410 } else {
411 let mut is_ready = ready.lock().unwrap();
412 let unit_ptr = units[unit_index].clone();
413 if unit_index >= partition_info.config.partitions.len() {
414 let fusion_index = unit_index - partition_info.config.partitions.len();
416 let (left_unit_index, right_unit_index) = partition_info.config.fusions[fusion_index];
417 for child_unit_index in [left_unit_index, right_unit_index] {
418 let child_ready_pair = ready_vec[child_unit_index].clone();
419 let (child_ready, child_condvar, _) = &*child_ready_pair;
420 let mut child_is_ready = child_ready.lock().unwrap();
421 while !*child_is_ready {
422 child_is_ready = child_condvar.wait(child_is_ready).unwrap();
424 }
425 }
426 }
427 unit_ptr.children_ready_solve::<DualSerialModule, F>(
428 parallel_unit,
429 PartitionedSyndromePattern::new(syndrome_pattern),
430 parallel_dual_module,
431 &mut None,
432 );
433 *is_ready = true;
434 condvar.notify_one();
435 }
436 })
437 };
438 if self.config.interleaving_base_fusion >= self.partition_info.config.fusions.len() {
439 for unit_index in 0..self.partition_info.units.len() {
440 issue_unit(unit_index);
441 }
442 } else {
443 for unit_index in 0..self.partition_info.config.partitions.len() {
444 if unit_index >= self.config.interleaving_base_fusion {
445 let fusion_index = self.partition_info.config.partitions.len()
446 + (unit_index - self.config.interleaving_base_fusion);
447 issue_unit(fusion_index);
448 }
449 issue_unit(unit_index);
450 }
451 for bias_index in 1..self.config.interleaving_base_fusion {
452 issue_unit(self.partition_info.units.len() - self.config.interleaving_base_fusion + bias_index);
453 }
454 }
455 });
456 }
457 } else {
458 let last_unit_ptr = self.units.last().unwrap().clone();
459 thread_pool.scope(|_| {
460 last_unit_ptr.iterative_solve_step_callback(
461 self,
462 PartitionedSyndromePattern::new(syndrome_pattern),
463 parallel_dual_module,
464 &mut Some(&mut callback),
465 )
466 })
467 }
468 }
469}
470
471impl FusionVisualizer for PrimalModuleParallel {
472 fn snapshot(&self, abbrev: bool) -> serde_json::Value {
473 let mut value = json!({});
476 for unit_ptr in self.units.iter() {
477 let unit = unit_ptr.read_recursive();
478 if !unit.is_active {
479 continue;
480 } let value_2 = unit.snapshot(abbrev);
482 snapshot_combine_values(&mut value, value_2, abbrev);
483 }
484 value
485 }
486}
487
488impl FusionVisualizer for PrimalModuleParallelUnit {
489 fn snapshot(&self, abbrev: bool) -> serde_json::Value {
490 self.serial_module.snapshot(abbrev)
491 }
492}
493
494impl PrimalModuleParallelUnitPtr {
495 pub fn new_wrapper(serial_module: PrimalModuleSerialPtr, unit_index: usize, partition_info: Arc<PartitionInfo>) -> Self {
497 let partition_unit_info = &partition_info.units[unit_index];
498 let is_active = partition_unit_info.children.is_none();
499 let interface_ptr = DualModuleInterfacePtr::new_empty();
500 interface_ptr.write().unit_index = unit_index;
501 Self::new_value(PrimalModuleParallelUnit {
502 unit_index,
503 interface_ptr,
504 partition_info,
505 is_active, serial_module,
507 children: None, parent: None, event_time: None,
510 streaming_decode_mocker: None,
511 })
512 }
513
514 #[allow(clippy::unnecessary_cast)]
516 #[allow(clippy::needless_borrow)]
517 fn children_ready_solve<DualSerialModule: DualModuleImpl + Send + Sync, F>(
518 &self,
519 primal_module_parallel: &PrimalModuleParallel,
520 partitioned_syndrome_pattern: PartitionedSyndromePattern,
521 parallel_dual_module: &DualModuleParallel<DualSerialModule>,
522 callback: &mut Option<&mut F>,
523 ) where
524 F: FnMut(
525 &DualModuleInterfacePtr,
526 &DualModuleParallelUnit<DualSerialModule>,
527 &PrimalModuleSerialPtr,
528 Option<&GroupMaxUpdateLength>,
529 ) + Send
530 + Sync,
531 {
532 let mut primal_unit = self.write();
533 if let Some(mocker) = &primal_unit.streaming_decode_mocker {
534 if primal_module_parallel.config.streaming_decode_use_spin_lock {
535 while primal_module_parallel.last_solve_start_time.read_recursive().elapsed() < mocker.bias {
536 std::hint::spin_loop(); }
538 } else {
539 let mut elapsed = primal_module_parallel.last_solve_start_time.read_recursive().elapsed();
540 while elapsed < mocker.bias {
541 std::thread::sleep(mocker.bias - elapsed);
542 elapsed = primal_module_parallel.last_solve_start_time.read_recursive().elapsed();
543 }
544 }
545 }
546 let mut event_time = PrimalModuleParallelUnitEventTime::new();
547 event_time.start = primal_module_parallel
548 .last_solve_start_time
549 .read_recursive()
550 .elapsed()
551 .as_secs_f64();
552 let dual_module_ptr = parallel_dual_module.get_unit(primal_unit.unit_index);
553 let mut dual_unit = dual_module_ptr.write();
554 let partition_unit_info = &primal_unit.partition_info.units[primal_unit.unit_index];
555 let (owned_defect_range, _) = partitioned_syndrome_pattern.partition(partition_unit_info);
556 let interface_ptr = primal_unit.interface_ptr.clone();
557 if let Some((left_child_weak, right_child_weak)) = primal_unit.children.as_ref() {
558 {
559 for child_weak in [left_child_weak, right_child_weak] {
561 let child_ptr = child_weak.upgrade_force();
562 let mut child = child_ptr.write();
563 debug_assert!(child.is_active, "cannot fuse inactive children");
564 child.is_active = false;
565 }
566 }
567 primal_unit.fuse(&mut dual_unit);
568 if let Some(callback) = callback.as_mut() {
569 callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
571 }
572 primal_unit.break_matching_with_mirror(dual_unit.deref_mut());
573 for defect_index in owned_defect_range.whole_defect_range.iter() {
574 let defect_vertex = partitioned_syndrome_pattern.syndrome_pattern.defect_vertices[defect_index as usize];
575 primal_unit
576 .serial_module
577 .load_defect(defect_vertex, &interface_ptr, dual_unit.deref_mut());
578 }
579 primal_unit.serial_module.solve_step_callback_interface_loaded(
580 &interface_ptr,
581 dual_unit.deref_mut(),
582 |interface, dual_module, primal_module, group_max_update_length| {
583 if let Some(callback) = callback.as_mut() {
584 callback(interface, dual_module, primal_module, Some(group_max_update_length));
585 }
586 },
587 );
588 if let Some(callback) = callback.as_mut() {
589 callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
590 }
591 } else {
592 debug_assert!(primal_unit.is_active, "leaf must be active to be solved");
593 let syndrome_pattern = owned_defect_range.expand();
594 primal_unit.serial_module.solve_step_callback(
595 &interface_ptr,
596 &syndrome_pattern,
597 dual_unit.deref_mut(),
598 |interface, dual_module, primal_module, group_max_update_length| {
599 if let Some(callback) = callback.as_mut() {
600 callback(interface, dual_module, primal_module, Some(group_max_update_length));
601 }
602 },
603 );
604 if let Some(callback) = callback.as_mut() {
605 callback(&primal_unit.interface_ptr, &dual_unit, &primal_unit.serial_module, None);
606 }
607 }
608 primal_unit.is_active = true;
609 event_time.end = primal_module_parallel
610 .last_solve_start_time
611 .read_recursive()
612 .elapsed()
613 .as_secs_f64();
614 primal_unit.event_time = Some(event_time);
615 }
616
617 fn iterative_solve_step_callback<DualSerialModule: DualModuleImpl + Send + Sync, F>(
619 &self,
620 primal_module_parallel: &PrimalModuleParallel,
621 partitioned_syndrome_pattern: PartitionedSyndromePattern,
622 parallel_dual_module: &DualModuleParallel<DualSerialModule>,
623 callback: &mut Option<&mut F>,
624 ) where
625 F: FnMut(
626 &DualModuleInterfacePtr,
627 &DualModuleParallelUnit<DualSerialModule>,
628 &PrimalModuleSerialPtr,
629 Option<&GroupMaxUpdateLength>,
630 ) + Send
631 + Sync,
632 {
633 let primal_unit = self.read_recursive();
634 let debug_sequential = primal_module_parallel.config.debug_sequential;
636 if let Some((left_child_weak, right_child_weak)) = primal_unit.children.as_ref() {
637 debug_assert!(
639 !primal_unit.is_active,
640 "parent must be inactive at the time of solving children"
641 );
642 let partition_unit_info = &primal_unit.partition_info.units[primal_unit.unit_index];
643 let (_, (left_partitioned, right_partitioned)) = partitioned_syndrome_pattern.partition(partition_unit_info);
644 if debug_sequential {
645 left_child_weak.upgrade_force().iterative_solve_step_callback(
646 primal_module_parallel,
647 left_partitioned,
648 parallel_dual_module,
649 callback,
650 );
651 right_child_weak.upgrade_force().iterative_solve_step_callback(
652 primal_module_parallel,
653 right_partitioned,
654 parallel_dual_module,
655 callback,
656 );
657 } else {
658 rayon::join(
659 || {
660 left_child_weak
661 .upgrade_force()
662 .iterative_solve_step_callback::<DualSerialModule, F>(
663 primal_module_parallel,
664 left_partitioned,
665 parallel_dual_module,
666 &mut None,
667 )
668 },
669 || {
670 right_child_weak
671 .upgrade_force()
672 .iterative_solve_step_callback::<DualSerialModule, F>(
673 primal_module_parallel,
674 right_partitioned,
675 parallel_dual_module,
676 &mut None,
677 )
678 },
679 );
680 };
681 }
682 drop(primal_unit);
683 self.children_ready_solve(
684 primal_module_parallel,
685 partitioned_syndrome_pattern,
686 parallel_dual_module,
687 callback,
688 );
689 }
690}
691
692impl PrimalModuleParallelUnit {
693 pub fn fuse<DualSerialModule: DualModuleImpl + Send + Sync>(
696 &mut self,
697 dual_unit: &mut DualModuleParallelUnit<DualSerialModule>,
698 ) {
699 let (left_child_ptr, right_child_ptr) = (
700 self.children.as_ref().unwrap().0.upgrade_force(),
701 self.children.as_ref().unwrap().1.upgrade_force(),
702 );
703 let left_child = left_child_ptr.read_recursive();
704 let right_child = right_child_ptr.read_recursive();
705 dual_unit.fuse(&self.interface_ptr, (&left_child.interface_ptr, &right_child.interface_ptr));
706 self.serial_module.fuse(&left_child.serial_module, &right_child.serial_module);
707 }
708
709 #[allow(clippy::unnecessary_cast)]
711 pub fn break_matching_with_mirror(&mut self, dual_module: &mut impl DualModuleImpl) {
712 let mut possible_break = vec![];
714 let module = self.serial_module.read_recursive();
715 for node_index in module.possible_break.iter() {
716 let primal_node_ptr = module.get_node(*node_index);
717 if let Some(primal_node_ptr) = primal_node_ptr {
718 let mut primal_node = primal_node_ptr.write();
719 if let Some((MatchTarget::VirtualVertex(vertex_index), _)) = &primal_node.temporary_match {
720 if self.partition_info.vertex_to_owning_unit[*vertex_index as usize] == self.unit_index {
721 primal_node.temporary_match = None;
722 self.interface_ptr.set_grow_state(
723 &primal_node.origin.upgrade_force(),
724 DualNodeGrowState::Grow,
725 dual_module,
726 );
727 } else {
728 possible_break.push(*node_index);
730 }
731 }
732 }
733 }
734 drop(module);
735 self.serial_module.write().possible_break = possible_break;
736 }
737}
738
739impl PrimalModuleImpl for PrimalModuleParallelUnit {
740 fn new_empty(_initializer: &SolverInitializer) -> Self {
741 panic!("creating parallel unit directly from initializer is forbidden, use `PrimalModuleParallel::new` instead");
742 }
743
744 fn clear(&mut self) {
745 self.serial_module.clear();
746 self.interface_ptr.clear();
747 }
748
749 fn load(&mut self, interface_ptr: &DualModuleInterfacePtr) {
750 self.serial_module.load(interface_ptr)
751 }
752
753 fn load_defect_dual_node(&mut self, dual_node_ptr: &DualNodePtr) {
754 self.serial_module.load_defect_dual_node(dual_node_ptr)
755 }
756
757 fn resolve<D: DualModuleImpl>(
758 &mut self,
759 group_max_update_length: GroupMaxUpdateLength,
760 interface: &DualModuleInterfacePtr,
761 dual_module: &mut D,
762 ) {
763 self.serial_module.resolve(group_max_update_length, interface, dual_module)
764 }
765
766 fn intermediate_matching<D: DualModuleImpl>(
767 &mut self,
768 interface: &DualModuleInterfacePtr,
769 dual_module: &mut D,
770 ) -> IntermediateMatching {
771 self.serial_module.intermediate_matching(interface, dual_module)
772 }
773}
774
775#[cfg(test)]
776pub mod tests {
777 use super::super::dual_module_serial::*;
778 use super::super::example_codes::*;
779 use super::*;
780
781 pub fn primal_module_parallel_basic_standard_syndrome_optional_viz<F>(
782 code: impl ExampleCode,
783 visualize_filename: Option<String>,
784 defect_vertices: Vec<VertexIndex>,
785 final_dual: Weight,
786 partition_func: F,
787 reordered_vertices: Option<Vec<VertexIndex>>,
788 ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
789 where
790 F: Fn(&SolverInitializer, &mut PartitionConfig),
791 {
792 primal_module_parallel_basic_standard_syndrome_optional_viz_config(
793 code,
794 visualize_filename,
795 defect_vertices,
796 final_dual,
797 partition_func,
798 reordered_vertices,
799 None,
800 )
801 }
802
803 pub fn primal_module_parallel_basic_standard_syndrome_optional_viz_config<F>(
804 mut code: impl ExampleCode,
805 visualize_filename: Option<String>,
806 mut defect_vertices: Vec<VertexIndex>,
807 final_dual: Weight,
808 partition_func: F,
809 reordered_vertices: Option<Vec<VertexIndex>>,
810 primal_config_json: Option<serde_json::Value>,
811 ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
812 where
813 F: Fn(&SolverInitializer, &mut PartitionConfig),
814 {
815 println!("{defect_vertices:?}");
816 if let Some(reordered_vertices) = &reordered_vertices {
817 code.reorder_vertices(reordered_vertices);
818 defect_vertices = translated_defect_to_reordered(reordered_vertices, &defect_vertices);
819 }
820 let mut visualizer = match visualize_filename.as_ref() {
821 Some(visualize_filename) => {
822 let visualizer = Visualizer::new(
823 Some(visualize_data_folder() + visualize_filename.as_str()),
824 code.get_positions(),
825 true,
826 )
827 .unwrap();
828 print_visualize_link(visualize_filename.clone());
829 Some(visualizer)
830 }
831 None => None,
832 };
833 let initializer = code.get_initializer();
834 let mut partition_config = PartitionConfig::new(initializer.vertex_num);
835 partition_func(&initializer, &mut partition_config);
836 let partition_info = partition_config.info();
837 let mut dual_module =
838 DualModuleParallel::new_config(&initializer, &partition_info, DualModuleParallelConfig::default());
839 let primal_config = if let Some(value) = primal_config_json {
840 serde_json::from_value(value).unwrap()
841 } else {
842 PrimalModuleParallelConfig {
843 debug_sequential: true,
844 ..Default::default()
845 }
846 };
847 let mut primal_module = PrimalModuleParallel::new_config(&initializer, &partition_info, primal_config.clone());
848 code.set_defect_vertices(&defect_vertices);
849 primal_module.parallel_solve_visualizer(&code.get_syndrome(), &dual_module, visualizer.as_mut());
850 let useless_interface_ptr = DualModuleInterfacePtr::new_empty(); let perfect_matching = primal_module.perfect_matching(&useless_interface_ptr, &mut dual_module);
852 let mut subgraph_builder = SubGraphBuilder::new(&initializer);
853 subgraph_builder.load_perfect_matching(&perfect_matching);
854 let subgraph = subgraph_builder.get_subgraph();
855 if let Some(visualizer) = visualizer.as_mut() {
856 let last_interface_ptr = &primal_module.units.last().unwrap().read_recursive().interface_ptr;
857 visualizer
858 .snapshot_combined(
859 "perfect matching and subgraph".to_string(),
860 vec![
861 last_interface_ptr,
862 &dual_module,
863 &perfect_matching,
864 &VisualizeSubgraph::new(&subgraph),
865 ],
866 )
867 .unwrap();
868 }
869 let sum_dual_variables = primal_module
870 .units
871 .last()
872 .unwrap()
873 .read_recursive()
874 .interface_ptr
875 .sum_dual_variables();
876 if primal_config.max_tree_size == usize::MAX {
877 assert_eq!(
879 sum_dual_variables,
880 subgraph_builder.total_weight(),
881 "unmatched sum dual variables"
882 );
883 }
884 assert_eq!(sum_dual_variables, final_dual * 2, "unexpected final dual variable sum");
885 (primal_module, dual_module)
886 }
887
888 pub fn primal_module_parallel_standard_syndrome<F>(
889 code: impl ExampleCode,
890 visualize_filename: String,
891 defect_vertices: Vec<VertexIndex>,
892 final_dual: Weight,
893 partition_func: F,
894 reordered_vertices: Option<Vec<VertexIndex>>,
895 ) -> (PrimalModuleParallel, DualModuleParallel<DualModuleSerial>)
896 where
897 F: Fn(&SolverInitializer, &mut PartitionConfig),
898 {
899 primal_module_parallel_basic_standard_syndrome_optional_viz(
900 code,
901 Some(visualize_filename),
902 defect_vertices,
903 final_dual,
904 partition_func,
905 reordered_vertices,
906 )
907 }
908
909 #[test]
911 fn primal_module_parallel_basic_1() {
912 let visualize_filename = "primal_module_parallel_basic_1.json".to_string();
914 let defect_vertices = vec![39, 52, 63, 90, 100];
915 let half_weight = 500;
916 primal_module_parallel_standard_syndrome(
917 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
918 visualize_filename,
919 defect_vertices,
920 9 * half_weight,
921 |initializer, _config| {
922 println!("initializer: {initializer:?}");
923 },
924 None,
925 );
926 }
927
928 #[test]
930 fn primal_module_parallel_basic_2() {
931 let visualize_filename = "primal_module_parallel_basic_2.json".to_string();
933 let defect_vertices = vec![39, 52, 63, 90, 100];
934 let half_weight = 500;
935 primal_module_parallel_standard_syndrome(
936 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
937 visualize_filename,
938 defect_vertices,
939 9 * half_weight,
940 |_initializer, config| {
941 config.partitions = vec![
942 VertexRange::new(0, 72), VertexRange::new(84, 132), ];
945 config.fusions = vec![
946 (0, 1), ];
948 },
949 None,
950 );
951 }
952
953 #[test]
955 fn primal_module_parallel_basic_3() {
956 let visualize_filename = "primal_module_parallel_basic_3.json".to_string();
958 let defect_vertices = vec![39, 52, 63, 90, 100];
959 let half_weight = 500;
960 primal_module_parallel_standard_syndrome(
961 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
962 visualize_filename,
963 defect_vertices,
964 9 * half_weight,
965 |_initializer, config| {
966 config.partitions = vec![
967 VertexRange::new(0, 60), VertexRange::new(72, 132), ];
970 config.fusions = vec![
971 (0, 1), ];
973 },
974 None,
975 );
976 }
977
978 #[test]
980 fn primal_module_parallel_basic_4() {
981 let visualize_filename = "primal_module_parallel_basic_4.json".to_string();
983 let defect_vertices = vec![39, 52, 63, 90, 100]; let half_weight = 500;
986 primal_module_parallel_standard_syndrome(
987 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
988 visualize_filename,
989 defect_vertices,
990 9 * half_weight,
991 |_initializer, config| {
992 config.partitions = vec![
993 VertexRange::new(0, 36),
994 VertexRange::new(42, 72),
995 VertexRange::new(84, 108),
996 VertexRange::new(112, 132),
997 ];
998 config.fusions = vec![(0, 1), (2, 3), (4, 5)];
999 },
1000 Some({
1001 let mut reordered_vertices = vec![];
1002 let split_horizontal = 6;
1003 let split_vertical = 5;
1004 for i in 0..split_horizontal {
1005 for j in 0..split_vertical {
1007 reordered_vertices.push(i * 12 + j);
1008 }
1009 reordered_vertices.push(i * 12 + 11);
1010 }
1011 for i in 0..split_horizontal {
1012 reordered_vertices.push(i * 12 + split_vertical);
1014 }
1015 for i in 0..split_horizontal {
1016 for j in (split_vertical + 1)..10 {
1018 reordered_vertices.push(i * 12 + j);
1019 }
1020 reordered_vertices.push(i * 12 + 10);
1021 }
1022 {
1023 for j in 0..12 {
1025 reordered_vertices.push(split_horizontal * 12 + j);
1026 }
1027 }
1028 for i in (split_horizontal + 1)..11 {
1029 for j in 0..split_vertical {
1031 reordered_vertices.push(i * 12 + j);
1032 }
1033 reordered_vertices.push(i * 12 + 11);
1034 }
1035 for i in (split_horizontal + 1)..11 {
1036 reordered_vertices.push(i * 12 + split_vertical);
1038 }
1039 for i in (split_horizontal + 1)..11 {
1040 for j in (split_vertical + 1)..10 {
1042 reordered_vertices.push(i * 12 + j);
1043 }
1044 reordered_vertices.push(i * 12 + 10);
1045 }
1046 reordered_vertices
1047 }),
1048 );
1049 }
1050
1051 #[test]
1053 fn primal_module_parallel_basic_5() {
1054 let visualize_filename = "primal_module_parallel_basic_5.json".to_string();
1056 let defect_vertices = vec![39, 52, 63, 90, 100]; let half_weight = 500;
1059 primal_module_parallel_standard_syndrome(
1060 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
1061 visualize_filename,
1062 defect_vertices,
1063 9 * half_weight,
1064 |_initializer, config| {
1065 config.partitions = vec![
1066 VertexRange::new(0, 25),
1067 VertexRange::new(30, 60),
1068 VertexRange::new(72, 97),
1069 VertexRange::new(102, 132),
1070 ];
1071 config.fusions = vec![(0, 1), (2, 3), (4, 5)];
1072 },
1073 Some({
1074 let mut reordered_vertices = vec![];
1075 let split_horizontal = 5;
1076 let split_vertical = 4;
1077 for i in 0..split_horizontal {
1078 for j in 0..split_vertical {
1080 reordered_vertices.push(i * 12 + j);
1081 }
1082 reordered_vertices.push(i * 12 + 11);
1083 }
1084 for i in 0..split_horizontal {
1085 reordered_vertices.push(i * 12 + split_vertical);
1087 }
1088 for i in 0..split_horizontal {
1089 for j in (split_vertical + 1)..10 {
1091 reordered_vertices.push(i * 12 + j);
1092 }
1093 reordered_vertices.push(i * 12 + 10);
1094 }
1095 {
1096 for j in 0..12 {
1098 reordered_vertices.push(split_horizontal * 12 + j);
1099 }
1100 }
1101 for i in (split_horizontal + 1)..11 {
1102 for j in 0..split_vertical {
1104 reordered_vertices.push(i * 12 + j);
1105 }
1106 reordered_vertices.push(i * 12 + 11);
1107 }
1108 for i in (split_horizontal + 1)..11 {
1109 reordered_vertices.push(i * 12 + split_vertical);
1111 }
1112 for i in (split_horizontal + 1)..11 {
1113 for j in (split_vertical + 1)..10 {
1115 reordered_vertices.push(i * 12 + j);
1116 }
1117 reordered_vertices.push(i * 12 + 10);
1118 }
1119 reordered_vertices
1120 }),
1121 );
1122 }
1123
1124 fn primal_module_parallel_debug_planar_code_common(
1125 d: VertexNum,
1126 visualize_filename: String,
1127 defect_vertices: Vec<VertexIndex>,
1128 final_dual: Weight,
1129 ) {
1130 let half_weight = 500;
1131 let split_horizontal = (d + 1) / 2;
1132 let row_count = d + 1;
1133 primal_module_parallel_standard_syndrome(
1134 CodeCapacityPlanarCode::new(d, 0.1, half_weight),
1135 visualize_filename,
1136 defect_vertices,
1137 final_dual * half_weight,
1138 |initializer, config| {
1139 config.partitions = vec![
1140 VertexRange::new(0, split_horizontal * row_count),
1141 VertexRange::new((split_horizontal + 1) * row_count, initializer.vertex_num),
1142 ];
1143 config.fusions = vec![(0, 1)];
1144 },
1145 None,
1146 );
1147 }
1148
1149 #[test]
1155 fn primal_module_parallel_debug_1() {
1156 let visualize_filename = "primal_module_parallel_debug_1.json".to_string();
1158 let defect_vertices = vec![88, 89, 102, 103, 105, 106, 118, 120, 122, 134, 138]; primal_module_parallel_debug_planar_code_common(15, visualize_filename, defect_vertices, 10);
1160 }
1161
1162 #[test]
1164 fn primal_module_parallel_union_find_basic_1() {
1165 let visualize_filename = "primal_module_parallel_union_find_basic_1.json".to_string();
1167 let defect_vertices = vec![51, 52, 53, 88];
1168 let half_weight = 500;
1169 primal_module_parallel_basic_standard_syndrome_optional_viz_config(
1170 CodeCapacityPlanarCode::new(11, 0.1, half_weight),
1171 Some(visualize_filename),
1172 defect_vertices,
1173 4 * half_weight,
1174 |_initializer, config| {
1175 config.partitions = vec![
1176 VertexRange::new(0, 72), VertexRange::new(84, 132), ];
1179 config.fusions = vec![
1180 (0, 1), ];
1182 },
1183 None,
1184 Some(json!({ "max_tree_size": 0, "debug_sequential": true })),
1185 );
1186 }
1187}