use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[cfg(target_arch = "wasm32")]
use gloo_timers::future::TimeoutFuture;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::error::{Error, Result};
use crate::runtime::browser;
use crate::runtime::protocol::{decode, encode, Envelope};
use crate::datatype::{Rank, Status, Tag};
#[derive(Clone, Debug)]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: u32,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 2,
initial_backoff: Duration::from_millis(2),
max_backoff: Duration::from_millis(20),
backoff_multiplier: 2,
}
}
}
impl RetryPolicy {
pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
if attempt == 0 || self.initial_backoff.is_zero() {
return Duration::ZERO;
}
let mut backoff_ms = self.initial_backoff.as_millis() as u64;
for _ in 1..attempt {
let mult = self.backoff_multiplier.max(1) as u64;
backoff_ms = backoff_ms.saturating_mul(mult);
}
let capped_ms = backoff_ms.min(self.max_backoff.as_millis() as u64);
Duration::from_millis(capped_ms)
}
}
#[derive(Clone, Debug)]
pub struct Runtime {
inner: Arc<RuntimeState>,
}
#[derive(Debug)]
struct RuntimeState {
rank: Rank,
size: Rank,
mailbox: Mutex<VecDeque<Envelope>>,
}
impl Runtime {
pub fn detect() -> Result<Self> {
#[cfg(target_arch = "wasm32")]
{
let env = browser::detect_environment()?;
return Ok(Self::new(env.rank, env.size));
}
#[cfg(not(target_arch = "wasm32"))]
{
let rank = std::env::var("JSMPI_RANK")
.ok()
.and_then(|value| value.parse::<Rank>().ok())
.unwrap_or(0);
let size = std::env::var("JSMPI_SIZE")
.ok()
.and_then(|value| value.parse::<Rank>().ok())
.unwrap_or(1);
Ok(Self::new(rank, size))
}
}
pub fn new(rank: Rank, size: Rank) -> Self {
Self {
inner: Arc::new(RuntimeState {
rank,
size,
mailbox: Mutex::new(VecDeque::new()),
}),
}
}
pub fn rank(&self) -> Rank {
self.inner.rank
}
pub fn size(&self) -> Rank {
self.inner.size
}
pub fn send<T>(&self, src: Rank, dst: Rank, tag: Tag, value: &T) -> Result<()>
where
T: Serialize,
{
self.send_bytes(src, dst, tag, &encode(value)?)
}
pub fn send_bytes(&self, src: Rank, dst: Rank, tag: Tag, payload: &[u8]) -> Result<()> {
if src < 0 || src >= self.size() {
return Err(Error::Protocol(format!(
"invalid source rank {src}, communicator size={}",
self.size()
)));
}
if dst < 0 || dst >= self.size() {
return Err(Error::Protocol(format!(
"invalid destination rank {dst}, communicator size={}",
self.size()
)));
}
let envelope = Envelope::new(src, dst, tag, payload.to_vec());
envelope.validate(self.size())?;
if dst == self.rank() || self.size() == 1 {
self.inner
.mailbox
.lock()
.expect("mailbox lock poisoned")
.push_back(envelope);
return Ok(());
}
browser::post_envelope(&envelope)
}
pub fn send_with_timeout<T>(
&self,
src: Rank,
dst: Rank,
tag: Tag,
value: &T,
timeout: Duration,
) -> Result<()>
where
T: Serialize,
{
if timeout.is_zero() {
return Err(Error::Timeout {
operation: "send",
timeout_ms: 0,
});
}
self.send(src, dst, tag, value)
}
pub fn send_with_retry_timeout<T>(
&self,
src: Rank,
dst: Rank,
tag: Tag,
value: &T,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()>
where
T: Serialize,
{
let mut attempt = 0_u32;
loop {
match self.send_with_timeout(src, dst, tag, value, timeout_per_attempt) {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
sleep_backoff(retry_policy.backoff_for_attempt(attempt));
}
Err(err) => return Err(err),
}
}
}
pub fn receive<T>(&self, source: Option<Rank>, tag: Option<Tag>) -> Result<(T, Status)>
where
T: DeserializeOwned,
{
let (payload, status) = self.receive_bytes(source, tag)?;
let value = decode(&payload)?;
Ok((value, status))
}
pub fn receive_bytes(&self, source: Option<Rank>, tag: Option<Tag>) -> Result<(Vec<u8>, Status)> {
loop {
if let Some(envelope) = self.take_from_mailbox(source, tag)? {
let status = Status {
source_rank: envelope.src,
tag: envelope.tag,
};
return Ok((envelope.payload, status));
}
#[cfg(target_arch = "wasm32")]
{
if let Some(envelope) = browser::recv_envelope_blocking(source, tag)? {
envelope.validate(self.size())?;
let status = Status {
source_rank: envelope.src,
tag: envelope.tag,
};
return Ok((envelope.payload, status));
}
}
#[cfg(not(target_arch = "wasm32"))]
{
if let Some((payload, status)) = self.try_receive_bytes(source, tag)? {
return Ok((payload, status));
}
}
#[cfg(not(target_arch = "wasm32"))]
std::thread::yield_now();
}
}
pub fn try_receive_bytes(
&self,
source: Option<Rank>,
tag: Option<Tag>,
) -> Result<Option<(Vec<u8>, Status)>> {
if let Some(envelope) = self.take_from_mailbox(source, tag)? {
let status = Status {
source_rank: envelope.src,
tag: envelope.tag,
};
return Ok(Some((envelope.payload, status)));
}
if let Some(envelope) = browser::recv_envelope(source, tag)? {
envelope.validate(self.size())?;
let status = Status {
source_rank: envelope.src,
tag: envelope.tag,
};
return Ok(Some((envelope.payload, status)));
}
Ok(None)
}
pub fn receive_with_timeout<T>(
&self,
source: Option<Rank>,
tag: Option<Tag>,
timeout: Duration,
) -> Result<(T, Status)>
where
T: DeserializeOwned,
{
if timeout.is_zero() {
return Err(Error::Timeout {
operation: "receive",
timeout_ms: 0,
});
}
let deadline = Instant::now() + timeout;
loop {
if let Some((value, status)) = self.try_receive(source, tag)? {
return Ok((value, status));
}
if Instant::now() >= deadline {
return Err(Error::Timeout {
operation: "receive",
timeout_ms: timeout.as_millis() as u64,
});
}
#[cfg(not(target_arch = "wasm32"))]
std::thread::yield_now();
}
}
pub fn receive_with_retry_timeout<T>(
&self,
source: Option<Rank>,
tag: Option<Tag>,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<(T, Status)>
where
T: DeserializeOwned,
{
let mut attempt = 0_u32;
loop {
match self.receive_with_timeout(source, tag, timeout_per_attempt) {
Ok(pair) => return Ok(pair),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
sleep_backoff(retry_policy.backoff_for_attempt(attempt));
}
Err(err) => return Err(err),
}
}
}
#[cfg(target_arch = "wasm32")]
pub async fn receive_async<T>(&self, source: Option<Rank>, tag: Option<Tag>) -> Result<(T, Status)>
where
T: DeserializeOwned,
{
self.receive_with_timeout_async(source, tag, Duration::from_secs(15))
.await
}
#[cfg(target_arch = "wasm32")]
pub async fn barrier_async(&self) -> Result<()> {
self.barrier_with_timeout_async(Duration::from_secs(15)).await
}
#[cfg(target_arch = "wasm32")]
pub async fn receive_with_timeout_async<T>(
&self,
source: Option<Rank>,
tag: Option<Tag>,
timeout: Duration,
) -> Result<(T, Status)>
where
T: DeserializeOwned,
{
if timeout.is_zero() {
return Err(Error::Timeout {
operation: "receive",
timeout_ms: 0,
});
}
let timeout_ms = timeout.as_millis().clamp(1, u64::MAX as u128) as u64;
let mut elapsed_ms = 0_u64;
loop {
if let Some((value, status)) = self.try_receive(source, tag)? {
return Ok((value, status));
}
if elapsed_ms >= timeout_ms {
return Err(Error::Timeout {
operation: "receive",
timeout_ms,
});
}
TimeoutFuture::new(1).await;
elapsed_ms = elapsed_ms.saturating_add(1);
}
}
#[cfg(target_arch = "wasm32")]
pub async fn barrier_with_timeout_async(&self, timeout: Duration) -> Result<()> {
if self.size() <= 1 {
return Ok(());
}
if timeout.is_zero() {
return Err(Error::Timeout {
operation: "barrier",
timeout_ms: 0,
});
}
let version = browser::barrier_begin(self.rank(), self.size())?;
let timeout_ms = timeout.as_millis().clamp(1, u64::MAX as u128) as u64;
let mut elapsed_ms = 0_u64;
loop {
if browser::barrier_ready(version)? {
return Ok(());
}
if elapsed_ms >= timeout_ms {
return Err(Error::Timeout {
operation: "barrier",
timeout_ms,
});
}
TimeoutFuture::new(1).await;
elapsed_ms = elapsed_ms.saturating_add(1);
}
}
pub fn try_receive<T>(&self, source: Option<Rank>, tag: Option<Tag>) -> Result<Option<(T, Status)>>
where
T: DeserializeOwned,
{
if let Some((payload, status)) = self.try_receive_bytes(source, tag)? {
let value = decode(&payload)?;
return Ok(Some((value, status)));
}
Ok(None)
}
pub fn barrier(&self) -> Result<()> {
if self.size() <= 1 {
return Ok(());
}
let ok = browser::barrier(self.rank(), self.size(), None)?;
if ok {
Ok(())
} else {
Err(Error::Timeout {
operation: "barrier",
timeout_ms: 0,
})
}
}
pub fn barrier_with_timeout(&self, timeout: Duration) -> Result<()> {
if self.size() <= 1 {
return Ok(());
}
if timeout.is_zero() {
return Err(Error::Timeout {
operation: "barrier",
timeout_ms: 0,
});
}
let ok = browser::barrier(
self.rank(),
self.size(),
Some(timeout.as_millis().clamp(1, u32::MAX as u128) as u32),
)?;
if ok {
Ok(())
} else {
Err(Error::Timeout {
operation: "barrier",
timeout_ms: timeout.as_millis() as u64,
})
}
}
pub fn barrier_with_retry_timeout(
&self,
timeout_per_attempt: Duration,
retry_policy: &RetryPolicy,
) -> Result<()> {
let mut attempt = 0_u32;
loop {
match self.barrier_with_timeout(timeout_per_attempt) {
Ok(()) => return Ok(()),
Err(Error::Timeout { .. }) if attempt < retry_policy.max_retries => {
attempt = attempt.saturating_add(1);
sleep_backoff(retry_policy.backoff_for_attempt(attempt));
}
Err(err) => return Err(err),
}
}
}
fn take_from_mailbox(&self, source: Option<Rank>, tag: Option<Tag>) -> Result<Option<Envelope>> {
let mut mailbox = self
.inner
.mailbox
.lock()
.map_err(|_| Error::Protocol("mailbox lock poisoned".to_string()))?;
if let Some(index) = mailbox.iter().position(|envelope| {
source.map(|expected| envelope.src == expected).unwrap_or(true)
&& tag.map(|expected| envelope.tag == expected).unwrap_or(true)
}) {
return Ok(mailbox.remove(index));
}
Ok(None)
}
}
fn sleep_backoff(duration: Duration) {
if duration.is_zero() {
return;
}
#[cfg(not(target_arch = "wasm32"))]
std::thread::sleep(duration);
#[cfg(target_arch = "wasm32")]
{
let _ = duration;
}
}
#[cfg(test)]
mod tests {
use std::thread;
use std::time::{Duration, Instant};
use crate::error::Error;
use super::{RetryPolicy, Runtime};
#[test]
fn try_receive_returns_none_when_empty() {
let runtime = Runtime::new(0, 2);
let result = runtime.try_receive::<i32>(None, None).expect("try_receive should not fail");
assert!(result.is_none());
}
#[test]
fn try_receive_honors_source_and_tag() {
let runtime = Runtime::new(0, 3);
runtime
.send(2, 0, 7, &123_i32)
.expect("send should succeed in local runtime");
let mismatch = runtime
.try_receive::<i32>(Some(2), Some(8))
.expect("try_receive should not fail");
assert!(mismatch.is_none());
let matched = runtime
.try_receive::<i32>(Some(2), Some(7))
.expect("try_receive should not fail")
.expect("message should match expected source and tag");
assert_eq!(matched.0, 123);
assert_eq!(matched.1.source_rank, 2);
assert_eq!(matched.1.tag, 7);
}
#[test]
fn receive_blocks_until_message_available() {
let runtime = Runtime::new(0, 2);
let sender = runtime.clone();
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(20));
sender
.send(0, 0, 3, &77_i32)
.expect("delayed send should succeed");
});
let (value, status) = runtime
.receive::<i32>(Some(0), Some(3))
.expect("blocking receive should eventually succeed");
handle.join().expect("sender thread should finish");
assert_eq!(value, 77);
assert_eq!(status.source_rank, 0);
assert_eq!(status.tag, 3);
}
#[test]
fn receive_with_timeout_returns_timeout_error() {
let runtime = Runtime::new(0, 2);
let result = runtime.receive_with_timeout::<i32>(None, None, Duration::from_millis(1));
match result {
Err(Error::Timeout { operation, .. }) => assert_eq!(operation, "receive"),
_ => panic!("expected receive timeout error"),
}
}
#[test]
fn send_rejects_invalid_destination_rank() {
let runtime = Runtime::new(0, 2);
let err = runtime
.send(0, 5, 0, &1_i32)
.expect_err("invalid destination should fail");
match err {
Error::Protocol(msg) => assert!(msg.contains("invalid destination rank")),
_ => panic!("expected protocol error"),
}
}
#[test]
fn send_with_zero_timeout_fails_fast() {
let runtime = Runtime::new(0, 2);
let err = runtime
.send_with_timeout(0, 0, 0, &1_i32, Duration::ZERO)
.expect_err("zero-timeout send should fail");
match err {
Error::Timeout { operation, .. } => assert_eq!(operation, "send"),
_ => panic!("expected timeout error"),
}
}
#[test]
fn receive_with_retry_timeout_stops_after_budget() {
let runtime = Runtime::new(0, 2);
let policy = RetryPolicy {
max_retries: 2,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(2),
backoff_multiplier: 2,
};
let started = Instant::now();
let result = runtime.receive_with_retry_timeout::<i32>(
None,
None,
Duration::from_millis(1),
&policy,
);
let elapsed = started.elapsed();
match result {
Err(Error::Timeout { operation, .. }) => assert_eq!(operation, "receive"),
_ => panic!("expected timeout after retries"),
}
assert!(elapsed < Duration::from_millis(200));
}
#[test]
fn send_and_receive_bytes_roundtrip() {
let runtime = Runtime::new(0, 1);
let payload = [1_u8, 2_u8, 3_u8, 4_u8];
runtime
.send_bytes(0, 0, 12, &payload)
.expect("raw byte send should succeed");
let (received, status) = runtime
.receive_bytes(Some(0), Some(12))
.expect("raw byte receive should succeed");
assert_eq!(received, payload);
assert_eq!(status.source_rank, 0);
assert_eq!(status.tag, 12);
}
#[test]
fn retry_policy_backoff_is_capped() {
let policy = RetryPolicy {
max_retries: 4,
initial_backoff: Duration::from_millis(3),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 3,
};
assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(3));
assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(9));
assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(10));
assert_eq!(policy.backoff_for_attempt(4), Duration::from_millis(10));
}
}