#![cfg(target_arch = "wasm32")]
use std::cell::RefCell;
use std::collections::HashSet;
use std::rc::Rc;
use js_sys::{Object, Reflect, SharedArrayBuffer};
use serde_wasm_bindgen;
use wasm_bindgen::closure::Closure;
use wasm_bindgen::{JsCast, JsValue};
use web_sys::{ErrorEvent, MessageEvent, Worker, WorkerOptions, WorkerType};
use crate::runtime::protocol::{Envelope, PROTOCOL_VERSION};
pub struct Job {
workers: Rc<Vec<Worker>>,
_message_handlers: Vec<Closure<dyn FnMut(MessageEvent)>>,
_error_handlers: Vec<Closure<dyn FnMut(ErrorEvent)>>,
}
#[derive(Clone, Debug, Default)]
pub struct JobOptions {
pub collective_timeout_ms: Option<u32>,
pub collective_max_retries: Option<u32>,
}
impl Job {
pub fn terminate(&self) {
terminate_all(&self.workers);
}
}
impl Drop for Job {
fn drop(&mut self) {
terminate_all(&self.workers);
}
}
fn terminate_all(workers: &[Worker]) {
for w in workers {
w.set_onmessage(None);
w.set_onerror(None);
w.terminate();
}
}
pub fn create_job(
worker_url: &str,
module_url: &str,
ranks: u32,
on_log: impl Fn(String, String) + 'static,
on_state_change: impl Fn(String) + 'static,
on_complete: impl Fn(u32, u32) + 'static,
) -> Job {
create_job_with_options(
worker_url,
module_url,
ranks,
JobOptions::default(),
on_log,
on_state_change,
on_complete,
)
}
pub fn create_job_with_options(
worker_url: &str,
module_url: &str,
ranks: u32,
options: JobOptions,
on_log: impl Fn(String, String) + 'static,
on_state_change: impl Fn(String) + 'static,
on_complete: impl Fn(u32, u32) + 'static,
) -> Job {
let on_log: Rc<dyn Fn(String, String)> = Rc::new(on_log);
let on_state_change: Rc<dyn Fn(String)> = Rc::new(on_state_change);
let on_complete: Rc<dyn Fn(u32, u32)> = Rc::new(on_complete);
let job_id = format!("job-{}", js_sys::Date::now() as u64);
let barrier_ranks: Rc<RefCell<HashSet<u32>>> = Rc::new(RefCell::new(HashSet::new()));
let barrier_started_at_ms: Rc<RefCell<Option<f64>>> = Rc::new(RefCell::new(None));
let finished_ranks: Rc<RefCell<u32>> = Rc::new(RefCell::new(0));
let stopped: Rc<RefCell<bool>> = Rc::new(RefCell::new(false));
let metrics: Rc<RefCell<RuntimeMetrics>> = Rc::new(RefCell::new(RuntimeMetrics::new(
ranks,
js_sys::Date::now(),
)));
let opts = WorkerOptions::new();
opts.set_type(WorkerType::Module);
let raw_workers: Vec<Worker> = (0..ranks)
.map(|_| Worker::new_with_options(worker_url, &opts).expect("Worker::new failed"))
.collect();
let workers_rc: Rc<Vec<Worker>> = Rc::new(raw_workers);
let mut message_handlers: Vec<Closure<dyn FnMut(MessageEvent)>> = Vec::new();
let mut error_handlers: Vec<Closure<dyn FnMut(ErrorEvent)>> = Vec::new();
for rank in 0..ranks {
let workers_clone = Rc::clone(&workers_rc);
let barrier_clone = Rc::clone(&barrier_ranks);
let barrier_started_clone = Rc::clone(&barrier_started_at_ms);
let finished_clone = Rc::clone(&finished_ranks);
let stopped_clone = Rc::clone(&stopped);
let metrics_clone = Rc::clone(&metrics);
let job_id_clone = job_id.clone();
let on_log_clone = Rc::clone(&on_log);
let on_state_clone = Rc::clone(&on_state_change);
let on_complete_clone = Rc::clone(&on_complete);
let handler = Closure::<dyn FnMut(MessageEvent)>::new(move |evt: MessageEvent| {
handle_worker_message(
evt,
rank,
ranks,
&job_id_clone,
&workers_clone,
&barrier_clone,
&barrier_started_clone,
&finished_clone,
&stopped_clone,
&metrics_clone,
&on_log_clone,
&on_state_clone,
&on_complete_clone,
);
});
workers_rc[rank as usize].set_onmessage(Some(handler.as_ref().unchecked_ref()));
message_handlers.push(handler);
let on_log_error = Rc::clone(&on_log);
let on_state_error = Rc::clone(&on_state_change);
let stopped_error = Rc::clone(&stopped);
let metrics_error = Rc::clone(&metrics);
let job_id_error = job_id.clone();
let error_handler = Closure::<dyn FnMut(ErrorEvent)>::new(move |evt: ErrorEvent| {
let mut m = metrics_error.borrow_mut();
m.error_count += 1;
emit_structured_log(
&on_log_error,
"error",
&job_id_error,
Some(rank),
"lifecycle",
"worker_crash",
Some("WORKER_CRASH"),
&format!("{} ({}:{})", evt.message(), evt.filename(), evt.lineno()),
);
mark_stopped_once(
&stopped_error,
&on_state_error,
&on_log_error,
&job_id_error,
rank,
&metrics_error,
);
});
workers_rc[rank as usize].set_onerror(Some(error_handler.as_ref().unchecked_ref()));
error_handlers.push(error_handler);
let control_buffer = make_control_buffer();
if control_buffer.is_none() {
emit_structured_log(
&on_log,
"error",
&job_id,
Some(rank),
"runtime",
"capability_check",
Some("CAPABILITY_MISSING_SAB"),
"SharedArrayBuffer is unavailable; serve with COOP/COEP headers for blocking receive/barrier.",
);
}
metrics.borrow_mut().mark_init_sent(rank, js_sys::Date::now());
let init_msg = build_init_msg(
rank,
ranks,
module_url,
&job_id,
control_buffer.as_ref(),
&options,
);
workers_rc[rank as usize]
.post_message(&init_msg)
.expect("postMessage failed");
}
(on_state_change)("running".into());
emit_structured_log(
&on_log,
"status",
&job_id,
None,
"lifecycle",
"job_started",
None,
&format!("spawned {} workers", ranks),
);
Job {
workers: workers_rc,
_message_handlers: message_handlers,
_error_handlers: error_handlers,
}
}
fn mark_stopped_once(
stopped: &Rc<RefCell<bool>>,
on_state_change: &Rc<dyn Fn(String)>,
on_log: &Rc<dyn Fn(String, String)>,
job_id: &str,
rank: u32,
metrics: &Rc<RefCell<RuntimeMetrics>>,
) {
if !*stopped.borrow() {
*stopped.borrow_mut() = true;
on_state_change("stopped".into());
emit_metrics(on_log, job_id, metrics, "job_stopped", Some(rank));
}
}
#[derive(Debug)]
struct RuntimeMetrics {
total_ranks: u32,
started_at_ms: f64,
init_sent_at_ms: Vec<Option<f64>>,
ready_at_ms: Vec<Option<f64>>,
barrier_rounds: u32,
barrier_wait_total_ms: f64,
transport_dispatch_total_ms: f64,
transport_dispatch_count: u64,
error_count: u64,
event_count: u64,
}
impl RuntimeMetrics {
fn new(total_ranks: u32, started_at_ms: f64) -> Self {
Self {
total_ranks,
started_at_ms,
init_sent_at_ms: vec![None; total_ranks as usize],
ready_at_ms: vec![None; total_ranks as usize],
barrier_rounds: 0,
barrier_wait_total_ms: 0.0,
transport_dispatch_total_ms: 0.0,
transport_dispatch_count: 0,
error_count: 0,
event_count: 0,
}
}
fn mark_init_sent(&mut self, rank: u32, now_ms: f64) {
if let Some(slot) = self.init_sent_at_ms.get_mut(rank as usize) {
*slot = Some(now_ms);
}
}
fn mark_ready(&mut self, rank: u32, now_ms: f64) {
if let Some(slot) = self.ready_at_ms.get_mut(rank as usize) {
*slot = Some(now_ms);
}
}
fn avg_startup_ms(&self) -> Option<f64> {
let mut total = 0.0;
let mut count = 0usize;
for i in 0..self.total_ranks as usize {
if let (Some(sent), Some(ready)) = (self.init_sent_at_ms[i], self.ready_at_ms[i]) {
total += ready - sent;
count += 1;
}
}
if count == 0 {
None
} else {
Some(total / count as f64)
}
}
fn avg_dispatch_ms(&self) -> Option<f64> {
if self.transport_dispatch_count == 0 {
None
} else {
Some(self.transport_dispatch_total_ms / self.transport_dispatch_count as f64)
}
}
fn avg_barrier_wait_ms(&self) -> Option<f64> {
if self.barrier_rounds == 0 {
None
} else {
Some(self.barrier_wait_total_ms / self.barrier_rounds as f64)
}
}
fn failure_rate(&self) -> f64 {
if self.event_count == 0 {
0.0
} else {
self.error_count as f64 / self.event_count as f64
}
}
}
fn extract_error_code(text: &str) -> Option<String> {
let prefix = text.split(':').next().unwrap_or_default().trim();
if prefix.chars().all(|c| c == '_' || c.is_ascii_uppercase()) && prefix.contains('_') {
Some(prefix.to_string())
} else {
None
}
}
fn emit_structured_log(
on_log: &Rc<dyn Fn(String, String)>,
kind: &str,
job_id: &str,
rank: Option<u32>,
phase: &str,
event: &str,
error_code: Option<&str>,
message: &str,
) {
let rank_value = rank
.map(|r| r.to_string())
.unwrap_or_else(|| "-".to_string());
let code = error_code.unwrap_or("-");
let text = format!(
"job_id={} rank={} phase={} event={} error_code={} message={}",
job_id, rank_value, phase, event, code, message
);
on_log(kind.to_string(), text);
}
fn emit_metrics(
on_log: &Rc<dyn Fn(String, String)>,
job_id: &str,
metrics: &Rc<RefCell<RuntimeMetrics>>,
event: &str,
rank: Option<u32>,
) {
let mut m = metrics.borrow_mut();
m.event_count += 1;
let startup = m
.avg_startup_ms()
.map(|v| format!("{:.2}", v))
.unwrap_or_else(|| "-".to_string());
let dispatch = m
.avg_dispatch_ms()
.map(|v| format!("{:.4}", v))
.unwrap_or_else(|| "-".to_string());
let barrier = m
.avg_barrier_wait_ms()
.map(|v| format!("{:.2}", v))
.unwrap_or_else(|| "-".to_string());
let uptime = js_sys::Date::now() - m.started_at_ms;
emit_structured_log(
on_log,
"status",
job_id,
rank,
"metrics",
event,
None,
&format!(
"startup_ms_avg={} dispatch_ms_avg={} barrier_wait_ms_avg={} failure_rate={:.4} uptime_ms={:.0}",
startup,
dispatch,
barrier,
m.failure_rate(),
uptime
),
);
}
#[allow(clippy::too_many_arguments)]
fn handle_worker_message(
evt: MessageEvent,
rank: u32,
total_ranks: u32,
job_id: &str,
workers: &Rc<Vec<Worker>>,
barrier_ranks: &Rc<RefCell<HashSet<u32>>>,
barrier_started_at_ms: &Rc<RefCell<Option<f64>>>,
finished_ranks: &Rc<RefCell<u32>>,
stopped: &Rc<RefCell<bool>>,
metrics: &Rc<RefCell<RuntimeMetrics>>,
on_log: &Rc<dyn Fn(String, String)>,
on_state_change: &Rc<dyn Fn(String)>,
on_complete: &Rc<dyn Fn(u32, u32)>,
) {
metrics.borrow_mut().event_count += 1;
let data = evt.data();
let msg_type = get_str(&data, "type").unwrap_or_default();
match msg_type.as_str() {
"jsmpi:ready" => {
let now = js_sys::Date::now();
metrics.borrow_mut().mark_ready(rank, now);
emit_structured_log(
on_log,
"status",
job_id,
Some(rank),
"startup",
"worker_ready",
None,
"worker initialized",
);
}
"jsmpi:send" => {
let envelope_value = get_obj(&data, "envelope");
let decoded = serde_wasm_bindgen::from_value::<Envelope>(envelope_value.clone());
match decoded {
Ok(envelope) => {
if envelope.protocol_version != PROTOCOL_VERSION {
metrics.borrow_mut().error_count += 1;
emit_structured_log(
on_log,
"error",
job_id,
Some(rank),
"protocol",
"version_check",
Some("PROTO_UNSUPPORTED_VERSION"),
&format!(
"expected={}, got={}",
PROTOCOL_VERSION, envelope.protocol_version
),
);
mark_stopped_once(stopped, on_state_change, on_log, job_id, rank, metrics);
return;
}
if envelope.dst < 0 || (envelope.dst as usize) >= workers.len() {
metrics.borrow_mut().error_count += 1;
emit_structured_log(
on_log,
"error",
job_id,
Some(rank),
"protocol",
"envelope_validate",
Some("PROTO_INVALID_ENVELOPE"),
&format!("invalid destination rank {}", envelope.dst),
);
return;
}
let started_ms = js_sys::Date::now();
let dst = envelope.dst as usize;
let deliver = build_deliver_msg(&envelope_value);
workers[dst].post_message(&deliver).ok();
let elapsed_ms = js_sys::Date::now() - started_ms;
{
let mut m = metrics.borrow_mut();
m.transport_dispatch_count += 1;
m.transport_dispatch_total_ms += elapsed_ms;
}
emit_structured_log(
on_log,
"transport",
job_id,
Some(rank),
"transport",
"send_route",
None,
&format!(
"src={} dst={} tag={} dispatch_ms={:.4}",
envelope.src, dst, envelope.tag, elapsed_ms
),
);
}
Err(err) => {
metrics.borrow_mut().error_count += 1;
emit_structured_log(
on_log,
"error",
job_id,
Some(rank),
"protocol",
"decode",
Some("PROTO_INVALID_ENVELOPE"),
&format!("decode failed: {err}"),
);
}
}
}
"jsmpi:barrier" => {
let now = js_sys::Date::now();
if barrier_ranks.borrow().is_empty() {
*barrier_started_at_ms.borrow_mut() = Some(now);
}
barrier_ranks.borrow_mut().insert(rank);
let current = barrier_ranks.borrow().len() as u32;
emit_structured_log(
on_log,
"barrier",
job_id,
Some(rank),
"collective",
"barrier_reached",
None,
&format!("reached ({}/{})", current, total_ranks),
);
if current == total_ranks {
let release = build_simple_msg("jsmpi:barrier-release");
for w in workers.iter() {
w.post_message(&release).ok();
}
if let Some(started) = *barrier_started_at_ms.borrow() {
let wait_ms = js_sys::Date::now() - started;
let mut m = metrics.borrow_mut();
m.barrier_rounds += 1;
m.barrier_wait_total_ms += wait_ms;
}
barrier_ranks.borrow_mut().clear();
*barrier_started_at_ms.borrow_mut() = None;
emit_structured_log(
on_log,
"status",
job_id,
None,
"collective",
"barrier_release",
None,
"all ranks reached the barrier",
);
}
}
"jsmpi:log" => {
let level = get_str(&data, "level").unwrap_or_else(|| "log".into());
let text = get_str(&data, "text").unwrap_or_default();
emit_structured_log(
on_log,
&level,
job_id,
Some(rank),
"runtime",
"worker_log",
None,
&text,
);
}
"jsmpi:error" => {
let text = get_str(&data, "text").unwrap_or_default();
let error_code = extract_error_code(&text);
metrics.borrow_mut().error_count += 1;
emit_structured_log(
on_log,
"error",
job_id,
Some(rank),
"runtime",
"worker_error",
error_code.as_deref(),
&text,
);
mark_stopped_once(stopped, on_state_change, on_log, job_id, rank, metrics);
}
"jsmpi:finished" => {
let prev = *finished_ranks.borrow();
*finished_ranks.borrow_mut() = prev + 1;
let done = prev + 1;
emit_structured_log(
on_log,
"status",
job_id,
Some(rank),
"lifecycle",
"worker_finished",
None,
&format!("({}/{})", done, total_ranks),
);
if done == total_ranks {
on_state_change("finished".into());
on_complete(done, total_ranks);
emit_metrics(on_log, job_id, metrics, "job_finished", None);
}
}
_ => {}
}
}
fn get_str(obj: &JsValue, key: &str) -> Option<String> {
Reflect::get(obj, &JsValue::from_str(key))
.ok()
.and_then(|v| v.as_string())
}
fn get_obj(obj: &JsValue, key: &str) -> JsValue {
Reflect::get(obj, &JsValue::from_str(key))
.unwrap_or(JsValue::UNDEFINED)
}
fn build_simple_msg(msg_type: &str) -> JsValue {
let obj = Object::new();
Reflect::set(&obj, &"type".into(), &JsValue::from_str(msg_type)).unwrap();
obj.into()
}
fn build_deliver_msg(envelope: &JsValue) -> JsValue {
let obj = Object::new();
Reflect::set(&obj, &"type".into(), &"jsmpi:deliver".into()).unwrap();
Reflect::set(&obj, &"envelope".into(), envelope).unwrap();
obj.into()
}
fn build_init_msg(
rank: u32,
size: u32,
module_url: &str,
job_id: &str,
control_buffer: Option<&SharedArrayBuffer>,
options: &JobOptions,
) -> JsValue {
let obj = Object::new();
Reflect::set(&obj, &"type".into(), &"jsmpi:init".into()).unwrap();
Reflect::set(&obj, &"rank".into(), &JsValue::from_f64(rank as f64)).unwrap();
Reflect::set(&obj, &"size".into(), &JsValue::from_f64(size as f64)).unwrap();
Reflect::set(&obj, &"jobId".into(), &JsValue::from_str(job_id)).unwrap();
Reflect::set(&obj, &"moduleUrl".into(), &JsValue::from_str(module_url)).unwrap();
if let Some(timeout_ms) = options.collective_timeout_ms {
Reflect::set(
&obj,
&"collectiveTimeoutMs".into(),
&JsValue::from_f64(timeout_ms as f64),
)
.unwrap();
}
if let Some(max_retries) = options.collective_max_retries {
Reflect::set(
&obj,
&"collectiveMaxRetries".into(),
&JsValue::from_f64(max_retries as f64),
)
.unwrap();
}
if let Some(sab) = control_buffer {
Reflect::set(&obj, &"controlBuffer".into(), sab).unwrap();
}
obj.into()
}
fn make_control_buffer() -> Option<SharedArrayBuffer> {
let global = js_sys::global();
let sab_ctor = Reflect::get(&global, &"SharedArrayBuffer".into()).ok()?;
if sab_ctor.is_undefined() || sab_ctor.is_null() {
return None;
}
Some(SharedArrayBuffer::new(8))
}