use crate::distributed_simulator::{
CommunicationConfig, CommunicationPattern, DistributedSimulatorConfig, DistributionStrategy,
FaultToleranceConfig, LoadBalancingConfig, LoadBalancingStrategy, NetworkConfig,
};
use crate::large_scale_simulator::{LargeScaleSimulatorConfig, QuantumStateRepresentation};
use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
use scirs2_core::parallel_ops::{IndexedParallelIterator, ParallelIterator};
use scirs2_core::Complex64;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct MPIQuantumSimulator {
communicator: MPICommunicator,
local_state: Arc<RwLock<LocalQuantumState>>,
config: MPISimulatorConfig,
stats: Arc<Mutex<MPISimulatorStats>>,
sync_manager: StateSynchronizationManager,
gate_handler: GateDistributionHandler,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MPISimulatorConfig {
pub total_qubits: usize,
pub distribution_strategy: MPIDistributionStrategy,
pub collective_optimization: CollectiveOptimization,
pub overlap_config: CommunicationOverlapConfig,
pub checkpoint_config: CheckpointConfig,
pub memory_config: MemoryConfig,
}
impl Default for MPISimulatorConfig {
fn default() -> Self {
Self {
total_qubits: 20,
distribution_strategy: MPIDistributionStrategy::AmplitudePartition,
collective_optimization: CollectiveOptimization::default(),
overlap_config: CommunicationOverlapConfig::default(),
checkpoint_config: CheckpointConfig::default(),
memory_config: MemoryConfig::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MPIDistributionStrategy {
AmplitudePartition,
QubitPartition,
HybridPartition,
GateAwarePartition,
HilbertCurvePartition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollectiveOptimization {
pub use_nonblocking: bool,
pub enable_fusion: bool,
pub buffer_size: usize,
pub allreduce_algorithm: AllreduceAlgorithm,
pub broadcast_algorithm: BroadcastAlgorithm,
}
impl Default for CollectiveOptimization {
fn default() -> Self {
Self {
use_nonblocking: true,
enable_fusion: true,
buffer_size: 16 * 1024 * 1024, allreduce_algorithm: AllreduceAlgorithm::RecursiveDoubling,
broadcast_algorithm: BroadcastAlgorithm::BinomialTree,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AllreduceAlgorithm {
Ring,
RecursiveDoubling,
Rabenseifner,
Automatic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BroadcastAlgorithm {
BinomialTree,
ScatterAllgather,
Pipeline,
Automatic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunicationOverlapConfig {
pub enable_overlap: bool,
pub pipeline_stages: usize,
pub prefetch_distance: usize,
}
impl Default for CommunicationOverlapConfig {
fn default() -> Self {
Self {
enable_overlap: true,
pipeline_stages: 4,
prefetch_distance: 2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointConfig {
pub enable: bool,
pub interval: usize,
pub storage_path: String,
pub use_compression: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
enable: false,
interval: 1000,
storage_path: "/tmp/quantum_checkpoint".to_string(),
use_compression: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub max_memory_per_node: usize,
pub enable_pooling: bool,
pub pool_size: usize,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
max_memory_per_node: 64 * 1024 * 1024 * 1024, enable_pooling: true,
pool_size: 1024 * 1024 * 1024, }
}
}
#[derive(Debug)]
pub struct MPICommunicator {
rank: usize,
size: usize,
backend: MPIBackend,
buffer_pool: Arc<Mutex<Vec<Vec<u8>>>>,
pending_requests: Arc<Mutex<Vec<MPIRequest>>>,
}
#[derive(Debug, Clone)]
pub enum MPIBackend {
Simulated(SimulatedMPIBackend),
#[cfg(feature = "mpi")]
Native(NativeMPIBackend),
TCP(TCPMPIBackend),
}
#[derive(Debug, Clone)]
pub struct SimulatedMPIBackend {
shared_state: Arc<RwLock<SimulatedMPIState>>,
}
#[derive(Debug, Default)]
pub struct SimulatedMPIState {
message_buffers: HashMap<usize, Vec<Vec<u8>>>,
barrier_count: usize,
collective_results: HashMap<String, Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct TCPMPIBackend {
connections: Arc<RwLock<HashMap<usize, std::net::SocketAddr>>>,
}
#[cfg(feature = "mpi")]
#[derive(Debug, Clone)]
pub struct NativeMPIBackend {
comm_handle: usize,
}
#[derive(Debug)]
pub struct MPIRequest {
id: usize,
request_type: MPIRequestType,
completed: Arc<Mutex<bool>>,
}
#[derive(Debug, Clone)]
pub enum MPIRequestType {
Send { dest: usize, tag: i32 },
Recv { source: usize, tag: i32 },
Collective { operation: String },
}
#[derive(Debug)]
pub struct LocalQuantumState {
amplitudes: Array1<Complex64>,
global_offset: usize,
local_qubits: Vec<usize>,
ghost_cells: GhostCells,
}
#[derive(Debug, Clone, Default)]
pub struct GhostCells {
left: Vec<Complex64>,
right: Vec<Complex64>,
width: usize,
}
#[derive(Debug, Clone, Default)]
pub struct MPISimulatorStats {
pub gates_executed: u64,
pub communication_time: Duration,
pub computation_time: Duration,
pub sync_count: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub load_imbalance: f64,
}
#[derive(Debug)]
pub struct StateSynchronizationManager {
strategy: SyncStrategy,
pending: Arc<Mutex<Vec<SyncOperation>>>,
}
#[derive(Debug, Clone, Copy)]
pub enum SyncStrategy {
Eager,
Lazy,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct SyncOperation {
qubits: Vec<usize>,
op_type: SyncOpType,
}
#[derive(Debug, Clone)]
pub enum SyncOpType {
BoundaryExchange,
GlobalReduction,
PartitionSwap,
}
#[derive(Debug)]
pub struct GateDistributionHandler {
routing_table: Arc<RwLock<HashMap<usize, usize>>>,
gate_classifier: GateClassifier,
}
#[derive(Debug)]
pub struct GateClassifier {
local_qubits: Vec<usize>,
}
impl MPIQuantumSimulator {
pub fn new(config: MPISimulatorConfig) -> QuantRS2Result<Self> {
let communicator = MPICommunicator::new()?;
let total_amplitudes = 1usize << config.total_qubits;
let local_size = total_amplitudes / communicator.size;
let global_offset = communicator.rank * local_size;
let local_state = LocalQuantumState {
amplitudes: Array1::zeros(local_size),
global_offset,
local_qubits: Self::calculate_local_qubits(
config.total_qubits,
communicator.rank,
communicator.size,
),
ghost_cells: GhostCells::default(),
};
let sync_manager = StateSynchronizationManager {
strategy: SyncStrategy::Adaptive,
pending: Arc::new(Mutex::new(Vec::new())),
};
let gate_handler = GateDistributionHandler {
routing_table: Arc::new(RwLock::new(HashMap::new())),
gate_classifier: GateClassifier {
local_qubits: local_state.local_qubits.clone(),
},
};
Ok(Self {
communicator,
local_state: Arc::new(RwLock::new(local_state)),
config,
stats: Arc::new(Mutex::new(MPISimulatorStats::default())),
sync_manager,
gate_handler,
})
}
fn calculate_local_qubits(total_qubits: usize, rank: usize, size: usize) -> Vec<usize> {
let partition_bits = (size as f64).log2().ceil() as usize;
let local_bits = total_qubits - partition_bits;
(0..local_bits).collect()
}
pub fn initialize(&mut self) -> QuantRS2Result<()> {
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
state.amplitudes.fill(Complex64::new(0.0, 0.0));
if self.communicator.rank == 0 {
state.amplitudes[0] = Complex64::new(1.0, 0.0);
}
Ok(())
}
pub fn apply_single_qubit_gate(
&mut self,
qubit: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let start = Instant::now();
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
if state.local_qubits.contains(&qubit) {
drop(state);
self.apply_local_single_qubit_gate(qubit, gate_matrix)?;
} else {
drop(state);
self.apply_distributed_single_qubit_gate(qubit, gate_matrix)?;
}
let mut stats = self
.stats
.lock()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
stats.gates_executed += 1;
stats.computation_time += start.elapsed();
Ok(())
}
fn apply_local_single_qubit_gate(
&self,
qubit: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let n = state.amplitudes.len();
let stride = 1 << qubit;
let amplitudes = state.amplitudes.as_slice_mut().ok_or_else(|| {
QuantRS2Error::InvalidInput("Failed to get mutable slice".to_string())
})?;
for i in 0..n / 2 {
let i0 = (i / stride) * (2 * stride) + (i % stride);
let i1 = i0 + stride;
let a0 = amplitudes[i0];
let a1 = amplitudes[i1];
amplitudes[i0] = gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1;
amplitudes[i1] = gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1;
}
Ok(())
}
fn apply_distributed_single_qubit_gate(
&self,
qubit: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let partition_bit = qubit - self.gate_handler.gate_classifier.local_qubits.len();
let partner = self.communicator.rank ^ (1 << partition_bit);
self.exchange_boundary_data(partner)?;
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let n = state.amplitudes.len();
let local_qubits = state.local_qubits.len();
let local_stride = 1 << local_qubits;
let is_lower = (self.communicator.rank >> partition_bit) & 1 == 0;
for i in 0..n {
let global_i = state.global_offset + i;
let partner_i = global_i ^ local_stride;
let partner_amp = if is_lower {
state
.ghost_cells
.right
.get(i)
.copied()
.unwrap_or(Complex64::new(0.0, 0.0))
} else {
state
.ghost_cells
.left
.get(i)
.copied()
.unwrap_or(Complex64::new(0.0, 0.0))
};
let local_amp = state.amplitudes[i];
let (a0, a1) = if is_lower {
(local_amp, partner_amp)
} else {
(partner_amp, local_amp)
};
let new_amp = if is_lower {
gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1
} else {
gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1
};
state.amplitudes[i] = new_amp;
}
Ok(())
}
fn exchange_boundary_data(&self, partner: usize) -> QuantRS2Result<()> {
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let send_data: Vec<Complex64> = state.amplitudes.iter().copied().collect();
drop(state);
let recv_data = self.communicator.sendrecv(&send_data, partner)?;
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
if self.communicator.rank < partner {
state.ghost_cells.right = recv_data;
} else {
state.ghost_cells.left = recv_data;
}
Ok(())
}
pub fn apply_two_qubit_gate(
&mut self,
control: usize,
target: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let start = Instant::now();
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let control_local = state.local_qubits.contains(&control);
let target_local = state.local_qubits.contains(&target);
drop(state);
match (control_local, target_local) {
(true, true) => {
self.apply_local_two_qubit_gate(control, target, gate_matrix)?;
}
(true, false) | (false, true) => {
self.apply_partial_distributed_gate(control, target, gate_matrix)?;
}
(false, false) => {
self.apply_full_distributed_gate(control, target, gate_matrix)?;
}
}
let mut stats = self
.stats
.lock()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
stats.gates_executed += 1;
stats.computation_time += start.elapsed();
Ok(())
}
fn apply_local_two_qubit_gate(
&self,
control: usize,
target: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let n = state.amplitudes.len();
let control_stride = 1 << control;
let target_stride = 1 << target;
let (low_stride, high_stride) = if control < target {
(control_stride, target_stride)
} else {
(target_stride, control_stride)
};
for i in 0..n / 4 {
let base = (i / low_stride) * (2 * low_stride) + (i % low_stride);
let base = (base / high_stride) * (2 * high_stride) + (base % high_stride);
let i00 = base;
let i01 = base + target_stride;
let i10 = base + control_stride;
let i11 = base + control_stride + target_stride;
let a00 = state.amplitudes[i00];
let a01 = state.amplitudes[i01];
let a10 = state.amplitudes[i10];
let a11 = state.amplitudes[i11];
state.amplitudes[i00] = gate_matrix[[0, 0]] * a00
+ gate_matrix[[0, 1]] * a01
+ gate_matrix[[0, 2]] * a10
+ gate_matrix[[0, 3]] * a11;
state.amplitudes[i01] = gate_matrix[[1, 0]] * a00
+ gate_matrix[[1, 1]] * a01
+ gate_matrix[[1, 2]] * a10
+ gate_matrix[[1, 3]] * a11;
state.amplitudes[i10] = gate_matrix[[2, 0]] * a00
+ gate_matrix[[2, 1]] * a01
+ gate_matrix[[2, 2]] * a10
+ gate_matrix[[2, 3]] * a11;
state.amplitudes[i11] = gate_matrix[[3, 0]] * a00
+ gate_matrix[[3, 1]] * a01
+ gate_matrix[[3, 2]] * a10
+ gate_matrix[[3, 3]] * a11;
}
Ok(())
}
fn apply_partial_distributed_gate(
&self,
control: usize,
target: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let (local_qubit, remote_qubit) = if state.local_qubits.contains(&control) {
(control, target)
} else {
(target, control)
};
drop(state);
let partition_bit = remote_qubit - self.gate_handler.gate_classifier.local_qubits.len();
let partner = self.communicator.rank ^ (1 << partition_bit);
self.exchange_boundary_data(partner)?;
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let n = state.amplitudes.len();
let local_stride = 1 << local_qubit;
for i in 0..n / 2 {
let i0 = (i / local_stride) * (2 * local_stride) + (i % local_stride);
let i1 = i0 + local_stride;
let a0 = state.amplitudes[i0];
let a1 = state.amplitudes[i1];
state.amplitudes[i0] = gate_matrix[[0, 0]] * a0 + gate_matrix[[0, 1]] * a1;
state.amplitudes[i1] = gate_matrix[[1, 0]] * a0 + gate_matrix[[1, 1]] * a1;
}
Ok(())
}
fn apply_full_distributed_gate(
&self,
control: usize,
target: usize,
gate_matrix: &Array2<Complex64>,
) -> QuantRS2Result<()> {
let local_qubits_len = self.gate_handler.gate_classifier.local_qubits.len();
let control_partition = control - local_qubits_len;
let target_partition = target - local_qubits_len;
let control_partner = self.communicator.rank ^ (1 << control_partition);
self.exchange_boundary_data(control_partner)?;
let target_partner = self.communicator.rank ^ (1 << target_partition);
self.exchange_boundary_data(target_partner)?;
let mut state = self
.local_state
.write()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let _ = gate_matrix;
Ok(())
}
pub const fn barrier(&self) -> QuantRS2Result<()> {
self.communicator.barrier()
}
pub fn get_probability_distribution(&self) -> QuantRS2Result<Vec<f64>> {
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
let local_probs: Vec<f64> = state.amplitudes.iter().map(|a| (a * a.conj()).re).collect();
drop(state);
let global_probs = self.communicator.gather(&local_probs, 0)?;
Ok(global_probs)
}
pub fn measure_all(&self) -> QuantRS2Result<Vec<bool>> {
let probs = self.get_probability_distribution()?;
if self.communicator.rank == 0 {
let mut rng = scirs2_core::random::thread_rng();
let random: f64 = rng.random();
let mut cumulative = 0.0;
let mut result_idx = 0;
for (i, &prob) in probs.iter().enumerate() {
cumulative += prob;
if random < cumulative {
result_idx = i;
break;
}
}
let result: Vec<bool> = (0..self.config.total_qubits)
.map(|i| (result_idx >> i) & 1 == 1)
.collect();
self.communicator.broadcast(&result, 0)
} else {
self.communicator.broadcast(&[], 0)
}
}
pub fn get_local_state(&self) -> QuantRS2Result<Array1<Complex64>> {
let state = self
.local_state
.read()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire state lock".to_string()))?;
Ok(state.amplitudes.clone())
}
pub fn get_stats(&self) -> QuantRS2Result<MPISimulatorStats> {
let stats = self
.stats
.lock()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
Ok(stats.clone())
}
pub fn reset(&mut self) -> QuantRS2Result<()> {
self.initialize()?;
let mut stats = self
.stats
.lock()
.map_err(|_| QuantRS2Error::InvalidInput("Failed to acquire stats lock".to_string()))?;
*stats = MPISimulatorStats::default();
Ok(())
}
}
impl MPICommunicator {
pub fn new() -> QuantRS2Result<Self> {
let shared_state = Arc::new(RwLock::new(SimulatedMPIState::default()));
let backend = MPIBackend::Simulated(SimulatedMPIBackend { shared_state });
Ok(Self {
rank: 0,
size: 1,
backend,
buffer_pool: Arc::new(Mutex::new(Vec::new())),
pending_requests: Arc::new(Mutex::new(Vec::new())),
})
}
#[must_use]
pub fn with_config(rank: usize, size: usize, backend: MPIBackend) -> Self {
Self {
rank,
size,
backend,
buffer_pool: Arc::new(Mutex::new(Vec::new())),
pending_requests: Arc::new(Mutex::new(Vec::new())),
}
}
#[must_use]
pub const fn rank(&self) -> usize {
self.rank
}
#[must_use]
pub const fn size(&self) -> usize {
self.size
}
pub const fn barrier(&self) -> QuantRS2Result<()> {
match &self.backend {
MPIBackend::Simulated(_) => {
Ok(())
}
MPIBackend::TCP(_) => {
Ok(())
}
#[cfg(feature = "mpi")]
MPIBackend::Native(_) => {
Ok(())
}
}
}
pub fn sendrecv(
&self,
send_data: &[Complex64],
partner: usize,
) -> QuantRS2Result<Vec<Complex64>> {
match &self.backend {
MPIBackend::Simulated(_) => {
Ok(send_data.to_vec())
}
MPIBackend::TCP(_) => {
Ok(send_data.to_vec())
}
#[cfg(feature = "mpi")]
MPIBackend::Native(_) => {
Ok(send_data.to_vec())
}
}
}
pub fn gather<T: Clone>(&self, local_data: &[T], root: usize) -> QuantRS2Result<Vec<T>> {
match &self.backend {
MPIBackend::Simulated(_) => {
Ok(local_data.to_vec())
}
MPIBackend::TCP(_) => {
Ok(local_data.to_vec())
}
#[cfg(feature = "mpi")]
MPIBackend::Native(_) => {
Ok(local_data.to_vec())
}
}
}
pub fn broadcast<T: Clone>(&self, data: &[T], root: usize) -> QuantRS2Result<Vec<T>> {
match &self.backend {
MPIBackend::Simulated(_) => {
Ok(data.to_vec())
}
MPIBackend::TCP(_) => {
Ok(data.to_vec())
}
#[cfg(feature = "mpi")]
MPIBackend::Native(_) => {
Ok(data.to_vec())
}
}
}
pub fn allreduce(&self, local_data: &[f64], op: ReduceOp) -> QuantRS2Result<Vec<f64>> {
match &self.backend {
MPIBackend::Simulated(_) => {
Ok(local_data.to_vec())
}
MPIBackend::TCP(_) => {
Ok(local_data.to_vec())
}
#[cfg(feature = "mpi")]
MPIBackend::Native(_) => {
Ok(local_data.to_vec())
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ReduceOp {
Sum,
Max,
Min,
Prod,
}
#[derive(Debug, Clone)]
pub struct MPISimulationResult {
pub measurements: Vec<bool>,
pub probabilities: Vec<f64>,
pub stats: MPISimulatorStats,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mpi_simulator_creation() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let simulator = MPIQuantumSimulator::new(config);
assert!(simulator.is_ok());
}
#[test]
fn test_mpi_simulator_initialization() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
assert!(simulator.initialize().is_ok());
let state = simulator
.get_local_state()
.expect("failed to get local state");
assert_eq!(state[0], Complex64::new(1.0, 0.0));
}
#[test]
fn test_mpi_communicator_creation() {
let comm = MPICommunicator::new();
assert!(comm.is_ok());
let comm = comm.expect("failed to create communicator");
assert_eq!(comm.rank(), 0);
assert_eq!(comm.size(), 1);
}
#[test]
fn test_single_qubit_gate() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
simulator.initialize().expect("failed to initialize");
let x_gate = Array2::from_shape_vec(
(2, 2),
vec![
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
],
)
.expect("valid 2x2 matrix shape");
let result = simulator.apply_single_qubit_gate(0, &x_gate);
assert!(result.is_ok());
}
#[test]
fn test_probability_distribution() {
let config = MPISimulatorConfig {
total_qubits: 2,
..Default::default()
};
let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
simulator.initialize().expect("failed to initialize");
let probs = simulator
.get_probability_distribution()
.expect("failed to get probability distribution");
assert_eq!(probs.len(), 4);
assert!((probs[0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_mpi_stats() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
let stats = simulator.get_stats().expect("failed to get stats");
assert_eq!(stats.gates_executed, 0);
}
#[test]
fn test_distribution_strategies() {
let strategies = vec![
MPIDistributionStrategy::AmplitudePartition,
MPIDistributionStrategy::QubitPartition,
MPIDistributionStrategy::HybridPartition,
MPIDistributionStrategy::GateAwarePartition,
MPIDistributionStrategy::HilbertCurvePartition,
];
for strategy in strategies {
let config = MPISimulatorConfig {
total_qubits: 4,
distribution_strategy: strategy,
..Default::default()
};
let simulator = MPIQuantumSimulator::new(config);
assert!(simulator.is_ok());
}
}
#[test]
fn test_reset() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
simulator.initialize().expect("failed to initialize");
let h_gate = Array2::from_shape_vec(
(2, 2),
vec![
Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
],
)
.expect("valid 2x2 matrix shape");
simulator
.apply_single_qubit_gate(0, &h_gate)
.expect("failed to apply gate");
simulator.reset().expect("failed to reset");
let state = simulator
.get_local_state()
.expect("failed to get local state");
assert!((state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
}
#[test]
fn test_collective_optimization_config() {
let config = CollectiveOptimization {
use_nonblocking: true,
enable_fusion: true,
buffer_size: 32 * 1024 * 1024,
allreduce_algorithm: AllreduceAlgorithm::Ring,
broadcast_algorithm: BroadcastAlgorithm::Pipeline,
};
assert!(config.use_nonblocking);
assert!(config.enable_fusion);
assert_eq!(config.buffer_size, 32 * 1024 * 1024);
}
#[test]
fn test_checkpoint_config() {
let config = CheckpointConfig {
enable: true,
interval: 500,
storage_path: "/custom/path".to_string(),
use_compression: false,
};
assert!(config.enable);
assert_eq!(config.interval, 500);
assert!(!config.use_compression);
}
#[test]
fn test_two_qubit_gate() {
let config = MPISimulatorConfig {
total_qubits: 4,
..Default::default()
};
let mut simulator = MPIQuantumSimulator::new(config).expect("failed to create simulator");
simulator.initialize().expect("failed to initialize");
let cnot_gate = Array2::from_shape_vec(
(4, 4),
vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
],
)
.expect("valid 4x4 matrix shape");
let result = simulator.apply_two_qubit_gate(0, 1, &cnot_gate);
assert!(result.is_ok());
}
}