use mpi::collective::CommunicatorCollectives;
use mpi::topology::Communicator;
use crate::kernel::{Complex, Float};
use super::distribution::LocalPartition;
use super::error::MpiError;
pub trait MpiFloat: Float + mpi::datatype::Equivalence {}
impl MpiFloat for f32 {}
impl MpiFloat for f64 {}
pub struct MpiPool<C: Communicator> {
comm: C,
size: i32,
rank: i32,
}
impl<C: Communicator> MpiPool<C> {
pub fn new(comm: C) -> Self {
let size = comm.size();
let rank = comm.rank();
Self { comm, size, rank }
}
#[inline]
pub fn size(&self) -> usize {
self.size as usize
}
#[inline]
pub fn rank(&self) -> usize {
self.rank as usize
}
#[inline]
pub fn is_root(&self) -> bool {
self.rank == 0
}
pub fn comm(&self) -> &C {
&self.comm
}
pub fn local_partition(&self, global_n: usize) -> LocalPartition {
LocalPartition::new(global_n, self.size(), self.rank())
}
pub fn barrier(&self) {
self.comm.barrier();
}
}
impl<C: Communicator> MpiPool<C> {
pub fn all_to_all_complex<T: MpiFloat>(
&self,
send_data: &[Complex<T>],
recv_data: &mut [Complex<T>],
count: usize,
) -> Result<(), MpiError>
where
Complex<T>: mpi::datatype::Equivalence,
{
let expected_len = count * self.size();
if send_data.len() < expected_len {
return Err(MpiError::SizeMismatch {
expected: expected_len,
actual: send_data.len(),
});
}
if recv_data.len() < expected_len {
return Err(MpiError::SizeMismatch {
expected: expected_len,
actual: recv_data.len(),
});
}
self.comm
.all_to_all_into(&send_data[..expected_len], &mut recv_data[..expected_len]);
Ok(())
}
pub fn all_to_all_v_complex<T: MpiFloat>(
&self,
send_data: &[Complex<T>],
send_counts: &[i32],
send_displs: &[i32],
recv_data: &mut [Complex<T>],
recv_counts: &[i32],
recv_displs: &[i32],
) -> Result<(), MpiError>
where
Complex<T>: mpi::datatype::Equivalence,
{
use mpi::datatype::PartitionMut;
let send_partition =
mpi::datatype::Partition::new(send_data, send_counts.to_vec(), send_displs.to_vec());
let mut recv_partition =
PartitionMut::new(recv_data, recv_counts.to_vec(), recv_displs.to_vec());
self.comm
.all_to_all_varcount_into(&send_partition, &mut recv_partition);
Ok(())
}
pub fn broadcast_complex<T: MpiFloat>(
&self,
data: &mut [Complex<T>],
root: usize,
) -> Result<(), MpiError>
where
Complex<T>: mpi::datatype::Equivalence,
{
use mpi::collective::Root;
let root_process = self.comm.process_at_rank(root as i32);
root_process.broadcast_into(data);
Ok(())
}
pub fn all_gather_complex<T: MpiFloat>(
&self,
send_data: &[Complex<T>],
recv_data: &mut [Complex<T>],
) -> Result<(), MpiError>
where
Complex<T>: mpi::datatype::Equivalence,
{
let expected_recv_len = send_data.len() * self.size();
if recv_data.len() < expected_recv_len {
return Err(MpiError::SizeMismatch {
expected: expected_recv_len,
actual: recv_data.len(),
});
}
self.comm.all_gather_into(send_data, recv_data);
Ok(())
}
pub fn all_gather_v_complex<T: MpiFloat>(
&self,
send_data: &[Complex<T>],
recv_data: &mut [Complex<T>],
recv_counts: &[i32],
recv_displs: &[i32],
) -> Result<(), MpiError>
where
Complex<T>: mpi::datatype::Equivalence,
{
use mpi::datatype::PartitionMut;
let mut recv_partition =
PartitionMut::new(recv_data, recv_counts.to_vec(), recv_displs.to_vec());
self.comm
.all_gather_varcount_into(send_data, &mut recv_partition);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_local_partition() {
let partition = LocalPartition::new(100, 4, 1);
assert_eq!(partition.local_n, 25);
assert_eq!(partition.local_start, 25);
}
}