use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use futures::channel::oneshot;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use wasm_bindgen::closure::Closure;
use wasm_bindgen::{JsCast, JsValue};
use web_sys::js_sys::{Array, Function, Object, Reflect};
use web_sys::{MessageEvent, Worker, WorkerOptions, WorkerType};
use crate::workers::blob::{
awsm_bundle_url, current_wasm_module, new_worker_from_js, WORKER_BOOTSTRAP_JS,
};
#[derive(Debug, Error)]
pub enum WorkerPoolError {
#[error("worker bootstrap failed: {0}")]
Bootstrap(String),
#[error("worker postMessage failed: {0}")]
PostMessage(String),
#[error("worker job failed: {0}")]
JobFailed(String),
#[error("worker job not registered: {0}")]
UnknownJob(&'static str),
#[error("worker serialization error: {0}")]
Serde(String),
#[error("worker channel dropped before result")]
ChannelDropped,
}
impl WorkerPoolError {
fn js_message(err: JsValue) -> String {
err.as_string()
.or_else(|| {
Reflect::get(&err, &JsValue::from_str("message"))
.ok()
.and_then(|v| v.as_string())
})
.unwrap_or_else(|| format!("{err:?}"))
}
fn bootstrap_from_js(prefix: &'static str, err: JsValue) -> Self {
WorkerPoolError::Bootstrap(format!("{prefix}: {}", Self::js_message(err)))
}
fn post_message_from_js(prefix: &'static str, err: JsValue) -> Self {
WorkerPoolError::PostMessage(format!("{prefix}: {}", Self::js_message(err)))
}
}
pub trait WorkerJob: 'static {
const NAME: &'static str;
type Input: Serialize + DeserializeOwned + 'static;
type Output: Serialize + DeserializeOwned + 'static;
fn execute(
input: Self::Input,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<Self::Output>>>>;
fn into_response_message(
output: Self::Output,
) -> Result<(JsValue, web_sys::js_sys::Array), String> {
let payload = serde_wasm_bindgen::to_value(&output)
.map_err(|err| format!("serialize output: {err}"))?;
Ok((payload, web_sys::js_sys::Array::new()))
}
fn from_response_message(payload: JsValue) -> Result<Self::Output, String> {
serde_wasm_bindgen::from_value(payload).map_err(|err| format!("deserialize output: {err}"))
}
}
#[derive(Default)]
pub enum WorkerPoolBootstrap {
#[default]
Auto,
ModuleUrl { bundle_url: String },
Custom(Box<dyn Fn() -> Result<Worker, JsValue> + 'static>),
}
#[derive(Debug, Default, Clone, Copy)]
pub struct WorkerPoolStats {
pub workers_alive: usize,
pub jobs_dispatched: u64,
pub jobs_completed: u64,
pub jobs_failed: u64,
pub job_round_trip_ms: f64,
}
struct PendingEntry {
sender: oneshot::Sender<Result<JsValue, JsValue>>,
worker_idx: usize,
}
pub struct WorkerPool {
workers: Vec<WorkerSlot>,
next_worker: AtomicUsize,
next_job_id: AtomicU64,
pending: Arc<Mutex<HashMap<u64, PendingEntry>>>,
stats: Arc<Mutex<WorkerPoolStats>>,
_onmessage_closures: Vec<Closure<dyn FnMut(MessageEvent)>>,
_onerror_closures: Vec<Closure<dyn FnMut(JsValue)>>,
}
struct WorkerSlot {
worker: Worker,
}
impl WorkerPool {
pub async fn with_workers(worker_count: Option<usize>) -> Result<Self, WorkerPoolError> {
let count = worker_count.unwrap_or_else(default_worker_count);
Self::new(WorkerPoolBootstrap::default(), count).await
}
pub async fn new(
bootstrap: WorkerPoolBootstrap,
worker_count: usize,
) -> Result<Self, WorkerPoolError> {
let worker_count = worker_count.max(1);
let glue_url = match &bootstrap {
WorkerPoolBootstrap::Auto => awsm_bundle_url(),
WorkerPoolBootstrap::ModuleUrl { bundle_url } => bundle_url.clone(),
WorkerPoolBootstrap::Custom(_) => String::new(),
};
let wasm_module = current_wasm_module()
.map_err(|err| WorkerPoolError::bootstrap_from_js("current_wasm_module", err))?;
let pending: Arc<Mutex<HashMap<u64, PendingEntry>>> = Arc::new(Mutex::new(HashMap::new()));
let stats = Arc::new(Mutex::new(WorkerPoolStats::default()));
let mut workers = Vec::with_capacity(worker_count);
let mut onmessage_closures = Vec::with_capacity(worker_count);
let mut onerror_closures = Vec::with_capacity(worker_count);
let mut ready_futures = Vec::with_capacity(worker_count);
for i in 0..worker_count {
let worker = match &bootstrap {
WorkerPoolBootstrap::Auto | WorkerPoolBootstrap::ModuleUrl { .. } => {
let opts = WorkerOptions::new();
opts.set_type(WorkerType::Module);
new_worker_from_js(WORKER_BOOTSTRAP_JS, Some(opts))
.map_err(|err| WorkerPoolError::bootstrap_from_js("worker spawn", err))?
}
WorkerPoolBootstrap::Custom(factory) => factory().map_err(|err| {
WorkerPoolError::bootstrap_from_js("custom worker factory", err)
})?,
};
let (ready_tx, ready_rx) = oneshot::channel::<Result<(), String>>();
let ready_cell: Arc<Mutex<Option<oneshot::Sender<Result<(), String>>>>> =
Arc::new(Mutex::new(Some(ready_tx)));
let pending_clone = Arc::clone(&pending);
let stats_clone = Arc::clone(&stats);
let ready_cell_clone = Arc::clone(&ready_cell);
let label = format!("awsm-worker-{i}");
let onmessage = Closure::<dyn FnMut(MessageEvent)>::new(move |e: MessageEvent| {
let data = e.data();
let kind = Reflect::get(&data, &JsValue::from_str("kind"))
.ok()
.and_then(|v| v.as_string())
.unwrap_or_default();
match kind.as_str() {
"awsm-ready" => {
if let Some(tx) = ready_cell_clone.lock().unwrap().take() {
let _ = tx.send(Ok(()));
}
}
"awsm-init-error" => {
let msg = Reflect::get(&data, &JsValue::from_str("message"))
.ok()
.and_then(|v| v.as_string())
.unwrap_or_else(|| "unknown init error".to_string());
if let Some(tx) = ready_cell_clone.lock().unwrap().take() {
let _ = tx.send(Err(msg));
}
}
"awsm-result" => {
let id = parse_job_id(&data);
let payload = Reflect::get(&data, &JsValue::from_str("payload"))
.unwrap_or(JsValue::UNDEFINED);
if let Some(id) = id {
if let Some(entry) = pending_clone.lock().unwrap().remove(&id) {
let _ = entry.sender.send(Ok(payload));
stats_clone.lock().unwrap().jobs_completed += 1;
}
}
}
"awsm-error" => {
let id = parse_job_id(&data);
let msg = Reflect::get(&data, &JsValue::from_str("message"))
.ok()
.and_then(|v| v.as_string())
.unwrap_or_else(|| "unknown job error".to_string());
if let Some(id) = id {
if let Some(entry) = pending_clone.lock().unwrap().remove(&id) {
let _ = entry.sender.send(Err(JsValue::from_str(&msg)));
stats_clone.lock().unwrap().jobs_failed += 1;
}
} else {
tracing::warn!("{label}: worker error without id: {msg}");
}
}
other => {
tracing::debug!("{label}: unknown worker message kind: {other:?}");
}
}
});
worker.set_onmessage(Some(onmessage.as_ref().unchecked_ref::<Function>()));
onmessage_closures.push(onmessage);
let onerror_label = format!("awsm-worker-{i}");
let onerror_worker_idx = i;
let onerror_ready_cell = Arc::clone(&ready_cell);
let onerror_pending = Arc::clone(&pending);
let onerror_stats = Arc::clone(&stats);
let onerror = Closure::<dyn FnMut(JsValue)>::new(move |err: JsValue| {
let msg = WorkerPoolError::js_message(err);
tracing::warn!("{onerror_label}: worker onerror: {msg}");
if let Some(tx) = onerror_ready_cell.lock().unwrap().take() {
let _ = tx.send(Err(format!("worker onerror during init: {msg}")));
}
let mut pending = onerror_pending.lock().unwrap();
let mut drained_count: u64 = 0;
pending.retain(|_id, entry| {
if entry.worker_idx == onerror_worker_idx {
let (placeholder_tx, _) = oneshot::channel::<Result<JsValue, JsValue>>();
let sender = std::mem::replace(&mut entry.sender, placeholder_tx);
let _ = sender.send(Err(JsValue::from_str(&format!(
"{onerror_label}: worker errored: {msg}"
))));
drained_count += 1;
false
} else {
true
}
});
if drained_count > 0 {
onerror_stats.lock().unwrap().jobs_failed += drained_count;
}
});
worker.set_onerror(Some(onerror.as_ref().unchecked_ref::<Function>()));
onerror_closures.push(onerror);
let init_msg = Object::new();
Reflect::set(
&init_msg,
&JsValue::from_str("kind"),
&JsValue::from_str("awsm-init"),
)
.map_err(|err| WorkerPoolError::bootstrap_from_js("init msg", err))?;
Reflect::set(&init_msg, &JsValue::from_str("wasm_module"), &wasm_module)
.map_err(|err| WorkerPoolError::bootstrap_from_js("init msg", err))?;
Reflect::set(
&init_msg,
&JsValue::from_str("glue_url"),
&JsValue::from_str(&glue_url),
)
.map_err(|err| WorkerPoolError::bootstrap_from_js("init msg", err))?;
worker
.post_message(&init_msg)
.map_err(|err| WorkerPoolError::bootstrap_from_js("init postMessage", err))?;
workers.push(WorkerSlot { worker });
ready_futures.push(ready_rx);
}
for (i, rx) in ready_futures.into_iter().enumerate() {
match rx.await {
Ok(Ok(())) => {}
Ok(Err(msg)) => {
return Err(WorkerPoolError::Bootstrap(format!(
"worker #{i} init: {msg}"
)));
}
Err(_) => {
return Err(WorkerPoolError::Bootstrap(format!(
"worker #{i} ready channel dropped"
)));
}
}
}
stats.lock().unwrap().workers_alive = workers.len();
Ok(Self {
workers,
next_worker: AtomicUsize::new(0),
next_job_id: AtomicU64::new(1),
pending,
stats,
_onmessage_closures: onmessage_closures,
_onerror_closures: onerror_closures,
})
}
pub fn register<J: WorkerJob>(&self) {
crate::workers::entry::register::<J>();
}
pub async fn dispatch<J: WorkerJob>(
&self,
input: J::Input,
) -> Result<J::Output, WorkerPoolError> {
self.dispatch_inner::<J>(input, None).await
}
pub async fn dispatch_with_transfer<J: WorkerJob>(
&self,
input: J::Input,
transfer: Array,
) -> Result<J::Output, WorkerPoolError> {
self.dispatch_inner::<J>(input, Some(transfer)).await
}
async fn dispatch_inner<J: WorkerJob>(
&self,
input: J::Input,
transfer: Option<Array>,
) -> Result<J::Output, WorkerPoolError> {
if !crate::workers::entry::is_registered(J::NAME) {
return Err(WorkerPoolError::UnknownJob(J::NAME));
}
let id = self.next_job_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel::<Result<JsValue, JsValue>>();
let worker_idx = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.workers.len();
let worker = &self.workers[worker_idx].worker;
let input_js = serde_wasm_bindgen::to_value(&input)
.map_err(|err| WorkerPoolError::Serde(format!("input: {err}")))?;
let msg = Object::new();
let _ = Reflect::set(
&msg,
&JsValue::from_str("kind"),
&JsValue::from_str("awsm-job"),
);
let _ = Reflect::set(
&msg,
&JsValue::from_str("id"),
&JsValue::from_str(&id.to_string()),
);
let _ = Reflect::set(
&msg,
&JsValue::from_str("name"),
&JsValue::from_str(J::NAME),
);
let _ = Reflect::set(&msg, &JsValue::from_str("input"), &input_js);
let post_result = match transfer {
Some(arr) => worker.post_message_with_transfer(&msg, &arr),
None => worker.post_message(&msg),
};
post_result.map_err(|err| {
WorkerPoolError::post_message_from_js("dispatch postMessage", err)
})?;
self.pending.lock().unwrap().insert(
id,
PendingEntry {
sender: tx,
worker_idx,
},
);
self.stats.lock().unwrap().jobs_dispatched += 1;
let dispatched_at_ms = perf_now_ms();
match rx.await {
Ok(Ok(payload)) => {
if dispatched_at_ms > 0.0 {
let delta = (perf_now_ms() - dispatched_at_ms).max(0.0);
self.stats.lock().unwrap().job_round_trip_ms += delta;
}
J::from_response_message(payload).map_err(WorkerPoolError::Serde)
}
Ok(Err(err)) => {
let msg = err.as_string().unwrap_or_else(|| format!("{err:?}"));
Err(WorkerPoolError::JobFailed(msg))
}
Err(_) => Err(WorkerPoolError::ChannelDropped),
}
}
pub fn stats(&self) -> WorkerPoolStats {
*self.stats.lock().unwrap()
}
}
impl Drop for WorkerPool {
fn drop(&mut self) {
for slot in &self.workers {
slot.worker.terminate();
}
}
}
fn perf_now_ms() -> f64 {
crate::web_global::performance()
.map(|p| p.now())
.unwrap_or(0.0)
}
pub(crate) fn parse_job_id(data: &JsValue) -> Option<u64> {
Reflect::get(data, &JsValue::from_str("id"))
.ok()
.and_then(|v| v.as_string())
.and_then(|s| s.parse::<u64>().ok())
}
fn default_worker_count() -> usize {
let n = web_sys::window()
.map(|w| w.navigator().hardware_concurrency() as usize)
.unwrap_or(4);
n.clamp(1, 4)
}