use num_complex::Complex64;
pub trait RankComm: std::fmt::Debug + Send + Sync {
fn rank(&self) -> usize;
fn size(&self) -> usize;
fn allgather_c64(&self, local: &[Complex64]) -> Vec<Complex64>;
fn allgather_f64(&self, local: &[f64]) -> Vec<f64>;
fn allreduce_sum_f64(&self, value: f64) -> f64;
fn sendrecv_c64(&self, partner: usize, send: &[Complex64], recv: &mut [Complex64]);
fn barrier(&self);
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SerialComm;
impl RankComm for SerialComm {
#[inline]
fn rank(&self) -> usize {
0
}
#[inline]
fn size(&self) -> usize {
1
}
#[inline]
fn allgather_c64(&self, local: &[Complex64]) -> Vec<Complex64> {
local.to_vec()
}
#[inline]
fn allgather_f64(&self, local: &[f64]) -> Vec<f64> {
local.to_vec()
}
#[inline]
fn allreduce_sum_f64(&self, value: f64) -> f64 {
value
}
#[inline]
fn sendrecv_c64(&self, _partner: usize, send: &[Complex64], recv: &mut [Complex64]) {
debug_assert_eq!(send.len(), recv.len());
recv.copy_from_slice(send);
}
#[inline]
fn barrier(&self) {}
}
#[cfg(feature = "distributed-mpi")]
const _: () = assert!(std::mem::size_of::<Complex64>() == 2 * std::mem::size_of::<f64>());
#[cfg(feature = "distributed-mpi")]
#[inline]
fn as_f64(slice: &[Complex64]) -> &[f64] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const f64, slice.len() * 2) }
}
#[cfg(feature = "distributed-mpi")]
#[inline]
fn as_f64_mut(slice: &mut [Complex64]) -> &mut [f64] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut f64, slice.len() * 2) }
}
#[cfg(feature = "distributed-mpi")]
pub struct MpiComm {
_universe: mpi::environment::Universe,
world: mpi::topology::SimpleCommunicator,
rank: usize,
size: usize,
}
#[cfg(feature = "distributed-mpi")]
impl std::fmt::Debug for MpiComm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpiComm")
.field("rank", &self.rank)
.field("size", &self.size)
.finish()
}
}
#[cfg(feature = "distributed-mpi")]
impl MpiComm {
pub fn world() -> Option<Self> {
use mpi::traits::Communicator;
let universe = mpi::initialize()?;
let world = universe.world();
let rank = world.rank() as usize;
let size = world.size() as usize;
Some(Self {
_universe: universe,
world,
rank,
size,
})
}
}
#[cfg(feature = "distributed-mpi")]
impl RankComm for MpiComm {
fn rank(&self) -> usize {
self.rank
}
fn size(&self) -> usize {
self.size
}
fn allgather_c64(&self, local: &[Complex64]) -> Vec<Complex64> {
use mpi::traits::CommunicatorCollectives;
let mut out = vec![Complex64::new(0.0, 0.0); local.len() * self.size];
self.world
.all_gather_into(as_f64(local), as_f64_mut(&mut out));
out
}
fn allgather_f64(&self, local: &[f64]) -> Vec<f64> {
use mpi::traits::CommunicatorCollectives;
let mut out = vec![0.0_f64; local.len() * self.size];
self.world.all_gather_into(local, &mut out);
out
}
fn allreduce_sum_f64(&self, value: f64) -> f64 {
use mpi::traits::CommunicatorCollectives;
let mut out = 0.0_f64;
self.world
.all_reduce_into(&value, &mut out, mpi::collective::SystemOperation::sum());
out
}
fn sendrecv_c64(&self, partner: usize, send: &[Complex64], recv: &mut [Complex64]) {
use mpi::point_to_point as p2p;
use mpi::traits::Communicator;
debug_assert_eq!(send.len(), recv.len());
let peer = self.world.process_at_rank(partner as i32);
p2p::send_receive_into(as_f64(send), &peer, as_f64_mut(recv), &peer);
}
fn barrier(&self) {
use mpi::traits::CommunicatorCollectives;
self.world.barrier();
}
}