use std::marker::PhantomData;
use std::cell::Cell;
use std::ops::Add;
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::collective::{BROADCAST_TAG, GATHER_TAG, REDUCE_TAG, SCATTER_TAG, Root};
use crate::datatype::{Rank, Status, Tag};
use crate::error::{Error, Result};
use crate::runtime::{RetryPolicy, Runtime};
#[cfg(target_arch = "wasm32")]
fn read_worker_numeric_config(key: &str) -> Option<f64> {
use wasm_bindgen::JsValue;
let global = js_sys::global();
js_sys::Reflect::get(&global, &JsValue::from_str(key))
.ok()
.and_then(|v| v.as_f64())
}
#[cfg(not(target_arch = "wasm32"))]
fn read_worker_numeric_config(_key: &str) -> Option<f64> {
None
}
const DEFAULT_TAG: Tag = 0;
const CHUNK_LENGTH_PREFIX_BYTES: usize = 8;
fn split_chunk_tags(tag: Tag) -> (Tag, Tag) {
let data_tag = tag
.checked_add(1)
.expect("chunked transfer tag overflow");
(tag, data_tag)
}
fn encode_chunk_length(len: usize) -> [u8; CHUNK_LENGTH_PREFIX_BYTES] {
(len as u64).to_le_bytes()
}
fn decode_chunk_length(prefix: &[u8]) -> usize {
assert_eq!(
prefix.len(),
CHUNK_LENGTH_PREFIX_BYTES,
"chunked transfer length prefix must be 8 bytes"
);
let mut raw = [0_u8; CHUNK_LENGTH_PREFIX_BYTES];
raw.copy_from_slice(prefix);
u64::from_le_bytes(raw) as usize
}
#[derive(Clone, Debug)]
pub struct SimpleCommunicator {
pub(crate) runtime: Runtime,
}
pub type SystemCommunicator = SimpleCommunicator;
#[derive(Clone, Copy, Debug)]
pub struct Process<'a> {
pub(crate) communicator: &'a SimpleCommunicator,
pub(crate) rank: Rank,
}
#[derive(Clone, Copy, Debug)]
pub struct AnyProcess<'a> {
communicator: &'a SimpleCommunicator,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestState {
Pending,
Completed,
Canceled,
Freed,
}
#[derive(Clone, Debug)]
pub struct ImmediateSendRequest {
state: Cell<RequestState>,
}
impl ImmediateSendRequest {
fn completed() -> Self {
Self {
state: Cell::new(RequestState::Completed),
}
}
pub fn state(&self) -> RequestState {
self.state.get()
}
pub fn is_pending(&self) -> bool {
matches!(self.state(), RequestState::Pending)
}
pub fn test(&self) -> bool {
matches!(self.state(), RequestState::Completed)
}
pub fn wait(self) {}
pub fn cancel(&mut self) -> Result<()> {
match self.state.get() {
RequestState::Pending => {
self.state.set(RequestState::Canceled);
Ok(())
}
RequestState::Completed => Err(Error::Protocol(
"cannot cancel a completed immediate send request".to_string(),
)),
RequestState::Canceled => Ok(()),
RequestState::Freed => Err(Error::Protocol(
"cannot cancel a freed immediate send request".to_string(),
)),
}
}
pub fn free(&mut self) -> Result<()> {
if matches!(self.state.get(), RequestState::Freed) {
return Ok(());
}
self.state.set(RequestState::Freed);
Ok(())
}
pub fn test_all(requests: &[Self]) -> bool {
requests.iter().all(|request| request.test())
}
pub fn wait_all(requests: Vec<Self>) {
for request in requests {
request.wait();
}
}
}
#[derive(Clone, Debug)]
pub struct ImmediateReceiveRequest<T> {
runtime: Runtime,
source: Option<Rank>,
tag: Option<Tag>,
state: Cell<RequestState>,
_marker: PhantomData<T>,
}
impl<T> ImmediateReceiveRequest<T>
where
T: DeserializeOwned,
{
fn new(runtime: Runtime, source: Option<Rank>, tag: Option<Tag>) -> Self {
Self {
runtime,
source,
tag,
state: Cell::new(RequestState::Pending),
_marker: PhantomData,
}
}
pub fn state(&self) -> RequestState {
self.state.get()
}
pub fn is_pending(&self) -> bool {
matches!(self.state(), RequestState::Pending)
}
pub fn test(&self) -> Option<(T, Status)> {
if !self.is_pending() {
return None;
}
self.runtime
.try_receive::<T>(self.source, self.tag)
.expect("jsmpi immediate receive failed")
.map(|result| {
self.state.set(RequestState::Completed);
result
})
}
pub fn wait(&self) -> (T, Status) {
match self.state.get() {
RequestState::Canceled => {
panic!("immediate receive request was canceled before wait")
}
RequestState::Freed => panic!("immediate receive request was freed before wait"),
RequestState::Completed => {
panic!("immediate receive request has already been completed")
}
RequestState::Pending => {}
}
let result = self
.runtime
.receive::<T>(self.source, self.tag)
.expect("jsmpi immediate receive failed");
self.state.set(RequestState::Completed);
result
}
pub fn wait_into(&self, out: &mut T) -> Status {
let (value, status) = self.wait();
*out = value;
status
}
pub fn cancel(&mut self) -> Result<()> {
match self.state.get() {
RequestState::Pending => {
self.state.set(RequestState::Canceled);
Ok(())
}
RequestState::Completed => Err(Error::Protocol(
"cannot cancel a completed immediate receive request".to_string(),
)),
RequestState::Canceled => Ok(()),
RequestState::Freed => Err(Error::Protocol(
"cannot cancel a freed immediate receive request".to_string(),
)),
}
}
pub fn free(&mut self) -> Result<()> {
if matches!(self.state.get(), RequestState::Freed) {
return Ok(());
}
self.state.set(RequestState::Freed);
Ok(())
}
pub fn test_any(requests: &[Self]) -> Option<(usize, (T, Status))> {
for (idx, request) in requests.iter().enumerate() {
if let Some(result) = request.test() {
return Some((idx, result));
}
}
None
}
pub fn wait_any(requests: &[Self]) -> (usize, (T, Status)) {
loop {
if let Some(result) = Self::test_any(requests) {
return result;
}
assert!(
requests.iter().any(|request| request.is_pending()),
"wait_any has no pending requests"
);
#[cfg(not(target_arch = "wasm32"))]
std::thread::yield_now();
}
}
pub fn test_all(requests: &[Self]) -> Vec<Option<(T, Status)>> {
requests.iter().map(|request| request.test()).collect()
}
pub fn wait_all(requests: &[Self]) -> Vec<(T, Status)> {
requests.iter().map(|request| request.wait()).collect()
}
pub fn test_some(requests: &[Self]) -> Vec<(usize, (T, Status))> {
requests
.iter()
.enumerate()
.filter_map(|(idx, request)| request.test().map(|result| (idx, result)))
.collect()
}
pub fn wait_some(requests: &[Self]) -> Vec<(usize, (T, Status))> {
loop {
let ready = Self::test_some(requests);
if !ready.is_empty() {
return ready;
}
assert!(
requests.iter().any(|request| request.is_pending()),
"wait_some has no pending requests"
);
#[cfg(not(target_arch = "wasm32"))]
std::thread::yield_now();
}
}
}
pub trait Communicator {
fn rank(&self) -> Rank;
fn size(&self) -> Rank;
fn process_at_rank(&self, rank: Rank) -> Process<'_>;
fn any_process(&self) -> AnyProcess<'_>;
fn barrier(&self);
fn all_gather_into<T>(&self, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone;
fn broadcast_into_from<T>(&self, root: Rank, value: &mut T)
where
T: Serialize + DeserializeOwned + Clone;
fn broadcast_into_from_with_timeout<T>(&self, root: Rank, value: &mut T, timeout: Duration) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn broadcast_into_from_with_retry_timeout<T>(
&self,
root: Rank,
value: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn gather_into_root<T>(&self, root: Rank, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone;
fn gather_into_root_with_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut Vec<T>,
timeout: Duration,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn gather_into_root_with_retry_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut Vec<T>,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn scatter_into_root<T>(&self, root: Rank, input: &[T], out: &mut T)
where
T: Serialize + DeserializeOwned + Clone;
fn scatter_into_root_with_timeout<T>(
&self,
root: Rank,
input: &[T],
out: &mut T,
timeout: Duration,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn scatter_into_root_with_retry_timeout<T>(
&self,
root: Rank,
input: &[T],
out: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn reduce_sum_into_root<T>(&self, root: Rank, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
fn all_reduce_sum_into_from<T>(&self, root: Rank, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
fn all_gather_into_with_timeout<T>(&self, value: &T, out: &mut Vec<T>, timeout: Duration) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone;
fn all_reduce_sum_into_from_with_retry_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>;
}
pub trait Destination {
fn send<T>(&self, value: &T)
where
T: Serialize;
fn send_with_tag<T>(&self, value: &T, tag: Tag)
where
T: Serialize;
fn send_slice<T>(&self, values: &[T])
where
T: Serialize;
fn send_slice_with_tag<T>(&self, values: &[T], tag: Tag)
where
T: Serialize;
}
pub trait ImmediateDestination {
fn immediate_send<T>(&self, value: &T) -> ImmediateSendRequest
where
T: Serialize;
fn immediate_send_with_tag<T>(&self, value: &T, tag: Tag) -> ImmediateSendRequest
where
T: Serialize;
}
pub trait DestinationBytes {
fn send_bytes(&self, payload: &[u8]);
fn send_bytes_with_tag(&self, payload: &[u8], tag: Tag);
fn send_bytes_chunked(&self, payload: &[u8], chunk_size: usize);
fn send_bytes_chunked_with_tag(&self, payload: &[u8], tag: Tag, chunk_size: usize);
}
pub trait Source {
fn receive<T>(&self) -> (T, Status)
where
T: DeserializeOwned;
fn receive_with_tag<T>(&self, tag: Tag) -> (T, Status)
where
T: DeserializeOwned;
fn receive_into<T>(&self, out: &mut T) -> Status
where
T: DeserializeOwned;
fn receive_into_with_tag<T>(&self, out: &mut T, tag: Tag) -> Status
where
T: DeserializeOwned;
fn receive_vec<T>(&self) -> (Vec<T>, Status)
where
T: DeserializeOwned;
fn receive_vec_with_tag<T>(&self, tag: Tag) -> (Vec<T>, Status)
where
T: DeserializeOwned;
fn receive_slice_into<T>(&self, out: &mut [T]) -> Status
where
T: DeserializeOwned;
fn receive_slice_into_with_tag<T>(&self, out: &mut [T], tag: Tag) -> Status
where
T: DeserializeOwned;
}
pub trait ImmediateSource {
fn immediate_receive<T>(&self) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned;
fn immediate_receive_with_tag<T>(&self, tag: Tag) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned;
}
pub trait SourceBytes {
fn receive_bytes(&self) -> (Vec<u8>, Status);
fn receive_bytes_with_tag(&self, tag: Tag) -> (Vec<u8>, Status);
fn receive_bytes_into(&self, out: &mut [u8]) -> Status;
fn receive_bytes_into_with_tag(&self, out: &mut [u8], tag: Tag) -> Status;
fn receive_bytes_chunked(&self) -> (Vec<u8>, Status);
fn receive_bytes_chunked_with_tag(&self, tag: Tag) -> (Vec<u8>, Status);
}
impl SimpleCommunicator {
pub(crate) fn new(runtime: Runtime) -> Self {
Self { runtime }
}
fn default_collective_timeout() -> Duration {
let configured = read_worker_numeric_config("__jsmpi_collective_timeout_ms")
.map(|v| v.max(1.0) as u64);
Duration::from_millis(configured.unwrap_or(15_000))
}
fn default_collective_retry_policy() -> RetryPolicy {
let mut policy = RetryPolicy::default();
if let Some(v) = read_worker_numeric_config("__jsmpi_collective_max_retries") {
policy.max_retries = v.max(0.0) as u32;
}
policy
}
fn broadcast_with_timeout<T>(&self, root: Rank, value: &mut T, timeout: Duration) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let size = self.size();
if size <= 1 {
return Ok(());
}
let rank = self.rank();
let vrank = (rank - root).rem_euclid(size);
let mut have_value = vrank == 0;
let mut step = 1;
while step < size {
if !have_value && vrank >= step && vrank < step * 2 {
let src_vrank = vrank - step;
let src = (src_vrank + root).rem_euclid(size);
let (received, _status) = self
.runtime
.receive_with_timeout(Some(src), Some(BROADCAST_TAG), timeout)?;
*value = received;
have_value = true;
}
if have_value {
let dst_vrank = vrank + step;
if dst_vrank < size {
let dst = (dst_vrank + root).rem_euclid(size);
self.runtime
.send_with_timeout(rank, dst, BROADCAST_TAG, value, timeout)?;
}
}
step <<= 1;
}
Ok(())
}
}
impl Communicator for SimpleCommunicator {
fn rank(&self) -> Rank {
self.runtime.rank()
}
fn size(&self) -> Rank {
self.runtime.size()
}
fn process_at_rank(&self, rank: Rank) -> Process<'_> {
Process {
communicator: self,
rank,
}
}
fn any_process(&self) -> AnyProcess<'_> {
AnyProcess { communicator: self }
}
fn barrier(&self) {
self.runtime.barrier().expect("jsmpi barrier failed");
}
fn all_gather_into<T>(&self, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone,
{
self
.all_gather_into_with_timeout(value, out, Self::default_collective_timeout())
.expect("jsmpi all_gather failed");
}
fn broadcast_into_from<T>(&self, root: Rank, value: &mut T)
where
T: Serialize + DeserializeOwned + Clone,
{
self
.broadcast_into_from_with_timeout(root, value, Self::default_collective_timeout())
.expect("jsmpi broadcast failed");
}
fn broadcast_into_from_with_timeout<T>(&self, root: Rank, value: &mut T, timeout: Duration) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
self.broadcast_with_timeout(root, value, timeout)
}
fn broadcast_into_from_with_retry_timeout<T>(
&self,
root: Rank,
value: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let mut attempt = 0_u32;
loop {
match self.broadcast_with_timeout(root, value, timeout_per_attempt) {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
let backoff = retry_policy.backoff_for_attempt(attempt);
#[cfg(not(target_arch = "wasm32"))]
std::thread::sleep(backoff);
#[cfg(target_arch = "wasm32")]
let _ = backoff;
}
Err(err) => return Err(err),
}
}
}
fn gather_into_root<T>(&self, root: Rank, value: &T, out: &mut Vec<T>)
where
T: Serialize + DeserializeOwned + Clone,
{
self
.gather_into_root_with_timeout(root, value, out, Self::default_collective_timeout())
.expect("jsmpi gather failed");
}
fn gather_into_root_with_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut Vec<T>,
timeout: Duration,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let self_rank = self.rank();
if self_rank == root {
out.clear();
out.reserve(self.size() as usize);
for src in 0..self.size() {
if src == root {
out.push(value.clone());
} else {
let (received, _status) = self
.runtime
.receive_with_timeout(Some(src), Some(GATHER_TAG), timeout)?;
out.push(received);
}
}
Ok(())
} else {
self.runtime
.send_with_timeout(self_rank, root, GATHER_TAG, value, timeout)
}
}
fn gather_into_root_with_retry_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut Vec<T>,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let mut attempt = 0_u32;
loop {
match self.gather_into_root_with_timeout(root, value, out, timeout_per_attempt) {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
let backoff = retry_policy.backoff_for_attempt(attempt);
#[cfg(not(target_arch = "wasm32"))]
std::thread::sleep(backoff);
#[cfg(target_arch = "wasm32")]
let _ = backoff;
}
Err(err) => return Err(err),
}
}
}
fn scatter_into_root<T>(&self, root: Rank, input: &[T], out: &mut T)
where
T: Serialize + DeserializeOwned + Clone,
{
self
.scatter_into_root_with_timeout(root, input, out, Self::default_collective_timeout())
.expect("jsmpi scatter failed");
}
fn scatter_into_root_with_timeout<T>(
&self,
root: Rank,
input: &[T],
out: &mut T,
timeout: Duration,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let self_rank = self.rank();
if self_rank == root {
if input.len() != self.size() as usize {
return Err(Error::Protocol(
"scatter input length must equal communicator size".to_string(),
));
}
for dst in 0..self.size() {
if dst == root {
*out = input[dst as usize].clone();
} else {
self.runtime
.send_with_timeout(root, dst, SCATTER_TAG, &input[dst as usize], timeout)?;
}
}
Ok(())
} else {
let (received, _status) = self
.runtime
.receive_with_timeout(Some(root), Some(SCATTER_TAG), timeout)?;
*out = received;
Ok(())
}
}
fn scatter_into_root_with_retry_timeout<T>(
&self,
root: Rank,
input: &[T],
out: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let mut attempt = 0_u32;
loop {
match self.scatter_into_root_with_timeout(root, input, out, timeout_per_attempt) {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
let backoff = retry_policy.backoff_for_attempt(attempt);
#[cfg(not(target_arch = "wasm32"))]
std::thread::sleep(backoff);
#[cfg(target_arch = "wasm32")]
let _ = backoff;
}
Err(err) => return Err(err),
}
}
}
fn reduce_sum_into_root<T>(&self, root: Rank, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
self.process_at_rank(root).reduce_sum_into_root(value, out);
}
fn all_reduce_sum_into_from<T>(&self, root: Rank, value: &T, out: &mut T)
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
let policy = Self::default_collective_retry_policy();
self
.all_reduce_sum_into_from_with_retry_timeout(
root,
value,
out,
Self::default_collective_timeout(),
&policy,
)
.expect("jsmpi all_reduce failed");
}
fn all_gather_into_with_timeout<T>(&self, value: &T, out: &mut Vec<T>, timeout: Duration) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone,
{
let self_rank = self.rank();
let root_rank = 0;
let size = self.size();
let mut gathered = Vec::<T>::new();
if self_rank == root_rank {
gathered.clear();
gathered.reserve(size as usize);
gathered.push(value.clone());
for src in 1..size {
let (received, _status) = self
.runtime
.receive_with_timeout(Some(src), Some(GATHER_TAG), timeout)?;
gathered.push(received);
}
} else {
self.runtime
.send_with_timeout(self_rank, root_rank, GATHER_TAG, value, timeout)?;
}
self.broadcast_with_timeout(root_rank, &mut gathered, timeout)?;
*out = gathered;
Ok(())
}
fn all_reduce_sum_into_from_with_retry_timeout<T>(
&self,
root: Rank,
value: &T,
out: &mut T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize + DeserializeOwned + Clone + Add<Output = T>,
{
let mut attempt = 0_u32;
loop {
let result = (|| {
let self_rank = self.rank();
let size = self.size();
if self_rank == root {
let mut reduced = value.clone();
for src in 0..size {
if src == root {
continue;
}
let (received, _status) = self
.runtime
.receive_with_timeout(Some(src), Some(REDUCE_TAG), timeout_per_attempt)?;
reduced = reduced + received;
}
*out = reduced;
} else {
self.runtime.send_with_timeout(
self_rank,
root,
REDUCE_TAG,
value,
timeout_per_attempt,
)?;
}
self.broadcast_with_timeout(root, out, timeout_per_attempt)
})();
match result {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
let backoff = retry_policy.backoff_for_attempt(attempt);
#[cfg(not(target_arch = "wasm32"))]
std::thread::sleep(backoff);
#[cfg(target_arch = "wasm32")]
let _ = backoff;
}
Err(err) => return Err(err),
}
}
}
}
impl<'a> Destination for Process<'a> {
fn send<T>(&self, value: &T)
where
T: Serialize,
{
self.send_with_tag(value, DEFAULT_TAG);
}
fn send_with_tag<T>(&self, value: &T, tag: Tag)
where
T: Serialize,
{
let src = self.communicator.rank();
self.communicator
.runtime
.send(src, self.rank, tag, value)
.expect("jsmpi send failed");
}
fn send_slice<T>(&self, values: &[T])
where
T: Serialize,
{
self.send_slice_with_tag(values, DEFAULT_TAG);
}
fn send_slice_with_tag<T>(&self, values: &[T], tag: Tag)
where
T: Serialize,
{
let src = self.communicator.rank();
self.communicator
.runtime
.send(src, self.rank, tag, &values)
.expect("jsmpi send failed");
}
}
impl<'a> ImmediateDestination for Process<'a> {
fn immediate_send<T>(&self, value: &T) -> ImmediateSendRequest
where
T: Serialize,
{
self.immediate_send_with_tag(value, DEFAULT_TAG)
}
fn immediate_send_with_tag<T>(&self, value: &T, tag: Tag) -> ImmediateSendRequest
where
T: Serialize,
{
self.send_with_tag(value, tag);
ImmediateSendRequest::completed()
}
}
impl<'a> DestinationBytes for Process<'a> {
fn send_bytes(&self, payload: &[u8]) {
self.send_bytes_with_tag(payload, DEFAULT_TAG);
}
fn send_bytes_with_tag(&self, payload: &[u8], tag: Tag) {
let src = self.communicator.rank();
self.communicator
.runtime
.send_bytes(src, self.rank, tag, payload)
.expect("jsmpi send failed");
}
fn send_bytes_chunked(&self, payload: &[u8], chunk_size: usize) {
self.send_bytes_chunked_with_tag(payload, DEFAULT_TAG, chunk_size);
}
fn send_bytes_chunked_with_tag(&self, payload: &[u8], tag: Tag, chunk_size: usize) {
let (length_tag, data_tag) = split_chunk_tags(tag);
let normalized_chunk_size = chunk_size.max(1);
self.send_bytes_with_tag(&encode_chunk_length(payload.len()), length_tag);
for chunk in payload.chunks(normalized_chunk_size) {
self.send_bytes_with_tag(chunk, data_tag);
}
}
}
impl<'a> Source for Process<'a> {
fn receive<T>(&self) -> (T, Status)
where
T: DeserializeOwned,
{
self.receive_with_tag::<T>(DEFAULT_TAG)
}
fn receive_with_tag<T>(&self, tag: Tag) -> (T, Status)
where
T: DeserializeOwned,
{
self.communicator
.runtime
.receive(Some(self.rank), Some(tag))
.expect("jsmpi receive failed")
}
fn receive_into<T>(&self, out: &mut T) -> Status
where
T: DeserializeOwned,
{
let (value, status) = self.receive::<T>();
*out = value;
status
}
fn receive_into_with_tag<T>(&self, out: &mut T, tag: Tag) -> Status
where
T: DeserializeOwned,
{
let (value, status) = self.receive_with_tag::<T>(tag);
*out = value;
status
}
fn receive_vec<T>(&self) -> (Vec<T>, Status)
where
T: DeserializeOwned,
{
self.receive_vec_with_tag::<T>(DEFAULT_TAG)
}
fn receive_vec_with_tag<T>(&self, tag: Tag) -> (Vec<T>, Status)
where
T: DeserializeOwned,
{
self.receive_with_tag::<Vec<T>>(tag)
}
fn receive_slice_into<T>(&self, out: &mut [T]) -> Status
where
T: DeserializeOwned,
{
self.receive_slice_into_with_tag(out, DEFAULT_TAG)
}
fn receive_slice_into_with_tag<T>(&self, out: &mut [T], tag: Tag) -> Status
where
T: DeserializeOwned,
{
let (values, status) = self.receive_vec_with_tag::<T>(tag);
assert_eq!(
values.len(),
out.len(),
"receive_slice_into length mismatch: expected {}, got {}",
out.len(),
values.len()
);
for (slot, value) in out.iter_mut().zip(values.into_iter()) {
*slot = value;
}
status
}
}
impl<'a> ImmediateSource for Process<'a> {
fn immediate_receive<T>(&self) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned,
{
self.immediate_receive_with_tag(DEFAULT_TAG)
}
fn immediate_receive_with_tag<T>(&self, tag: Tag) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned,
{
ImmediateReceiveRequest::new(self.communicator.runtime.clone(), Some(self.rank), Some(tag))
}
}
impl<'a> SourceBytes for Process<'a> {
fn receive_bytes(&self) -> (Vec<u8>, Status) {
self.receive_bytes_with_tag(DEFAULT_TAG)
}
fn receive_bytes_with_tag(&self, tag: Tag) -> (Vec<u8>, Status) {
self.communicator
.runtime
.receive_bytes(Some(self.rank), Some(tag))
.expect("jsmpi receive failed")
}
fn receive_bytes_into(&self, out: &mut [u8]) -> Status {
self.receive_bytes_into_with_tag(out, DEFAULT_TAG)
}
fn receive_bytes_into_with_tag(&self, out: &mut [u8], tag: Tag) -> Status {
let (received, status) = self.receive_bytes_with_tag(tag);
assert_eq!(
received.len(),
out.len(),
"receive_bytes_into length mismatch: expected {}, got {}",
out.len(),
received.len()
);
out.copy_from_slice(&received);
status
}
fn receive_bytes_chunked(&self) -> (Vec<u8>, Status) {
self.receive_bytes_chunked_with_tag(DEFAULT_TAG)
}
fn receive_bytes_chunked_with_tag(&self, tag: Tag) -> (Vec<u8>, Status) {
let (length_tag, data_tag) = split_chunk_tags(tag);
let (prefix, mut status) = self.receive_bytes_with_tag(length_tag);
let expected_len = decode_chunk_length(&prefix);
let mut payload = Vec::with_capacity(expected_len);
while payload.len() < expected_len {
let (chunk, _chunk_status) = self.receive_bytes_with_tag(data_tag);
let remaining = expected_len - payload.len();
if chunk.len() <= remaining {
payload.extend_from_slice(&chunk);
} else {
payload.extend_from_slice(&chunk[..remaining]);
}
}
status.tag = tag;
(payload, status)
}
}
impl<'a> Source for AnyProcess<'a> {
fn receive<T>(&self) -> (T, Status)
where
T: DeserializeOwned,
{
self.communicator
.runtime
.receive(None, None)
.expect("jsmpi receive failed")
}
fn receive_with_tag<T>(&self, tag: Tag) -> (T, Status)
where
T: DeserializeOwned,
{
self.communicator
.runtime
.receive(None, Some(tag))
.expect("jsmpi receive failed")
}
fn receive_into<T>(&self, out: &mut T) -> Status
where
T: DeserializeOwned,
{
let (value, status) = self.receive::<T>();
*out = value;
status
}
fn receive_into_with_tag<T>(&self, out: &mut T, tag: Tag) -> Status
where
T: DeserializeOwned,
{
let (value, status) = self.receive_with_tag::<T>(tag);
*out = value;
status
}
fn receive_vec<T>(&self) -> (Vec<T>, Status)
where
T: DeserializeOwned,
{
self.receive_vec_with_tag::<T>(DEFAULT_TAG)
}
fn receive_vec_with_tag<T>(&self, tag: Tag) -> (Vec<T>, Status)
where
T: DeserializeOwned,
{
self.communicator
.runtime
.receive(None, Some(tag))
.expect("jsmpi receive failed")
}
fn receive_slice_into<T>(&self, out: &mut [T]) -> Status
where
T: DeserializeOwned,
{
self.receive_slice_into_with_tag(out, DEFAULT_TAG)
}
fn receive_slice_into_with_tag<T>(&self, out: &mut [T], tag: Tag) -> Status
where
T: DeserializeOwned,
{
let (values, status) = self.receive_vec_with_tag::<T>(tag);
assert_eq!(
values.len(),
out.len(),
"receive_slice_into length mismatch: expected {}, got {}",
out.len(),
values.len()
);
for (slot, value) in out.iter_mut().zip(values.into_iter()) {
*slot = value;
}
status
}
}
impl<'a> ImmediateSource for AnyProcess<'a> {
fn immediate_receive<T>(&self) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned,
{
ImmediateReceiveRequest::new(self.communicator.runtime.clone(), None, None)
}
fn immediate_receive_with_tag<T>(&self, tag: Tag) -> ImmediateReceiveRequest<T>
where
T: DeserializeOwned,
{
ImmediateReceiveRequest::new(self.communicator.runtime.clone(), None, Some(tag))
}
}
impl<'a> SourceBytes for AnyProcess<'a> {
fn receive_bytes(&self) -> (Vec<u8>, Status) {
self.receive_bytes_with_tag(DEFAULT_TAG)
}
fn receive_bytes_with_tag(&self, tag: Tag) -> (Vec<u8>, Status) {
self.communicator
.runtime
.receive_bytes(None, Some(tag))
.expect("jsmpi receive failed")
}
fn receive_bytes_into(&self, out: &mut [u8]) -> Status {
self.receive_bytes_into_with_tag(out, DEFAULT_TAG)
}
fn receive_bytes_into_with_tag(&self, out: &mut [u8], tag: Tag) -> Status {
let (received, status) = self.receive_bytes_with_tag(tag);
assert_eq!(
received.len(),
out.len(),
"receive_bytes_into length mismatch: expected {}, got {}",
out.len(),
received.len()
);
out.copy_from_slice(&received);
status
}
fn receive_bytes_chunked(&self) -> (Vec<u8>, Status) {
self.receive_bytes_chunked_with_tag(DEFAULT_TAG)
}
fn receive_bytes_chunked_with_tag(&self, tag: Tag) -> (Vec<u8>, Status) {
let (length_tag, data_tag) = split_chunk_tags(tag);
let (prefix, mut status) = self.receive_bytes_with_tag(length_tag);
let expected_len = decode_chunk_length(&prefix);
let mut payload = Vec::with_capacity(expected_len);
while payload.len() < expected_len {
let (chunk, _chunk_status) = self
.communicator
.runtime
.receive_bytes(Some(status.source_rank), Some(data_tag))
.expect("jsmpi receive failed");
let remaining = expected_len - payload.len();
if chunk.len() <= remaining {
payload.extend_from_slice(&chunk);
} else {
payload.extend_from_slice(&chunk[..remaining]);
}
}
status.tag = tag;
(payload, status)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::collective::{
BROADCAST_TAG,
GATHER_TAG,
REDUCE_TAG,
Root,
SCATTER_TAG,
SystemOperation,
};
use crate::point_to_point as p2p;
use crate::runtime::{RetryPolicy, Runtime};
use crate::traits::{
Communicator,
Destination,
DestinationBytes,
ImmediateDestination,
ImmediateSource,
Source,
SourceBytes,
};
use super::{ImmediateReceiveRequest, ImmediateSendRequest, RequestState, SimpleCommunicator};
#[test]
fn process_send_and_receive_with_tag_roundtrip() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_with_tag(&42_i32, 9);
let (value, status) = world.process_at_rank(0).receive_with_tag::<i32>(9);
assert_eq!(value, 42);
assert_eq!(status.tag, 9);
assert_eq!(status.source_rank, 0);
}
#[test]
fn any_process_receive_with_tag_filters_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let self_rank = world.process_at_rank(0);
self_rank.send_with_tag(&1_i32, 5);
self_rank.send_with_tag(&2_i32, 6);
let (value, status) = world.any_process().receive_with_tag::<i32>(6);
assert_eq!(value, 2);
assert_eq!(status.tag, 6);
let (leftover, leftover_status) = world.any_process().receive_with_tag::<i32>(5);
assert_eq!(leftover, 1);
assert_eq!(leftover_status.tag, 5);
}
#[test]
fn immediate_send_request_reports_completed() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req = world.process_at_rank(0).immediate_send_with_tag(&55_i32, 12);
assert!(req.test());
req.wait();
let (value, status) = world.process_at_rank(0).receive_with_tag::<i32>(12);
assert_eq!(value, 55);
assert_eq!(status.tag, 12);
}
#[test]
fn immediate_send_test_all_and_wait_all_work() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world.process_at_rank(0).immediate_send_with_tag(&1_i32, 81);
let req_b = world.process_at_rank(0).immediate_send_with_tag(&2_i32, 82);
let requests = vec![req_a, req_b];
assert!(ImmediateSendRequest::test_all(&requests));
ImmediateSendRequest::wait_all(requests);
let (a, s1) = world.process_at_rank(0).receive_with_tag::<i32>(81);
let (b, s2) = world.process_at_rank(0).receive_with_tag::<i32>(82);
assert_eq!(a, 1);
assert_eq!(s1.tag, 81);
assert_eq!(b, 2);
assert_eq!(s2.tag, 82);
}
#[test]
fn immediate_send_cancel_and_free_follow_state_rules() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut req = world.process_at_rank(0).immediate_send(&1_i32);
assert_eq!(req.state(), RequestState::Completed);
let err = req.cancel().expect_err("cancel should fail after completion");
assert!(matches!(err, crate::Error::Protocol(_)));
req.free().expect("free should be a no-op success");
assert_eq!(req.state(), RequestState::Freed);
}
#[test]
fn immediate_receive_test_then_wait() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(33);
assert!(req.test().is_none());
world.process_at_rank(0).send_with_tag(&99_i32, 33);
let (value, status) = req.wait();
assert_eq!(value, 99);
assert_eq!(status.tag, 33);
}
#[test]
fn immediate_receive_wait_into_writes_output() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req = world.any_process().immediate_receive_with_tag::<i32>(21);
world.process_at_rank(0).send_with_tag(&77_i32, 21);
let mut out = 0_i32;
let status = req.wait_into(&mut out);
assert_eq!(out, 77);
assert_eq!(status.tag, 21);
}
#[test]
fn immediate_receive_test_any_returns_first_ready_request() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(41);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(42);
let requests = vec![req_a, req_b];
world.process_at_rank(0).send_with_tag(&420_i32, 42);
let (idx, (value, status)) = ImmediateReceiveRequest::test_any(&requests)
.expect("one request should be ready");
assert_eq!(idx, 1);
assert_eq!(value, 420);
assert_eq!(status.tag, 42);
}
#[test]
fn immediate_receive_wait_any_waits_for_next_ready_request() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(51);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(52);
let requests = vec![req_a, req_b];
world.process_at_rank(0).send_with_tag(&510_i32, 51);
let (idx, (value, status)) = ImmediateReceiveRequest::wait_any(&requests);
assert_eq!(idx, 0);
assert_eq!(value, 510);
assert_eq!(status.tag, 51);
}
#[test]
fn immediate_receive_test_all_reports_per_request_readiness() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(61);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(62);
let requests = vec![req_a, req_b];
world.process_at_rank(0).send_with_tag(&610_i32, 61);
let snapshot = ImmediateReceiveRequest::test_all(&requests);
assert_eq!(snapshot.len(), 2);
assert_eq!(snapshot[0].as_ref().map(|pair| pair.0), Some(610));
assert!(snapshot[1].is_none());
}
#[test]
fn immediate_receive_wait_all_collects_all_results() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(71);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(72);
let requests = vec![req_a, req_b];
world.process_at_rank(0).send_with_tag(&710_i32, 71);
world.process_at_rank(0).send_with_tag(&720_i32, 72);
let results = ImmediateReceiveRequest::wait_all(&requests);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 710);
assert_eq!(results[0].1.tag, 71);
assert_eq!(results[1].0, 720);
assert_eq!(results[1].1.tag, 72);
}
#[test]
fn immediate_receive_test_some_returns_all_ready_requests() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(73);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(74);
let req_c = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(75);
let requests = vec![req_a, req_b, req_c];
world.process_at_rank(0).send_with_tag(&730_i32, 73);
world.process_at_rank(0).send_with_tag(&750_i32, 75);
let ready = ImmediateReceiveRequest::test_some(&requests);
assert_eq!(ready.len(), 2);
assert_eq!(ready[0].0, 0);
assert_eq!(ready[0].1.0, 730);
assert_eq!(ready[0].1.1.tag, 73);
assert_eq!(ready[1].0, 2);
assert_eq!(ready[1].1.0, 750);
assert_eq!(ready[1].1.1.tag, 75);
}
#[test]
fn immediate_receive_wait_some_waits_until_any_ready() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req_a = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(76);
let req_b = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(77);
let requests = vec![req_a, req_b];
world.process_at_rank(0).send_with_tag(&770_i32, 77);
let ready = ImmediateReceiveRequest::wait_some(&requests);
assert_eq!(ready.len(), 1);
assert_eq!(ready[0].0, 1);
assert_eq!(ready[0].1.0, 770);
assert_eq!(ready[0].1.1.tag, 77);
}
#[test]
fn immediate_receive_cancel_and_free_follow_state_rules() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut req = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(88);
assert_eq!(req.state(), RequestState::Pending);
req.cancel().expect("cancel should succeed while pending");
assert_eq!(req.state(), RequestState::Canceled);
req.free().expect("free should be a no-op success");
assert_eq!(req.state(), RequestState::Freed);
}
#[test]
fn immediate_receive_test_marks_request_completed() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let req = world
.process_at_rank(0)
.immediate_receive_with_tag::<i32>(89);
world.process_at_rank(0).send_with_tag(&123_i32, 89);
let got = req.test();
assert_eq!(got.map(|(v, _)| v), Some(123));
assert_eq!(req.state(), RequestState::Completed);
}
#[test]
fn receive_into_writes_output_buffer() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send(&11_i32);
let mut out = 0_i32;
let status = world.process_at_rank(0).receive_into(&mut out);
assert_eq!(out, 11);
assert_eq!(status.tag, 0);
}
#[test]
fn receive_into_with_tag_writes_output_buffer() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_with_tag(&123_i32, 99);
let mut out = 0_i32;
let status = world
.process_at_rank(0)
.receive_into_with_tag(&mut out, 99);
assert_eq!(out, 123);
assert_eq!(status.tag, 99);
}
#[test]
fn any_process_receive_into_writes_output_buffer() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_with_tag(&99_i32, 4);
let mut out = -1_i32;
let status = world.any_process().receive_into(&mut out);
assert_eq!(out, 99);
assert_eq!(status.tag, 4);
}
#[test]
fn process_send_slice_and_receive_vec_roundtrip() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let payload = [1_i32, 2_i32, 3_i32, 4_i32];
world.process_at_rank(0).send_slice_with_tag(&payload, 42);
let (received, status) = world.process_at_rank(0).receive_vec_with_tag::<i32>(42);
assert_eq!(received, payload);
assert_eq!(status.tag, 42);
}
#[test]
fn any_process_receive_slice_into_writes_buffer() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let payload = vec![8_i32, 9_i32, 10_i32];
world.process_at_rank(0).send_slice(&payload);
let mut out = [0_i32, 0_i32, 0_i32];
let status = world.any_process().receive_slice_into(&mut out);
assert_eq!(out, [8, 9, 10]);
assert_eq!(status.tag, 0);
}
#[test]
fn process_receive_slice_into_with_tag_filters_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_slice_with_tag(&[1_i32, 1_i32], 7);
world.process_at_rank(0).send_slice_with_tag(&[2_i32, 2_i32], 8);
let mut out = [0_i32, 0_i32];
let status = world
.process_at_rank(0)
.receive_slice_into_with_tag(&mut out, 8);
assert_eq!(out, [2, 2]);
assert_eq!(status.tag, 8);
let (leftover, leftover_status) = world.process_at_rank(0).receive_vec_with_tag::<i32>(7);
assert_eq!(leftover, vec![1, 1]);
assert_eq!(leftover_status.tag, 7);
}
#[test]
#[should_panic(expected = "receive_slice_into length mismatch")]
fn receive_slice_into_panics_on_length_mismatch() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_slice(&[1_i32, 2_i32, 3_i32]);
let mut out = [0_i32, 0_i32];
let _ = world.process_at_rank(0).receive_slice_into(&mut out);
}
#[test]
fn process_send_and_receive_bytes_roundtrip() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let payload = [10_u8, 20_u8, 30_u8];
world.process_at_rank(0).send_bytes_with_tag(&payload, 66);
let (received, status) = world.process_at_rank(0).receive_bytes_with_tag(66);
assert_eq!(received, payload);
assert_eq!(status.tag, 66);
}
#[test]
fn any_process_receive_bytes_into_with_tag_filters_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_bytes_with_tag(&[1_u8, 1_u8], 15);
world.process_at_rank(0).send_bytes_with_tag(&[9_u8, 9_u8], 16);
let mut out = [0_u8, 0_u8];
let status = world
.any_process()
.receive_bytes_into_with_tag(&mut out, 16);
assert_eq!(out, [9, 9]);
assert_eq!(status.tag, 16);
let (leftover, leftover_status) = world.process_at_rank(0).receive_bytes_with_tag(15);
assert_eq!(leftover, vec![1, 1]);
assert_eq!(leftover_status.tag, 15);
}
#[test]
#[should_panic(expected = "receive_bytes_into length mismatch")]
fn receive_bytes_into_panics_on_length_mismatch() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
world.process_at_rank(0).send_bytes(&[1_u8, 2_u8, 3_u8]);
let mut out = [0_u8, 0_u8];
let _ = world.process_at_rank(0).receive_bytes_into(&mut out);
}
#[test]
fn process_send_and_receive_chunked_bytes_roundtrip() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let payload: Vec<u8> = (0_u16..130_u16).map(|v| (v % 251) as u8).collect();
world
.process_at_rank(0)
.send_bytes_chunked_with_tag(&payload, 90, 16);
let (received, status) = world
.process_at_rank(0)
.receive_bytes_chunked_with_tag(90);
assert_eq!(received, payload);
assert_eq!(status.tag, 90);
}
#[test]
fn chunked_send_with_zero_chunk_size_still_roundtrips() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let payload = vec![5_u8, 4_u8, 3_u8, 2_u8, 1_u8];
world
.process_at_rank(0)
.send_bytes_chunked_with_tag(&payload, 91, 0);
let (received, status) = world
.any_process()
.receive_bytes_chunked_with_tag(91);
assert_eq!(received, payload);
assert_eq!(status.tag, 91);
}
#[test]
fn all_gather_single_rank_collects_local_value() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = Vec::<i32>::new();
world.all_gather_into(&42_i32, &mut out);
assert_eq!(out, vec![42]);
}
#[test]
fn all_gather_root_accumulates_seeded_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 3));
world
.runtime
.send(1, 0, GATHER_TAG, &10_i32)
.expect("seed send from rank 1 should succeed");
world
.runtime
.send(2, 0, GATHER_TAG, &20_i32)
.expect("seed send from rank 2 should succeed");
let mut out = Vec::<i32>::new();
world.all_gather_into(&5_i32, &mut out);
assert_eq!(out, vec![5, 10, 20]);
}
#[test]
fn all_gather_non_root_receives_seeded_result() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
let seeded = vec![7_i32, 8_i32, 9_i32];
world
.runtime
.send(0, 2, BROADCAST_TAG, &seeded)
.expect("seed all_gather result should succeed");
let mut out = Vec::<i32>::new();
world.all_gather_into(&99_i32, &mut out);
assert_eq!(out, seeded);
}
#[test]
fn communicator_broadcast_into_from_single_rank_works() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut value = 321_i32;
world.broadcast_into_from(0, &mut value);
assert_eq!(value, 321);
}
#[test]
fn communicator_broadcast_with_timeout_non_root_receives_seeded_value() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, BROADCAST_TAG, &909_i32)
.expect("seed broadcast value should succeed");
let mut value = 0_i32;
world
.broadcast_into_from_with_timeout(0, &mut value, Duration::from_millis(10))
.expect("broadcast_with_timeout should succeed");
assert_eq!(value, 909);
}
#[test]
fn communicator_gather_into_root_single_rank_works() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = Vec::<i32>::new();
world.gather_into_root(0, &17_i32, &mut out);
assert_eq!(out, vec![17]);
}
#[test]
fn communicator_gather_with_retry_timeout_root_collects_seeded_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 3));
world
.runtime
.send(1, 0, GATHER_TAG, &70_i32)
.expect("seed send from rank 1 should succeed");
world
.runtime
.send(2, 0, GATHER_TAG, &80_i32)
.expect("seed send from rank 2 should succeed");
let mut out = Vec::<i32>::new();
world
.gather_into_root_with_retry_timeout(
0,
&60_i32,
&mut out,
Duration::from_millis(10),
&RetryPolicy::default(),
)
.expect("gather_with_retry_timeout should succeed");
assert_eq!(out, vec![60, 70, 80]);
}
#[test]
fn communicator_scatter_into_root_single_rank_works() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = 0_i32;
world.scatter_into_root(0, &[44_i32], &mut out);
assert_eq!(out, 44);
}
#[test]
fn communicator_scatter_with_retry_timeout_non_root_receives_seeded_value() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, SCATTER_TAG, &515_i32)
.expect("seed scatter value should succeed");
let mut out = 0_i32;
world
.scatter_into_root_with_retry_timeout(
0,
&[1_i32, 2_i32, 3_i32],
&mut out,
Duration::from_millis(10),
&RetryPolicy::default(),
)
.expect("scatter_with_retry_timeout should succeed");
assert_eq!(out, 515);
}
#[test]
fn communicator_reduce_sum_into_root_accumulates_seeded_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 3));
world
.runtime
.send(1, 0, REDUCE_TAG, &6_i32)
.expect("seed send from rank 1 should succeed");
world
.runtime
.send(2, 0, REDUCE_TAG, &9_i32)
.expect("seed send from rank 2 should succeed");
let mut out = 0_i32;
world.reduce_sum_into_root(0, &5_i32, &mut out);
assert_eq!(out, 20);
}
#[test]
fn communicator_all_reduce_sum_into_from_non_root_receives_seeded_result() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, BROADCAST_TAG, &77_i32)
.expect("seed all_reduce result should succeed");
let mut out = 0_i32;
world.all_reduce_sum_into_from(0, &1_i32, &mut out);
assert_eq!(out, 77);
}
#[test]
fn communicator_all_gather_with_timeout_non_root_receives_seeded_result() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
let seeded = vec![3_i32, 4_i32, 5_i32];
world
.runtime
.send(0, 2, BROADCAST_TAG, &seeded)
.expect("seed all_gather result should succeed");
let mut out = Vec::<i32>::new();
world
.all_gather_into_with_timeout(&9_i32, &mut out, Duration::from_millis(10))
.expect("all_gather_with_timeout should succeed");
assert_eq!(out, seeded);
}
#[test]
fn communicator_all_reduce_with_retry_timeout_non_root_receives_seeded_result() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, BROADCAST_TAG, &31_i32)
.expect("seed all_reduce result should succeed");
let mut out = 0_i32;
world
.all_reduce_sum_into_from_with_retry_timeout(
0,
&1_i32,
&mut out,
Duration::from_millis(10),
&RetryPolicy::default(),
)
.expect("all_reduce_with_retry_timeout should succeed");
assert_eq!(out, 31);
}
#[test]
fn broadcast_single_rank_keeps_value() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut value = 123_i32;
world.process_at_rank(0).broadcast_into(&mut value);
assert_eq!(value, 123);
}
#[test]
fn gather_single_rank_collects_local_value() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = Vec::<i32>::new();
world.process_at_rank(0).gather_into_root(&55_i32, &mut out);
assert_eq!(out, vec![55]);
}
#[test]
fn gather_as_non_root_sends_without_panicking() {
let world = SimpleCommunicator::new(Runtime::new(1, 2));
let mut out = Vec::<i32>::new();
world.process_at_rank(0).gather_into_root(&88_i32, &mut out);
assert!(out.is_empty());
}
#[test]
fn scatter_single_rank_sets_local_output() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = 0_i32;
world
.process_at_rank(0)
.scatter_into_root(&[123_i32], &mut out);
assert_eq!(out, 123);
}
#[test]
fn scatter_non_root_receives_from_root_tag() {
let world = SimpleCommunicator::new(Runtime::new(1, 2));
world
.runtime
.send(0, 1, SCATTER_TAG, &456_i32)
.expect("seed send should succeed");
let mut out = 0_i32;
world.process_at_rank(0).scatter_into_root(&[], &mut out);
assert_eq!(out, 456);
}
#[test]
fn scatter_non_root_ignores_input_slice() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, SCATTER_TAG, &789_i32)
.expect("seed send should succeed");
let mut out = 0_i32;
world
.process_at_rank(0)
.scatter_into_root(&[1_i32, 2_i32, 3_i32], &mut out);
assert_eq!(out, 789);
}
#[test]
#[should_panic(expected = "scatter input length must equal communicator size")]
fn scatter_root_panics_on_invalid_input_length() {
let world = SimpleCommunicator::new(Runtime::new(0, 2));
let mut out = 0_i32;
world.process_at_rank(0).scatter_into_root(&[1_i32], &mut out);
}
#[test]
fn reduce_sum_single_rank_sets_local_output() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = 0_i32;
world.process_at_rank(0).reduce_sum_into_root(&12_i32, &mut out);
assert_eq!(out, 12);
}
#[test]
fn reduce_sum_root_accumulates_seeded_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 3));
world
.runtime
.send(1, 0, REDUCE_TAG, &7_i32)
.expect("seed send from rank 1 should succeed");
world
.runtime
.send(2, 0, REDUCE_TAG, &8_i32)
.expect("seed send from rank 2 should succeed");
let mut out = 0_i32;
world.process_at_rank(0).reduce_sum_into_root(&5_i32, &mut out);
assert_eq!(out, 20);
}
#[test]
fn reduce_sum_non_root_sends_without_panicking() {
let world = SimpleCommunicator::new(Runtime::new(1, 3));
let mut out = -1_i32;
world.process_at_rank(0).reduce_sum_into_root(&9_i32, &mut out);
assert_eq!(out, -1);
}
#[test]
fn all_reduce_sum_single_rank_sets_local_output() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let mut out = 0_i32;
world.process_at_rank(0).all_reduce_sum_into(&33_i32, &mut out);
assert_eq!(out, 33);
}
#[test]
fn all_reduce_sum_root_accumulates_seeded_messages() {
let world = SimpleCommunicator::new(Runtime::new(0, 3));
world
.runtime
.send(1, 0, REDUCE_TAG, &4_i32)
.expect("seed send from rank 1 should succeed");
world
.runtime
.send(2, 0, REDUCE_TAG, &6_i32)
.expect("seed send from rank 2 should succeed");
let mut out = 0_i32;
world.process_at_rank(0).all_reduce_sum_into(&10_i32, &mut out);
assert_eq!(out, 20);
}
#[test]
fn all_reduce_sum_non_root_receives_seeded_result() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, BROADCAST_TAG, &99_i32)
.expect("seed all_reduce result should succeed");
let mut out = 0_i32;
world.process_at_rank(0).all_reduce_sum_into(&1_i32, &mut out);
assert_eq!(out, 99);
}
#[test]
fn reduce_into_root_sum_alias_works() {
let world = SimpleCommunicator::new(Runtime::new(0, 2));
world
.runtime
.send(1, 0, REDUCE_TAG, &5_i32)
.expect("seed send from rank 1 should succeed");
let mut out = 0_i32;
world
.process_at_rank(0)
.reduce_into_root(&7_i32, &mut out, SystemOperation::sum());
assert_eq!(out, 12);
}
#[test]
fn all_reduce_into_sum_alias_works() {
let world = SimpleCommunicator::new(Runtime::new(2, 3));
world
.runtime
.send(0, 2, BROADCAST_TAG, &21_i32)
.expect("seed all_reduce result should succeed");
let mut out = 0_i32;
world
.process_at_rank(0)
.all_reduce_into(&1_i32, &mut out, SystemOperation::sum());
assert_eq!(out, 21);
}
#[test]
fn point_to_point_module_aliases_are_usable() {
let world = SimpleCommunicator::new(Runtime::new(0, 1));
let p: p2p::Process<'_> = world.process_at_rank(0);
p.send_with_tag(&3_i32, 11);
let (v, s) = p.receive_with_tag::<i32>(11);
assert_eq!(v, 3);
assert_eq!(s.tag, 11);
}
}