1use crate::distributed_simulator::{
16 CommunicationConfig, CommunicationPattern, DistributedSimulatorConfig, DistributionStrategy,
17 FaultToleranceConfig, LoadBalancingConfig, LoadBalancingStrategy, NetworkConfig,
18};
19use crate::large_scale_simulator::{LargeScaleSimulatorConfig, QuantumStateRepresentation};
20use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
21use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
22use scirs2_core::parallel_ops::{IndexedParallelIterator, ParallelIterator};
23use scirs2_core::Complex64;
24use serde::{Deserialize, Serialize};
25use std::collections::{BTreeMap, HashMap};
26use std::sync::{Arc, Mutex, RwLock};
27use std::time::{Duration, Instant};
28
29#[derive(Debug)]
35pub struct MPIQuantumSimulator {
36 communicator: MPICommunicator,
38 local_state: Arc<RwLock<LocalQuantumState>>,
40 config: MPISimulatorConfig,
42 stats: Arc<Mutex<MPISimulatorStats>>,
44 sync_manager: StateSynchronizationManager,
46 gate_handler: GateDistributionHandler,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MPISimulatorConfig {
53 pub total_qubits: usize,
55 pub distribution_strategy: MPIDistributionStrategy,
57 pub collective_optimization: CollectiveOptimization,
59 pub overlap_config: CommunicationOverlapConfig,
61 pub checkpoint_config: CheckpointConfig,
63 pub memory_config: MemoryConfig,
65}
66
67impl Default for MPISimulatorConfig {
68 fn default() -> Self {
69 Self {
70 total_qubits: 20,
71 distribution_strategy: MPIDistributionStrategy::AmplitudePartition,
72 collective_optimization: CollectiveOptimization::default(),
73 overlap_config: CommunicationOverlapConfig::default(),
74 checkpoint_config: CheckpointConfig::default(),
75 memory_config: MemoryConfig::default(),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum MPIDistributionStrategy {
83 AmplitudePartition,
85 QubitPartition,
87 HybridPartition,
89 GateAwarePartition,
91 HilbertCurvePartition,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct CollectiveOptimization {
98 pub use_nonblocking: bool,
100 pub enable_fusion: bool,
102 pub buffer_size: usize,
104 pub allreduce_algorithm: AllreduceAlgorithm,
106 pub broadcast_algorithm: BroadcastAlgorithm,
108}
109
110impl Default for CollectiveOptimization {
111 fn default() -> Self {
112 Self {
113 use_nonblocking: true,
114 enable_fusion: true,
115 buffer_size: 16 * 1024 * 1024, allreduce_algorithm: AllreduceAlgorithm::RecursiveDoubling,
117 broadcast_algorithm: BroadcastAlgorithm::BinomialTree,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
124pub enum AllreduceAlgorithm {
125 Ring,
127 RecursiveDoubling,
129 Rabenseifner,
131 Automatic,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
137pub enum BroadcastAlgorithm {
138 BinomialTree,
140 ScatterAllgather,
142 Pipeline,
144 Automatic,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct CommunicationOverlapConfig {
151 pub enable_overlap: bool,
153 pub pipeline_stages: usize,
155 pub prefetch_distance: usize,
157}
158
159impl Default for CommunicationOverlapConfig {
160 fn default() -> Self {
161 Self {
162 enable_overlap: true,
163 pipeline_stages: 4,
164 prefetch_distance: 2,
165 }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct CheckpointConfig {
172 pub enable: bool,
174 pub interval: usize,
176 pub storage_path: String,
178 pub use_compression: bool,
180}
181
182impl Default for CheckpointConfig {
183 fn default() -> Self {
184 Self {
185 enable: false,
186 interval: 1000,
187 storage_path: "/tmp/quantum_checkpoint".to_string(),
188 use_compression: true,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct MemoryConfig {
196 pub max_memory_per_node: usize,
198 pub enable_pooling: bool,
200 pub pool_size: usize,
202}
203
204impl Default for MemoryConfig {
205 fn default() -> Self {
206 Self {
207 max_memory_per_node: 64 * 1024 * 1024 * 1024, enable_pooling: true,
209 pool_size: 1024 * 1024 * 1024, }
211 }
212}
213
214#[derive(Debug)]
216pub struct MPICommunicator {
217 rank: usize,
219 size: usize,
221 backend: MPIBackend,
223 buffer_pool: Arc<Mutex<Vec<Vec<u8>>>>,
225 pending_requests: Arc<Mutex<Vec<MPIRequest>>>,
227}
228
229#[derive(Debug, Clone)]
231pub enum MPIBackend {
232 Simulated(SimulatedMPIBackend),
234 #[cfg(feature = "mpi")]
236 Native(NativeMPIBackend),
237 TCP(TCPMPIBackend),
239}
240
241#[derive(Debug, Clone)]
243pub struct SimulatedMPIBackend {
244 shared_state: Arc<RwLock<SimulatedMPIState>>,
246}
247
248#[derive(Debug, Default)]
250pub struct SimulatedMPIState {
251 message_buffers: HashMap<usize, Vec<Vec<u8>>>,
253 barrier_count: usize,
255 collective_results: HashMap<String, Vec<u8>>,
257}
258
259#[derive(Debug, Clone)]
261pub struct TCPMPIBackend {
262 connections: Arc<RwLock<HashMap<usize, std::net::SocketAddr>>>,
264}
265
266#[cfg(feature = "mpi")]
268#[derive(Debug, Clone)]
269pub struct NativeMPIBackend {
270 comm_handle: usize,
272}
273
274#[derive(Debug)]
276pub struct MPIRequest {
277 id: usize,
279 request_type: MPIRequestType,
281 completed: Arc<Mutex<bool>>,
283}
284
285#[derive(Debug, Clone)]
287pub enum MPIRequestType {
288 Send { dest: usize, tag: i32 },
289 Recv { source: usize, tag: i32 },
290 Collective { operation: String },
291}
292
293#[derive(Debug)]
295pub struct LocalQuantumState {
296 amplitudes: Array1<Complex64>,
298 global_offset: usize,
300 local_qubits: Vec<usize>,
302 ghost_cells: GhostCells,
304}
305
306#[derive(Debug, Clone, Default)]
308pub struct GhostCells {
309 left: Vec<Complex64>,
311 right: Vec<Complex64>,
313 width: usize,
315}
316
317#[derive(Debug, Clone, Default)]
319pub struct MPISimulatorStats {
320 pub gates_executed: u64,
322 pub communication_time: Duration,
324 pub computation_time: Duration,
326 pub sync_count: u64,
328 pub bytes_sent: u64,
330 pub bytes_received: u64,
332 pub load_imbalance: f64,
334}
335
336#[derive(Debug)]
338pub struct StateSynchronizationManager {
339 strategy: SyncStrategy,
341 pending: Arc<Mutex<Vec<SyncOperation>>>,
343}
344
345#[derive(Debug, Clone, Copy)]
347pub enum SyncStrategy {
348 Eager,
350 Lazy,
352 Adaptive,
354}
355
356#[derive(Debug, Clone)]
358pub struct SyncOperation {
359 qubits: Vec<usize>,
361 op_type: SyncOpType,
363}
364
365#[derive(Debug, Clone)]
367pub enum SyncOpType {
368 BoundaryExchange,
369 GlobalReduction,
370 PartitionSwap,
371}
372
373#[derive(Debug)]
375pub struct GateDistributionHandler {
376 routing_table: Arc<RwLock<HashMap<usize, usize>>>,
378 gate_classifier: GateClassifier,
380}
381
382#[derive(Debug)]
384pub struct GateClassifier {
385 local_qubits: Vec<usize>,
387}
388
389impl MPIQuantumSimulator {
390 pub fn new(config: MPISimulatorConfig) -> QuantRS2Result<Self> {
392 let communicator = MPICommunicator::new()?;
394
395 let total_amplitudes = 1usize << config.total_qubits;
397 let local_size = total_amplitudes / communicator.size;
398 let global_offset = communicator.rank * local_size;
399
400 let local_state = LocalQuantumState {
402 amplitudes: Array1::zeros(local_size),
403 global_offset,
404 local_qubits: Self::calculate_local_qubits(
405 config.total_qubits,
406 communicator.rank,
407 communicator.size,
408 ),
409 ghost_cells: GhostCells::default(),
410 };
411
412 let sync_manager = StateSynchronizationManager {
414 strategy: SyncStrategy::Adaptive,
415 pending: Arc::new(Mutex::new(Vec::new())),
416 };
417
418 let gate_handler = GateDistributionHandler {
420 routing_table: Arc::new(RwLock::new(HashMap::new())),
421 gate_classifier: GateClassifier {
422 local_qubits: local_state.local_qubits.clone(),
423 },
424 };
425
426 Ok(Self {
427 communicator,
428 local_state: Arc::new(RwLock::new(local_state)),
429 config,
430 stats: Arc::new(Mutex::new(MPISimulatorStats::default())),
431 sync_manager,
432 gate_handler,
433 })
434 }
435
436 fn calculate_local_qubits(total_qubits: usize, rank: usize, size: usize) -> Vec<usize> {
438 let partition_bits = (size as f64).log2().ceil() as usize;
440 let local_bits = total_qubits - partition_bits;
441
442 (0..local_bits).collect()
444 }
445
446 pub fn initialize(&mut self) -> QuantRS2Result<()> {
448 let mut state = self
449 .local_state
450 .write()
451 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
452
453 state.amplitudes.fill(Complex64::new(0.0, 0.0));
455
456 if self.communicator.rank == 0 {
458 state.amplitudes[0] = Complex64::new(1.0, 0.0);
459 }
460
461 Ok(())
462 }
463
464 pub fn apply_single_qubit_gate(
466 &mut self,
467 qubit: usize,
468 gate_matrix: &Array2<Complex64>,
469 ) -> QuantRS2Result<()> {
470 let start = Instant::now();
471
472 let state = self
474 .local_state
475 .read()
476 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
477
478 if state.local_qubits.contains(&qubit) {
479 drop(state);
481 self.apply_local_single_qubit_gate(qubit, gate_matrix)?;
482 } else {
483 drop(state);
485 self.apply_distributed_single_qubit_gate(qubit, gate_matrix)?;
486 }
487
488 let mut stats = self
490 .stats
491 .lock()
492 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
493 stats.gates_executed += 1;
494 stats.computation_time += start.elapsed();
495
496 Ok(())
497 }
498
499 fn apply_local_single_qubit_gate(
501 &self,
502 qubit: usize,
503 gate_matrix: &Array2<Complex64>,
504 ) -> QuantRS2Result<()> {
505 let mut state = self
506 .local_state
507 .write()
508 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
509
510 let n = state.amplitudes.len();
511 let stride = 1 << qubit;
512
513 let amplitudes = state.amplitudes.as_slice_mut().ok_or_else(|| {
515 QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
516 })?;
517
518 for i in 0..n / 2 {
520 let i0 = (i / stride) * (2 * stride) + (i % stride);
521 let i1 = i0 + stride;
522
523 let a0 = amplitudes[i0];
524 let a1 = amplitudes[i1];
525
526 amplitudes[i0] = gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1;
527 amplitudes[i1] = gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1;
528 }
529
530 Ok(())
531 }
532
533 fn apply_distributed_single_qubit_gate(
535 &self,
536 qubit: usize,
537 gate_matrix: &Array2<Complex64>,
538 ) -> QuantRS2Result<()> {
539 let partition_bit = qubit - self.gate_handler.gate_classifier.local_qubits.len();
541 let partner = self.communicator.rank ^ (1 << partition_bit);
542
543 self.exchange_boundary_data(partner)?;
545
546 let mut state = self
548 .local_state
549 .write()
550 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
551
552 let n = state.amplitudes.len();
553 let local_qubits = state.local_qubits.len();
554 let local_stride = 1 << local_qubits;
555
556 let is_lower = (self.communicator.rank >> partition_bit) & 1 == 0;
558
559 for i in 0..n {
560 let global_i = state.global_offset + i;
561 let partner_i = global_i ^ local_stride;
562
563 let partner_amp = if is_lower {
565 state
566 .ghost_cells
567 .right
568 .get(i)
569 .copied()
570 .unwrap_or(Complex64::new(0.0, 0.0))
571 } else {
572 state
573 .ghost_cells
574 .left
575 .get(i)
576 .copied()
577 .unwrap_or(Complex64::new(0.0, 0.0))
578 };
579
580 let local_amp = state.amplitudes[i];
581
582 let (a0, a1) = if is_lower {
584 (local_amp, partner_amp)
585 } else {
586 (partner_amp, local_amp)
587 };
588
589 let new_amp = if is_lower {
590 gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1
591 } else {
592 gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1
593 };
594
595 state.amplitudes[i] = new_amp;
596 }
597
598 Ok(())
599 }
600
601 fn exchange_boundary_data(&self, partner: usize) -> QuantRS2Result<()> {
603 let state = self
604 .local_state
605 .read()
606 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
607
608 let send_data: Vec<Complex64> = state.amplitudes.iter().copied().collect();
610 drop(state);
611
612 let recv_data = self.communicator.sendrecv(&send_data, partner)?;
614
615 let mut state = self
617 .local_state
618 .write()
619 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
620
621 if self.communicator.rank < partner {
622 state.ghost_cells.right = recv_data;
623 } else {
624 state.ghost_cells.left = recv_data;
625 }
626
627 Ok(())
628 }
629
630 pub fn apply_two_qubit_gate(
632 &mut self,
633 control: usize,
634 target: usize,
635 gate_matrix: &Array2<Complex64>,
636 ) -> QuantRS2Result<()> {
637 let start = Instant::now();
638
639 let state = self
640 .local_state
641 .read()
642 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
643
644 let control_local = state.local_qubits.contains(&control);
645 let target_local = state.local_qubits.contains(&target);
646 drop(state);
647
648 match (control_local, target_local) {
649 (true, true) => {
650 self.apply_local_two_qubit_gate(control, target, gate_matrix)?;
652 }
653 (true, false) | (false, true) => {
654 self.apply_partial_distributed_gate(control, target, gate_matrix)?;
656 }
657 (false, false) => {
658 self.apply_full_distributed_gate(control, target, gate_matrix)?;
660 }
661 }
662
663 let mut stats = self
665 .stats
666 .lock()
667 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
668 stats.gates_executed += 1;
669 stats.computation_time += start.elapsed();
670
671 Ok(())
672 }
673
674 fn apply_local_two_qubit_gate(
676 &self,
677 control: usize,
678 target: usize,
679 gate_matrix: &Array2<Complex64>,
680 ) -> QuantRS2Result<()> {
681 let mut state = self
682 .local_state
683 .write()
684 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
685
686 let n = state.amplitudes.len();
687 let control_stride = 1 << control;
688 let target_stride = 1 << target;
689
690 let (low_stride, high_stride) = if control < target {
692 (control_stride, target_stride)
693 } else {
694 (target_stride, control_stride)
695 };
696
697 for i in 0..n / 4 {
699 let base = (i / low_stride) * (2 * low_stride) + (i % low_stride);
701 let base = (base / high_stride) * (2 * high_stride) + (base % high_stride);
702
703 let i00 = base;
705 let i01 = base + target_stride;
706 let i10 = base + control_stride;
707 let i11 = base + control_stride + target_stride;
708
709 let a00 = state.amplitudes[i00];
711 let a01 = state.amplitudes[i01];
712 let a10 = state.amplitudes[i10];
713 let a11 = state.amplitudes[i11];
714
715 state.amplitudes[i00] = gate_matrix[[0, 0]] * a00
717 + gate_matrix[[0, 1]] * a01
718 + gate_matrix[[0, 2]] * a10
719 + gate_matrix[[0, 3]] * a11;
720 state.amplitudes[i01] = gate_matrix[[1, 0]] * a00
721 + gate_matrix[[1, 1]] * a01
722 + gate_matrix[[1, 2]] * a10
723 + gate_matrix[[1, 3]] * a11;
724 state.amplitudes[i10] = gate_matrix[[2, 0]] * a00
725 + gate_matrix[[2, 1]] * a01
726 + gate_matrix[[2, 2]] * a10
727 + gate_matrix[[2, 3]] * a11;
728 state.amplitudes[i11] = gate_matrix[[3, 0]] * a00
729 + gate_matrix[[3, 1]] * a01
730 + gate_matrix[[3, 2]] * a10
731 + gate_matrix[[3, 3]] * a11;
732 }
733
734 Ok(())
735 }
736
737 fn apply_partial_distributed_gate(
739 &self,
740 control: usize,
741 target: usize,
742 gate_matrix: &Array2<Complex64>,
743 ) -> QuantRS2Result<()> {
744 let state = self
746 .local_state
747 .read()
748 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
749
750 let (local_qubit, remote_qubit) = if state.local_qubits.contains(&control) {
751 (control, target)
752 } else {
753 (target, control)
754 };
755 drop(state);
756
757 let partition_bit = remote_qubit - self.gate_handler.gate_classifier.local_qubits.len();
759 let partner = self.communicator.rank ^ (1 << partition_bit);
760
761 self.exchange_boundary_data(partner)?;
763
764 let mut state = self
768 .local_state
769 .write()
770 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
771
772 let n = state.amplitudes.len();
773 let local_stride = 1 << local_qubit;
774
775 for i in 0..n / 2 {
776 let i0 = (i / local_stride) * (2 * local_stride) + (i % local_stride);
777 let i1 = i0 + local_stride;
778
779 let a0 = state.amplitudes[i0];
780 let a1 = state.amplitudes[i1];
781
782 state.amplitudes[i0] = gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1;
785 state.amplitudes[i1] = gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1;
786 }
787
788 Ok(())
789 }
790
791 fn apply_full_distributed_gate(
793 &self,
794 control: usize,
795 target: usize,
796 gate_matrix: &Array2<Complex64>,
797 ) -> QuantRS2Result<()> {
798 let local_qubits_len = self.gate_handler.gate_classifier.local_qubits.len();
801
802 let control_partition = control - local_qubits_len;
803 let target_partition = target - local_qubits_len;
804
805 let control_partner = self.communicator.rank ^ (1 << control_partition);
807 self.exchange_boundary_data(control_partner)?;
808
809 let target_partner = self.communicator.rank ^ (1 << target_partition);
811 self.exchange_boundary_data(target_partner)?;
812
813 let mut state = self
815 .local_state
816 .write()
817 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
818
819 let _ = gate_matrix; Ok(())
823 }
824
825 pub const fn barrier(&self) -> QuantRS2Result<()> {
827 self.communicator.barrier()
828 }
829
830 pub fn get_probability_distribution(&self) -> QuantRS2Result<Vec<f64>> {
832 let state = self
833 .local_state
834 .read()
835 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
836
837 let local_probs: Vec<f64> = state.amplitudes.iter().map(|a| (a * a.conj()).re).collect();
839
840 drop(state);
841
842 let global_probs = self.communicator.gather(&local_probs, 0)?;
844
845 Ok(global_probs)
846 }
847
848 pub fn measure_all(&self) -> QuantRS2Result<Vec<bool>> {
850 let probs = self.get_probability_distribution()?;
852
853 if self.communicator.rank == 0 {
855 let mut rng = scirs2_core::random::thread_rng();
857 let random: f64 = scirs2_core::random::Rng::gen(&mut rng);
858
859 let mut cumulative = 0.0;
860 let mut result_idx = 0;
861
862 for (i, &prob) in probs.iter().enumerate() {
863 cumulative += prob;
864 if random < cumulative {
865 result_idx = i;
866 break;
867 }
868 }
869
870 let result: Vec<bool> = (0..self.config.total_qubits)
872 .map(|i| (result_idx >> i) & 1 == 1)
873 .collect();
874
875 self.communicator.broadcast(&result, 0)
877 } else {
878 self.communicator.broadcast(&[], 0)
880 }
881 }
882
883 pub fn get_local_state(&self) -> QuantRS2Result<Array1<Complex64>> {
885 let state = self
886 .local_state
887 .read()
888 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
889 Ok(state.amplitudes.clone())
890 }
891
892 pub fn get_stats(&self) -> QuantRS2Result<MPISimulatorStats> {
894 let stats = self
895 .stats
896 .lock()
897 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
898 Ok(stats.clone())
899 }
900
901 pub fn reset(&mut self) -> QuantRS2Result<()> {
903 self.initialize()?;
904
905 let mut stats = self
907 .stats
908 .lock()
909 .map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
910 *stats = MPISimulatorStats::default();
911
912 Ok(())
913 }
914}
915
916impl MPICommunicator {
917 pub fn new() -> QuantRS2Result<Self> {
919 let shared_state = Arc::new(RwLock::new(SimulatedMPIState::default()));
921 let backend = MPIBackend::Simulated(SimulatedMPIBackend { shared_state });
922
923 Ok(Self {
924 rank: 0,
925 size: 1,
926 backend,
927 buffer_pool: Arc::new(Mutex::new(Vec::new())),
928 pending_requests: Arc::new(Mutex::new(Vec::new())),
929 })
930 }
931
932 #[must_use]
934 pub fn with_config(rank: usize, size: usize, backend: MPIBackend) -> Self {
935 Self {
936 rank,
937 size,
938 backend,
939 buffer_pool: Arc::new(Mutex::new(Vec::new())),
940 pending_requests: Arc::new(Mutex::new(Vec::new())),
941 }
942 }
943
944 #[must_use]
946 pub const fn rank(&self) -> usize {
947 self.rank
948 }
949
950 #[must_use]
952 pub const fn size(&self) -> usize {
953 self.size
954 }
955
956 pub const fn barrier(&self) -> QuantRS2Result<()> {
958 match &self.backend {
959 MPIBackend::Simulated(_) => {
960 Ok(())
962 }
963 MPIBackend::TCP(_) => {
964 Ok(())
966 }
967 #[cfg(feature = "mpi")]
968 MPIBackend::Native(_) => {
969 Ok(())
971 }
972 }
973 }
974
975 pub fn sendrecv(
977 &self,
978 send_data: &[Complex64],
979 partner: usize,
980 ) -> QuantRS2Result<Vec<Complex64>> {
981 match &self.backend {
982 MPIBackend::Simulated(_) => {
983 Ok(send_data.to_vec())
985 }
986 MPIBackend::TCP(_) => {
987 Ok(send_data.to_vec())
989 }
990 #[cfg(feature = "mpi")]
991 MPIBackend::Native(_) => {
992 Ok(send_data.to_vec())
994 }
995 }
996 }
997
998 pub fn gather<T: Clone>(&self, local_data: &[T], root: usize) -> QuantRS2Result<Vec<T>> {
1000 match &self.backend {
1001 MPIBackend::Simulated(_) => {
1002 Ok(local_data.to_vec())
1004 }
1005 MPIBackend::TCP(_) => {
1006 Ok(local_data.to_vec())
1008 }
1009 #[cfg(feature = "mpi")]
1010 MPIBackend::Native(_) => {
1011 Ok(local_data.to_vec())
1013 }
1014 }
1015 }
1016
1017 pub fn broadcast<T: Clone>(&self, data: &[T], root: usize) -> QuantRS2Result<Vec<T>> {
1019 match &self.backend {
1020 MPIBackend::Simulated(_) => {
1021 Ok(data.to_vec())
1023 }
1024 MPIBackend::TCP(_) => {
1025 Ok(data.to_vec())
1027 }
1028 #[cfg(feature = "mpi")]
1029 MPIBackend::Native(_) => {
1030 Ok(data.to_vec())
1032 }
1033 }
1034 }
1035
1036 pub fn allreduce(&self, local_data: &[f64], op: ReduceOp) -> QuantRS2Result<Vec<f64>> {
1038 match &self.backend {
1039 MPIBackend::Simulated(_) => {
1040 Ok(local_data.to_vec())
1042 }
1043 MPIBackend::TCP(_) => {
1044 Ok(local_data.to_vec())
1046 }
1047 #[cfg(feature = "mpi")]
1048 MPIBackend::Native(_) => {
1049 Ok(local_data.to_vec())
1051 }
1052 }
1053 }
1054}
1055
1056#[derive(Debug, Clone, Copy)]
1058pub enum ReduceOp {
1059 Sum,
1060 Max,
1061 Min,
1062 Prod,
1063}
1064
1065#[derive(Debug, Clone)]
1067pub struct MPISimulationResult {
1068 pub measurements: Vec<bool>,
1070 pub probabilities: Vec<f64>,
1072 pub stats: MPISimulatorStats,
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_mpi_simulator_creation() {
1082 let config = MPISimulatorConfig {
1083 total_qubits: 4,
1084 ..Default::default()
1085 };
1086 let simulator = MPIQuantumSimulator::new(config);
1087 assert!(simulator.is_ok());
1088 }
1089
1090 #[test]
1091 fn test_mpi_simulator_initialization() {
1092 let config = MPISimulatorConfig {
1093 total_qubits: 4,
1094 ..Default::default()
1095 };
1096 let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1097 assert!(simulator.initialize().is_ok());
1098
1099 let state = simulator
1100 .get_local_state()
1101 .expect("failed to get local state");
1102 assert_eq!(state[0], Complex64::new(1.0, 0.0));
1103 }
1104
1105 #[test]
1106 fn test_mpi_communicator_creation() {
1107 let comm = MPICommunicator::new();
1108 assert!(comm.is_ok());
1109
1110 let comm = comm.expect("failed to create communicator");
1111 assert_eq!(comm.rank(), 0);
1112 assert_eq!(comm.size(), 1);
1113 }
1114
1115 #[test]
1116 fn test_single_qubit_gate() {
1117 let config = MPISimulatorConfig {
1118 total_qubits: 4,
1119 ..Default::default()
1120 };
1121 let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1122 simulator.initialize().expect("failed to initialize");
1123
1124 let x_gate = Array2::from_shape_vec(
1126 (2, 2),
1127 vec![
1128 Complex64::new(0.0, 0.0),
1129 Complex64::new(1.0, 0.0),
1130 Complex64::new(1.0, 0.0),
1131 Complex64::new(0.0, 0.0),
1132 ],
1133 )
1134 .expect("valid 2x2 matrix shape");
1135
1136 let result = simulator.apply_single_qubit_gate(0, &x_gate);
1137 assert!(result.is_ok());
1138 }
1139
1140 #[test]
1141 fn test_probability_distribution() {
1142 let config = MPISimulatorConfig {
1143 total_qubits: 2,
1144 ..Default::default()
1145 };
1146 let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1147 simulator.initialize().expect("failed to initialize");
1148
1149 let probs = simulator
1150 .get_probability_distribution()
1151 .expect("failed to get probability distribution");
1152 assert_eq!(probs.len(), 4);
1153 assert!((probs[0] - 1.0).abs() < 1e-10);
1154 }
1155
1156 #[test]
1157 fn test_mpi_stats() {
1158 let config = MPISimulatorConfig {
1159 total_qubits: 4,
1160 ..Default::default()
1161 };
1162 let simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1163
1164 let stats = simulator.get_stats().expect("failed to get stats");
1165 assert_eq!(stats.gates_executed, 0);
1166 }
1167
1168 #[test]
1169 fn test_distribution_strategies() {
1170 let strategies = vec![
1171 MPIDistributionStrategy::AmplitudePartition,
1172 MPIDistributionStrategy::QubitPartition,
1173 MPIDistributionStrategy::HybridPartition,
1174 MPIDistributionStrategy::GateAwarePartition,
1175 MPIDistributionStrategy::HilbertCurvePartition,
1176 ];
1177
1178 for strategy in strategies {
1179 let config = MPISimulatorConfig {
1180 total_qubits: 4,
1181 distribution_strategy: strategy,
1182 ..Default::default()
1183 };
1184 let simulator = MPIQuantumSimulator::new(config);
1185 assert!(simulator.is_ok());
1186 }
1187 }
1188
1189 #[test]
1190 fn test_reset() {
1191 let config = MPISimulatorConfig {
1192 total_qubits: 4,
1193 ..Default::default()
1194 };
1195 let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1196 simulator.initialize().expect("failed to initialize");
1197
1198 let h_gate = Array2::from_shape_vec(
1200 (2, 2),
1201 vec![
1202 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1203 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1204 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
1205 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
1206 ],
1207 )
1208 .expect("valid 2x2 matrix shape");
1209 simulator
1210 .apply_single_qubit_gate(0, &h_gate)
1211 .expect("failed to apply gate");
1212
1213 simulator.reset().expect("failed to reset");
1215
1216 let state = simulator
1218 .get_local_state()
1219 .expect("failed to get local state");
1220 assert!((state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
1221 }
1222
1223 #[test]
1224 fn test_collective_optimization_config() {
1225 let config = CollectiveOptimization {
1226 use_nonblocking: true,
1227 enable_fusion: true,
1228 buffer_size: 32 * 1024 * 1024,
1229 allreduce_algorithm: AllreduceAlgorithm::Ring,
1230 broadcast_algorithm: BroadcastAlgorithm::Pipeline,
1231 };
1232
1233 assert!(config.use_nonblocking);
1234 assert!(config.enable_fusion);
1235 assert_eq!(config.buffer_size, 32 * 1024 * 1024);
1236 }
1237
1238 #[test]
1239 fn test_checkpoint_config() {
1240 let config = CheckpointConfig {
1241 enable: true,
1242 interval: 500,
1243 storage_path: "/custom/path".to_string(),
1244 use_compression: false,
1245 };
1246
1247 assert!(config.enable);
1248 assert_eq!(config.interval, 500);
1249 assert!(!config.use_compression);
1250 }
1251
1252 #[test]
1253 fn test_two_qubit_gate() {
1254 let config = MPISimulatorConfig {
1255 total_qubits: 4,
1256 ..Default::default()
1257 };
1258 let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
1259 simulator.initialize().expect("failed to initialize");
1260
1261 let cnot_gate = Array2::from_shape_vec(
1263 (4, 4),
1264 vec![
1265 Complex64::new(1.0, 0.0),
1266 Complex64::new(0.0, 0.0),
1267 Complex64::new(0.0, 0.0),
1268 Complex64::new(0.0, 0.0),
1269 Complex64::new(0.0, 0.0),
1270 Complex64::new(1.0, 0.0),
1271 Complex64::new(0.0, 0.0),
1272 Complex64::new(0.0, 0.0),
1273 Complex64::new(0.0, 0.0),
1274 Complex64::new(0.0, 0.0),
1275 Complex64::new(0.0, 0.0),
1276 Complex64::new(1.0, 0.0),
1277 Complex64::new(0.0, 0.0),
1278 Complex64::new(0.0, 0.0),
1279 Complex64::new(1.0, 0.0),
1280 Complex64::new(0.0, 0.0),
1281 ],
1282 )
1283 .expect("valid 4x4 matrix shape");
1284
1285 let result = simulator.apply_two_qubit_gate(0, 1, &cnot_gate);
1286 assert!(result.is_ok());
1287 }
1288}