#![allow(clippy::module_name_repetitions)]
use mpi_crate::Count;
use ndarray::{s, Array2, ArrayBase, Data, DataMut, Dim, Ix2};
#[derive(Debug, Clone)]
pub struct Distribution<const N: usize> {
pub sz: [usize; N],
pub st: [usize; N],
pub en: [usize; N],
pub sz_procs: Vec<[usize; N]>,
pub st_procs: Vec<[usize; N]>,
pub en_procs: Vec<[usize; N]>,
pub nprocs: usize,
pub nrank: usize,
pub axis_contig: usize,
}
impl<const N: usize> Distribution<N> {
#[must_use]
pub fn new(n_global: [usize; N], nprocs: usize, nrank: usize, axis_contig: usize) -> Self {
if axis_contig > 1 {
panic!(
"axis_contig must be 0 (first axis) or 1 (second axis), got {}.",
axis_contig
);
}
let mut st_procs: Vec<[usize; N]> = vec![[0; N]; nprocs];
let mut sz_procs: Vec<[usize; N]> = vec![n_global; nprocs];
let mut en_procs: Vec<[usize; N]> = vec![n_global; nprocs];
for array in &mut en_procs {
for element in array.iter_mut() {
*element -= 1;
}
}
for axis in 0..N {
if axis == axis_contig {
continue;
}
let (st_split, en_split, sz_split) = Self::distribute(n_global[axis], nprocs);
for (i, j) in st_procs.iter_mut().zip(st_split.iter()) {
i[axis] = *j;
}
for (i, j) in en_procs.iter_mut().zip(en_split.iter()) {
i[axis] = *j;
}
for (i, j) in sz_procs.iter_mut().zip(sz_split.iter()) {
i[axis] = *j;
}
}
let st = st_procs[nrank];
let en = en_procs[nrank];
let sz = sz_procs[nrank];
Self {
st,
en,
sz,
st_procs,
en_procs,
sz_procs,
nprocs,
nrank,
axis_contig,
}
}
fn distribute(n_global: usize, nprocs: usize) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
let size = n_global / nprocs;
let mut st = vec![0; nprocs];
let mut en = vec![0; nprocs];
let mut sz = vec![0; nprocs];
st[0] = 0;
sz[0] = size;
en[0] = size - 1;
let nu = n_global - size * nprocs;
let nl = nprocs - nu;
for i in 1..nl {
st[i] = st[i - 1] + size;
sz[i] = size;
en[i] = en[i - 1] + size;
}
let size = size + 1;
for i in nl..nprocs {
st[i] = en[i - 1] + 1;
sz[i] = size;
en[i] = en[i - 1] + size;
}
en[nprocs - 1] = n_global - 1;
sz[nprocs - 1] = en[nprocs - 1] - st[nprocs - 1] + 1;
(st, en, sz)
}
#[must_use]
pub fn length(&self) -> usize {
self.sz.iter().product()
}
pub fn check_shape<A, S>(&self, data: &ArrayBase<S, Dim<[usize; N]>>)
where
S: Data<Elem = A>,
Dim<[usize; N]>: ndarray::Dimension,
{
if data.shape() != self.sz {
panic!(
"Shape mismatch, got {:?} expected {:?}.",
data.shape(),
self.sz
);
}
}
#[must_use]
pub fn get_counts_all_gather(&self) -> (Vec<Count>, Vec<Count>) {
let counts: Vec<Count> = self
.sz_procs
.iter()
.map(|x| (x.iter().product::<usize>()) as i32)
.collect();
let displs: Vec<Count> = counts
.iter()
.scan(0, |acc, &x| {
let tmp = *acc;
*acc += x;
Some(tmp)
})
.collect();
(counts, displs)
}
}
impl Distribution<2> {
#[must_use]
pub fn get_counts_all_to_all(&self, recv_dist: &Distribution<2>) -> (Vec<Count>, Vec<Count>) {
let counts: Vec<Count> = recv_dist
.sz_procs
.iter()
.map(|x| (x[self.axis_contig] * self.sz[recv_dist.axis_contig]) as i32)
.collect();
let displs: Vec<Count> = counts
.iter()
.scan(0, |acc, &x| {
let tmp = *acc;
*acc += x;
Some(tmp)
})
.collect();
(counts, displs)
}
#[must_use]
pub fn split_array<S, T>(&self, global: &ArrayBase<S, Ix2>) -> Array2<T>
where
S: Data<Elem = T>,
T: Copy,
{
global
.slice(s![self.st[0]..=self.en[0], self.st[1]..=self.en[1]])
.to_owned()
}
pub fn split_array_inplace<S1, S2, T>(
&self,
global: &ArrayBase<S1, Ix2>,
pencil: &mut ArrayBase<S2, Ix2>,
) where
S1: Data<Elem = T>,
S2: Data<Elem = T> + DataMut,
T: Copy,
{
pencil.assign(&global.slice(s![self.st[0]..=self.en[0], self.st[1]..=self.en[1]]));
}
}