jsmpi 0.1.0

A browser-oriented MPI compatibility layer for Rust/WASM using Web Workers
Documentation
use std::ops::Add;

use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::topology::{Communicator, Process};

pub(crate) const BROADCAST_TAG: i32 = -1;
pub(crate) const GATHER_TAG: i32 = -2;
pub(crate) const SCATTER_TAG: i32 = -3;
pub(crate) const REDUCE_TAG: i32 = -4;

#[derive(Clone, Copy, Debug)]
pub enum SystemOperation {
    Sum,
}

impl SystemOperation {
    pub fn sum() -> Self {
        Self::Sum
    }
}

pub trait Root {
    fn broadcast_into<T>(&self, value: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone;

    fn gather_into_root<T>(&self, value: &T, out: &mut Vec<T>)
    where
        T: Serialize + DeserializeOwned + Clone;

    fn scatter_into_root<T>(&self, input: &[T], out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone;

    fn reduce_sum_into_root<T>(&self, value: &T, out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>;

    fn all_reduce_sum_into<T>(&self, value: &T, out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>;

    fn reduce_into_root<T>(&self, value: &T, out: &mut T, op: SystemOperation)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>;

    fn all_reduce_into<T>(&self, value: &T, out: &mut T, op: SystemOperation)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
}

impl<'a> Root for Process<'a> {
    fn broadcast_into<T>(&self, value: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone,
    {
        let size = self.communicator.size();
        if size <= 1 {
            return;
        }

        let rank = self.communicator.rank();
        let root = self.rank;
        // Map global ranks into a virtual rank space where root is 0.
        let vrank = (rank - root).rem_euclid(size);

        let mut have_value = vrank == 0;
        let mut step = 1;
        while step < size {
            if !have_value && vrank >= step && vrank < step * 2 {
                let src_vrank = vrank - step;
                let src = (src_vrank + root).rem_euclid(size);
                let (received, _status) = self
                    .communicator
                    .runtime
                    .receive(Some(src), Some(BROADCAST_TAG))
                    .expect("jsmpi broadcast failed");
                *value = received;
                have_value = true;
            }

            if have_value {
                let dst_vrank = vrank + step;
                if dst_vrank < size {
                    let dst = (dst_vrank + root).rem_euclid(size);
                    self.communicator
                        .runtime
                        .send(rank, dst, BROADCAST_TAG, value)
                        .expect("jsmpi broadcast failed");
                }
            }

            step <<= 1;
        }
    }

    fn gather_into_root<T>(&self, value: &T, out: &mut Vec<T>)
    where
        T: Serialize + DeserializeOwned + Clone,
    {
        let self_rank = self.communicator.rank();
        let root_rank = self.rank;

        if self_rank == root_rank {
            out.clear();
            out.reserve(self.communicator.size() as usize);
            for src in 0..self.communicator.size() {
                if src == root_rank {
                    out.push(value.clone());
                } else {
                    let (received, _status) = self
                        .communicator
                        .runtime
                        .receive(Some(src), Some(GATHER_TAG))
                        .expect("jsmpi gather failed");
                    out.push(received);
                }
            }
        } else {
            self.communicator
                .runtime
                .send(self_rank, root_rank, GATHER_TAG, value)
                .expect("jsmpi gather failed");
        }
    }

    fn scatter_into_root<T>(&self, input: &[T], out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone,
    {
        let self_rank = self.communicator.rank();
        let root_rank = self.rank;

        if self_rank == root_rank {
            assert_eq!(
                input.len(),
                self.communicator.size() as usize,
                "scatter input length must equal communicator size"
            );

            for dst in 0..self.communicator.size() {
                if dst == root_rank {
                    *out = input[dst as usize].clone();
                } else {
                    self.communicator
                        .runtime
                        .send(root_rank, dst, SCATTER_TAG, &input[dst as usize])
                        .expect("jsmpi scatter failed");
                }
            }
        } else {
            let (received, _status) = self
                .communicator
                .runtime
                .receive(Some(root_rank), Some(SCATTER_TAG))
                .expect("jsmpi scatter failed");
            *out = received;
        }
    }

    fn reduce_sum_into_root<T>(&self, value: &T, out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
    {
        let self_rank = self.communicator.rank();
        let root_rank = self.rank;

        if self_rank == root_rank {
            let mut reduced = value.clone();
            for src in 0..self.communicator.size() {
                if src == root_rank {
                    continue;
                }
                let (received, _status) = self
                    .communicator
                    .runtime
                    .receive(Some(src), Some(REDUCE_TAG))
                    .expect("jsmpi reduce failed");
                reduced = reduced + received;
            }
            *out = reduced;
        } else {
            self.communicator
                .runtime
                .send(self_rank, root_rank, REDUCE_TAG, value)
                .expect("jsmpi reduce failed");
        }
    }

    fn all_reduce_sum_into<T>(&self, value: &T, out: &mut T)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
    {
        let self_rank = self.communicator.rank();
        let root_rank = self.rank;

        if self_rank == root_rank {
            let mut reduced = value.clone();
            for src in 0..self.communicator.size() {
                if src == root_rank {
                    continue;
                }
                let (received, _status) = self
                    .communicator
                    .runtime
                    .receive(Some(src), Some(REDUCE_TAG))
                    .expect("jsmpi all_reduce failed");
                reduced = reduced + received;
            }

            *out = reduced.clone();
            // Use broadcast path for dissemination so large communicator fan-out
            // doesn't overload root with linear sends.
            self.broadcast_into(out);
        } else {
            self.communicator
                .runtime
                .send(self_rank, root_rank, REDUCE_TAG, value)
                .expect("jsmpi all_reduce failed");

            self.broadcast_into(out);
        }
    }

    fn reduce_into_root<T>(&self, value: &T, out: &mut T, op: SystemOperation)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
    {
        match op {
            SystemOperation::Sum => self.reduce_sum_into_root(value, out),
        }
    }

    fn all_reduce_into<T>(&self, value: &T, out: &mut T, op: SystemOperation)
    where
        T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
    {
        match op {
            SystemOperation::Sum => self.all_reduce_sum_into(value, out),
        }
    }
}