use mpi::traits::*;
use mpi::topology::Communicator;
use crate::data::PolyData;
use crate::parallel::decomposition::Partition;
pub struct MpiContext {
universe: mpi::environment::Universe,
}
impl MpiContext {
pub fn init() -> Result<Self, String> {
let universe = mpi::initialize()
.ok_or_else(|| "MPI already initialized or unavailable".to_string())?;
Ok(Self { universe })
}
pub fn rank(&self) -> usize {
self.universe.world().rank() as usize
}
pub fn size(&self) -> usize {
self.universe.world().size() as usize
}
pub fn barrier(&self) {
self.universe.world().barrier();
}
pub fn broadcast_bytes(&self, data: &mut Vec<u8>) {
let world = self.universe.world();
let root = world.process_at_rank(0);
let mut len = data.len() as u64;
if world.rank() == 0 {
root.broadcast_into(&mut len);
} else {
root.broadcast_into(&mut len);
data.resize(len as usize, 0);
}
root.broadcast_into(data.as_mut_slice());
}
pub fn gather_f64(&self, local_value: f64) -> Vec<f64> {
let world = self.universe.world();
let root = world.process_at_rank(0);
if world.rank() == 0 {
let mut gathered = vec![0.0f64; world.size() as usize];
root.gather_into_root(&local_value, &mut gathered);
gathered
} else {
root.gather_into(&local_value);
vec![]
}
}
pub fn allreduce_sum(&self, local: f64) -> f64 {
let world = self.universe.world();
let mut global = 0.0f64;
world.all_reduce_into(&local, &mut global, mpi::collective::SystemOperation::sum());
global
}
pub fn allreduce_max(&self, local: f64) -> f64 {
let world = self.universe.world();
let mut global = 0.0f64;
world.all_reduce_into(&local, &mut global, mpi::collective::SystemOperation::max());
global
}
pub fn send_partition_size(&self, dest: usize, num_points: usize, num_cells: usize) {
let world = self.universe.world();
let data = [num_points as u64, num_cells as u64];
world.process_at_rank(dest as i32).send(&data[..]);
}
pub fn recv_partition_size(&self, source: usize) -> (usize, usize) {
let world = self.universe.world();
let (msg, _status) = world.process_at_rank(source as i32).receive_vec::<u64>();
(msg[0] as usize, msg[1] as usize)
}
}