use std::ops::Add;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::topology::{Communicator, Process};
pub(crate) const BROADCAST_TAG: i32 = -1;
pub(crate) const GATHER_TAG: i32 = -2;
pub(crate) const SCATTER_TAG: i32 = -3;
pub(crate) const REDUCE_TAG: i32 = -4;
#[derive(Clone, Copy, Debug)]
pub enum SystemOperation {
Sum,
}
impl SystemOperation {
pub fn sum() -> Self {
Self::Sum
}
}
pub trait Root {
fn broadcast_into<T>(&self, value: &mut T)
where
T: Serialize + DeserializeOwned + Clone;
fn gather_into_root<T>(&self, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone;
fn scatter_into_root<T>(&self, input: &[T], out: &mut T)
where
T: Serialize + DeserializeOwned + Clone;
fn reduce_sum_into_root<T>(&self, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
fn all_reduce_sum_into<T>(&self, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
fn reduce_into_root<T>(&self, value: &T, out: &mut T, op: SystemOperation)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
fn all_reduce_into<T>(&self, value: &T, out: &mut T, op: SystemOperation)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
}
impl<'a> Root for Process<'a> {
fn broadcast_into<T>(&self, value: &mut T)
where
T: Serialize + DeserializeOwned + Clone,
{
let size = self.communicator.size();
if size <= 1 {
return;
}
let rank = self.communicator.rank();
let root = self.rank;
let vrank = (rank - root).rem_euclid(size);
let mut have_value = vrank == 0;
let mut step = 1;
while step < size {
if !have_value && vrank >= step && vrank < step * 2 {
let src_vrank = vrank - step;
let src = (src_vrank + root).rem_euclid(size);
let (received, _status) = self
.communicator
.runtime
.receive(Some(src), Some(BROADCAST_TAG))
.expect("jsmpi broadcast failed");
*value = received;
have_value = true;
}
if have_value {
let dst_vrank = vrank + step;
if dst_vrank < size {
let dst = (dst_vrank + root).rem_euclid(size);
self.communicator
.runtime
.send(rank, dst, BROADCAST_TAG, value)
.expect("jsmpi broadcast failed");
}
}
step <<= 1;
}
}
fn gather_into_root<T>(&self, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone,
{
let self_rank = self.communicator.rank();
let root_rank = self.rank;
if self_rank == root_rank {
out.clear();
out.reserve(self.communicator.size() as usize);
for src in 0..self.communicator.size() {
if src == root_rank {
out.push(value.clone());
} else {
let (received, _status) = self
.communicator
.runtime
.receive(Some(src), Some(GATHER_TAG))
.expect("jsmpi gather failed");
out.push(received);
}
}
} else {
self.communicator
.runtime
.send(self_rank, root_rank, GATHER_TAG, value)
.expect("jsmpi gather failed");
}
}
fn scatter_into_root<T>(&self, input: &[T], out: &mut T)
where
T: Serialize + DeserializeOwned + Clone,
{
let self_rank = self.communicator.rank();
let root_rank = self.rank;
if self_rank == root_rank {
assert_eq!(
input.len(),
self.communicator.size() as usize,
"scatter input length must equal communicator size"
);
for dst in 0..self.communicator.size() {
if dst == root_rank {
*out = input[dst as usize].clone();
} else {
self.communicator
.runtime
.send(root_rank, dst, SCATTER_TAG, &input[dst as usize])
.expect("jsmpi scatter failed");
}
}
} else {
let (received, _status) = self
.communicator
.runtime
.receive(Some(root_rank), Some(SCATTER_TAG))
.expect("jsmpi scatter failed");
*out = received;
}
}
fn reduce_sum_into_root<T>(&self, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
let self_rank = self.communicator.rank();
let root_rank = self.rank;
if self_rank == root_rank {
let mut reduced = value.clone();
for src in 0..self.communicator.size() {
if src == root_rank {
continue;
}
let (received, _status) = self
.communicator
.runtime
.receive(Some(src), Some(REDUCE_TAG))
.expect("jsmpi reduce failed");
reduced = reduced + received;
}
*out = reduced;
} else {
self.communicator
.runtime
.send(self_rank, root_rank, REDUCE_TAG, value)
.expect("jsmpi reduce failed");
}
}
fn all_reduce_sum_into<T>(&self, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
let self_rank = self.communicator.rank();
let root_rank = self.rank;
if self_rank == root_rank {
let mut reduced = value.clone();
for src in 0..self.communicator.size() {
if src == root_rank {
continue;
}
let (received, _status) = self
.communicator
.runtime
.receive(Some(src), Some(REDUCE_TAG))
.expect("jsmpi all_reduce failed");
reduced = reduced + received;
}
*out = reduced.clone();
self.broadcast_into(out);
} else {
self.communicator
.runtime
.send(self_rank, root_rank, REDUCE_TAG, value)
.expect("jsmpi all_reduce failed");
self.broadcast_into(out);
}
}
fn reduce_into_root<T>(&self, value: &T, out: &mut T, op: SystemOperation)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
match op {
SystemOperation::Sum => self.reduce_sum_into_root(value, out),
}
}
fn all_reduce_into<T>(&self, value: &T, out: &mut T, op: SystemOperation)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
match op {
SystemOperation::Sum => self.all_reduce_sum_into(value, out),
}
}
}