#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::similar_names)]
pub mod distribute;
pub mod functions;
pub use distribute::Distribution;
use mpi_crate::collective::CommunicatorCollectives;
use mpi_crate::collective::Root;
use mpi_crate::datatype::{Partition, PartitionMut};
use mpi_crate::environment::Universe;
use mpi_crate::topology::Communicator;
use mpi_crate::topology::SystemCommunicator;
use mpi_crate::traits::Equivalence;
use ndarray::{Array2, ArrayBase, Data, DataMut, Dim, Ix2};
use num_traits::Zero;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct Decomp2d<'a> {
pub universe: &'a Universe,
pub world: SystemCommunicator,
pub nprocs: i32,
pub nrank: i32,
pub n_global: [usize; 2],
pub x_pencil: Distribution<2>,
pub y_pencil: Distribution<2>,
}
impl<'a> Decomp2d<'a> {
#[must_use]
pub fn new(universe: &'a Universe, n_global: [usize; 2]) -> Self {
let world = universe.world();
let nprocs = world.size();
let nrank = world.rank();
let x_pencil = Distribution::new(n_global, nprocs as usize, nrank as usize, 0);
let y_pencil = Distribution::new(n_global, nprocs as usize, nrank as usize, 1);
Self {
universe,
world,
nprocs,
nrank,
n_global,
x_pencil,
y_pencil,
}
}
#[must_use]
pub fn get_global_shape(&self) -> [usize; 2] {
self.n_global
}
#[must_use]
pub fn get_x_pencil_shape(&self) -> [usize; 2] {
self.x_pencil.sz
}
#[must_use]
pub fn get_y_pencil_shape(&self) -> [usize; 2] {
self.y_pencil.sz
}
#[must_use]
pub fn split_array_x_pencil<S, T>(&self, data: &ArrayBase<S, Ix2>) -> Array2<T>
where
S: Data<Elem = T>,
T: Copy,
{
self.x_pencil.split_array(data)
}
#[must_use]
pub fn split_array_y_pencil<S, T>(&self, data: &ArrayBase<S, Ix2>) -> Array2<T>
where
S: Data<Elem = T>,
T: Copy,
{
self.y_pencil.split_array(data)
}
pub fn transpose_x_to_y<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(snd, self.x_pencil.sz);
check_shape(rcv, self.y_pencil.sz);
let mut send = vec![T::zero(); self.x_pencil.length()];
let mut recv = vec![T::zero(); self.y_pencil.length()];
Self::split_xy(snd, &mut send);
let (send_counts, send_displs) = self.x_pencil.get_counts_all_to_all(&self.y_pencil);
let (recv_counts, recv_displs) = self.y_pencil.get_counts_all_to_all(&self.x_pencil);
{
let send_buffer = Partition::new(&send[..], &send_counts[..], &send_displs[..]);
let mut recv_buffer =
PartitionMut::new(&mut recv[..], &recv_counts[..], &recv_displs[..]);
self.world
.all_to_all_varcount_into(&send_buffer, &mut recv_buffer);
}
self.merge_xy(rcv, &recv);
}
fn split_xy<S, T>(data: &ArrayBase<S, Ix2>, buf: &mut [T])
where
S: Data<Elem = T>,
T: Copy,
{
for (d, b) in data.iter().zip(buf.iter_mut()) {
*b = *d;
}
}
fn merge_xy<S, T>(&self, data: &mut ArrayBase<S, Ix2>, buf: &[T])
where
S: DataMut<Elem = T>,
T: Copy,
{
let mut pos = 0;
for proc in 0..self.nprocs as usize {
let j1 = self.x_pencil.st_procs[proc][1];
let j2 = self.x_pencil.en_procs[proc][1];
for i in 0..self.y_pencil.sz[0] {
for j in j1..=j2 {
data[[i, j]] = buf[pos];
pos += 1;
}
}
}
}
pub fn transpose_y_to_x<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(snd, self.y_pencil.sz);
check_shape(rcv, self.x_pencil.sz);
let mut send = vec![T::zero(); self.y_pencil.length()];
let mut recv = vec![T::zero(); self.x_pencil.length()];
Self::split_yx(snd, &mut send);
let (send_counts, send_displs) = self.y_pencil.get_counts_all_to_all(&self.x_pencil);
let (recv_counts, recv_displs) = self.x_pencil.get_counts_all_to_all(&self.y_pencil);
{
let send_buffer = Partition::new(&send[..], &send_counts[..], &send_displs[..]);
let mut recv_buffer =
PartitionMut::new(&mut recv[..], &recv_counts[..], &recv_displs[..]);
self.world
.all_to_all_varcount_into(&send_buffer, &mut recv_buffer);
}
self.merge_yx(rcv, &recv);
}
fn split_yx<S, T>(data: &ArrayBase<S, Ix2>, buf: &mut [T])
where
S: Data<Elem = T>,
T: Copy,
{
let mut data_view = data.view();
data_view.swap_axes(0, 1);
for (d, b) in data_view.iter().zip(buf.iter_mut()) {
*b = *d;
}
}
fn merge_yx<S, T>(&self, data: &mut ArrayBase<S, Ix2>, buf: &[T])
where
S: DataMut<Elem = T>,
T: Copy,
{
let mut pos = 0;
for proc in 0..self.nprocs as usize {
let i1 = self.y_pencil.st_procs[proc][0];
let i2 = self.y_pencil.en_procs[proc][0];
for j in 0..self.x_pencil.sz[1] {
for i in i1..=i2 {
data[[i, j]] = buf[pos];
pos += 1;
}
}
}
}
pub fn gather_x<S1, T>(&self, snd: &ArrayBase<S1, Ix2>)
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
check_shape(snd, self.x_pencil.sz);
let mut send = vec![T::zero(); self.x_pencil.length()];
let mut snd_view = snd.view();
snd_view.swap_axes(0, 1);
for (s, m) in send.iter_mut().zip(snd_view.iter()) {
*s = *m;
}
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
root_process.gather_varcount_into(&send[..]);
}
}
pub fn gather_x_root<S1, T>(&self, snd: &ArrayBase<S1, Ix2>) -> Array2<T>
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
let mut recv = Array2::<T>::zeros(self.n_global);
self.gather_x_inplace_root(snd, &mut recv);
recv
} else {
panic!("Rank must be root!");
}
}
pub fn gather_x_inplace_root<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
check_shape(snd, self.x_pencil.sz);
let mut send = vec![T::zero(); self.x_pencil.length()];
let mut snd_view = snd.view();
snd_view.swap_axes(0, 1);
for (s, m) in send.iter_mut().zip(snd_view.iter()) {
*s = *m;
}
if self.nrank == root_rank {
let recv_length = self.n_global.iter().product::<usize>();
let mut recv = vec![T::zero(); recv_length];
let (counts, displs) = self.x_pencil.get_counts_all_gather();
{
let mut partition = PartitionMut::new(&mut recv[..], &counts[..], &displs[..]);
root_process.gather_varcount_into_root(&send[..], &mut partition);
}
check_shape(rcv, self.n_global);
rcv.swap_axes(0, 1);
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
rcv.swap_axes(0, 1);
} else {
panic!("Rank must be root!");
}
}
pub fn gather_y<S1, T>(&self, snd: &ArrayBase<S1, Ix2>)
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
check_shape(snd, self.y_pencil.sz);
let mut send = vec![T::zero(); self.y_pencil.length()];
for (s, m) in send.iter_mut().zip(snd.iter()) {
*s = *m;
}
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
root_process.gather_varcount_into(&send[..]);
}
}
#[must_use]
pub fn gather_y_root<S1, T>(&self, snd: &ArrayBase<S1, Ix2>) -> Array2<T>
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
let mut recv = Array2::<T>::zeros(self.n_global);
self.gather_y_inplace_root(snd, &mut recv);
recv
} else {
panic!("Rank must be root!");
}
}
pub fn gather_y_inplace_root<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
check_shape(snd, self.y_pencil.sz);
let mut send = vec![T::zero(); self.y_pencil.length()];
for (s, m) in send.iter_mut().zip(snd.iter()) {
*s = *m;
}
if self.nrank == root_rank {
let recv_length = self.n_global.iter().product::<usize>();
let mut recv = vec![T::zero(); recv_length];
let (counts, displs) = self.y_pencil.get_counts_all_gather();
{
let mut partition = PartitionMut::new(&mut recv[..], &counts[..], &displs[..]);
root_process.gather_varcount_into_root(&send[..], &mut partition);
}
check_shape(rcv, self.n_global);
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
} else {
panic!("Rank must be root!");
}
}
#[must_use]
pub fn scatter_x<T>(&self) -> Array2<T>
where
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
let mut recv = Array2::<T>::zeros(self.x_pencil.sz);
self.scatter_x_inplace(&mut recv);
recv
}
}
pub fn scatter_x_inplace<S2, T>(&self, rcv: &mut ArrayBase<S2, Ix2>)
where
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(rcv, self.x_pencil.sz);
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
let mut recv = vec![T::zero(); self.x_pencil.length()];
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
root_process.scatter_varcount_into(&mut recv[..]);
}
rcv.swap_axes(0, 1);
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
rcv.swap_axes(0, 1);
}
#[must_use]
pub fn scatter_x_root<S1, T>(&self, snd: &ArrayBase<S1, Ix2>) -> Array2<T>
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
let mut recv = Array2::<T>::zeros(self.x_pencil.sz);
self.scatter_x_inplace_root(snd, &mut recv);
recv
} else {
panic!("Rank must be root!");
}
}
pub fn scatter_x_inplace_root<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(snd, self.n_global);
check_shape(rcv, self.x_pencil.sz);
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
let mut recv = vec![T::zero(); self.x_pencil.length()];
if self.nrank == root_rank {
let sendv_length = self.n_global.iter().product::<usize>();
let mut send = vec![T::zero(); sendv_length];
let mut snd_view = snd.view();
snd_view.swap_axes(0, 1);
for (s, m) in send.iter_mut().zip(snd_view.iter()) {
*s = *m;
}
let (counts, displs) = self.x_pencil.get_counts_all_gather();
{
let partition = Partition::new(&send[..], &counts[..], &displs[..]);
root_process.scatter_varcount_into_root(&partition, &mut recv[..]);
}
} else {
panic!("Rank must be root!");
}
rcv.swap_axes(0, 1);
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
rcv.swap_axes(0, 1);
}
#[must_use]
pub fn scatter_y<T>(&self) -> Array2<T>
where
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
let mut recv = Array2::<T>::zeros(self.y_pencil.sz);
self.scatter_y_inplace(&mut recv);
recv
}
}
pub fn scatter_y_inplace<S2, T>(&self, rcv: &mut ArrayBase<S2, Ix2>)
where
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(rcv, self.y_pencil.sz);
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
let mut recv = vec![T::zero(); self.y_pencil.length()];
if self.nrank == root_rank {
panic!("Rank must not be root!");
} else {
root_process.scatter_varcount_into(&mut recv[..]);
}
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
}
pub fn scatter_y_root<S1, T>(&self, snd: &ArrayBase<S1, Ix2>) -> Array2<T>
where
S1: Data<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
let root_rank = 0;
if self.nrank == root_rank {
let mut recv = Array2::<T>::zeros(self.y_pencil.sz);
self.scatter_y_inplace_root(snd, &mut recv);
recv
} else {
panic!("Rank must be root!");
}
}
pub fn scatter_y_inplace_root<S1, S2, T>(
&self,
snd: &ArrayBase<S1, Ix2>,
rcv: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(snd, self.n_global);
check_shape(rcv, self.y_pencil.sz);
let root_rank = 0;
let root_process = self.world.process_at_rank(root_rank);
let mut recv = vec![T::zero(); self.y_pencil.length()];
if self.nrank == root_rank {
let sendv_length = self.n_global.iter().product::<usize>();
let mut send = vec![T::zero(); sendv_length];
for (s, m) in send.iter_mut().zip(snd.iter()) {
*s = *m;
}
let (counts, displs) = self.y_pencil.get_counts_all_gather();
{
let partition = Partition::new(&send[..], &counts[..], &displs[..]);
root_process.scatter_varcount_into_root(&partition, &mut recv[..]);
}
} else {
panic!("Rank must be root!");
}
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
}
pub fn all_gather_x<S1, S2, T>(&self, snd: &ArrayBase<S1, Ix2>, rcv: &mut ArrayBase<S2, Ix2>)
where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence + Debug,
{
check_shape(snd, self.x_pencil.sz);
check_shape(rcv, self.n_global);
let mut send = vec![T::zero(); self.x_pencil.length()];
let mut snd_view = snd.view();
snd_view.swap_axes(0, 1);
for (s, m) in send.iter_mut().zip(snd_view.iter()) {
*s = *m;
}
let recv_length = self.n_global.iter().product::<usize>();
let mut recv = vec![T::zero(); recv_length];
let (counts, displs) = self.x_pencil.get_counts_all_gather();
{
let mut partition = PartitionMut::new(&mut recv[..], &counts[..], &displs[..]);
self.world
.all_gather_varcount_into(&send[..], &mut partition);
}
rcv.swap_axes(0, 1);
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
rcv.swap_axes(0, 1);
}
pub fn all_gather_y<S1, S2, T>(&self, snd: &ArrayBase<S1, Ix2>, rcv: &mut ArrayBase<S2, Ix2>)
where
S1: Data<Elem = T>,
S2: DataMut<Elem = T>,
T: Zero + Clone + Copy + Equivalence,
{
check_shape(snd, self.y_pencil.sz);
check_shape(rcv, self.n_global);
let mut send = vec![T::zero(); self.y_pencil.length()];
for (s, m) in send.iter_mut().zip(snd.iter()) {
*s = *m;
}
let recv_length = self.n_global.iter().product::<usize>();
let mut recv = vec![T::zero(); recv_length];
let (counts, displs) = self.y_pencil.get_counts_all_gather();
{
let mut partition = PartitionMut::new(&mut recv[..], &counts[..], &displs[..]);
self.world
.all_gather_varcount_into(&send[..], &mut partition);
}
for (s, m) in rcv.iter_mut().zip(recv.iter()) {
*s = *m;
}
}
}
fn check_shape<A, S, const N: usize>(data: &ArrayBase<S, Dim<[usize; N]>>, shape: [usize; N])
where
S: Data<Elem = A>,
Dim<[usize; N]>: ndarray::Dimension,
{
if data.shape() != shape {
panic!(
"Shape mismatch, got {:?} expected {:?}.",
data.shape(),
shape
);
}
}