use std::{borrow::Borrow, cell::RefCell, rc::Rc};
use futures::future::join_all;
use js_sys::wasm_bindgen::{prelude::wasm_bindgen, UnwrapThrowExt};
use scheduler::Scheduler;
pub use scheduler::Strategy;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::Closure;
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::JsFuture;
use web_sys::window;
use crate::{
channel_task::ChannelTask,
error::InitError,
func::{WebWorkerChannelFn, WebWorkerFn},
WebWorker,
};
mod scheduler;
#[wasm_bindgen(getter_with_clone)]
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct WorkerPoolOptions {
pub path: Option<String>,
pub path_bg: Option<String>,
pub strategy: Option<Strategy>,
pub num_workers: Option<usize>,
pub precompile_wasm: Option<bool>,
pub idle_timeout_ms: Option<u32>,
pub(crate) wasm_module: Option<js_sys::WebAssembly::Module>,
}
#[wasm_bindgen]
impl WorkerPoolOptions {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Default::default()
}
}
impl WorkerPoolOptions {
fn path(&self) -> Option<&str> {
self.path.as_deref()
}
fn path_bg(&self) -> Option<&str> {
self.path_bg.as_deref()
}
fn strategy(&self) -> Strategy {
self.strategy.unwrap_or_default()
}
fn num_workers(&self) -> usize {
self.num_workers.unwrap_or_else(|| {
window()
.expect_throw("Window missing")
.navigator()
.hardware_concurrency() as usize
})
}
}
enum WorkerSlot {
Active(WebWorker),
Creating,
Empty,
}
pub struct WebWorkerPool {
slots: Rc<Vec<RefCell<WorkerSlot>>>,
num_slots: usize,
scheduler: Scheduler,
#[allow(dead_code)]
wasm_module: Option<js_sys::WebAssembly::Module>,
pool_path: Option<String>,
pool_path_bg: Option<String>,
_idle_checker_cb: Option<Closure<dyn FnMut()>>,
_idle_checker_id: Option<i32>,
worker_ready: tokio::sync::Notify,
}
impl Drop for WebWorkerPool {
fn drop(&mut self) {
if let Some(id) = self._idle_checker_id {
if let Some(w) = web_sys::window() {
w.clear_interval_with_handle(id);
}
}
}
}
impl WebWorkerPool {
pub async fn new() -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions::default()).await
}
pub async fn with_strategy(strategy: Strategy) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
strategy: Some(strategy),
..Default::default()
})
.await
}
pub async fn with_num_workers(num_workers: usize) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
num_workers: Some(num_workers),
..Default::default()
})
.await
}
pub async fn with_path(path: String) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
path: Some(path),
..Default::default()
})
.await
}
pub async fn with_options(mut options: WorkerPoolOptions) -> Result<Self, InitError> {
let wasm_module =
if options.wasm_module.is_none() && options.precompile_wasm.unwrap_or(false) {
Some(Self::precompile_wasm(&options).await?)
} else {
options.wasm_module.take()
};
let num_slots = options.num_workers().max(1);
let worker_inits = (0..num_slots).map(|_| {
WebWorker::with_path_and_module(
options.path(),
options.path_bg(),
None,
wasm_module.clone(),
)
});
let workers = join_all(worker_inits).await;
let workers = workers.into_iter().collect::<Result<Vec<_>, _>>()?;
let slots: Rc<Vec<RefCell<WorkerSlot>>> = Rc::new(
workers
.into_iter()
.map(|w| RefCell::new(WorkerSlot::Active(w)))
.collect(),
);
let (idle_checker_cb, idle_checker_id) = if let Some(timeout) = options.idle_timeout_ms {
let slots_clone = Rc::clone(&slots);
let cb = Closure::<dyn FnMut()>::new(move || {
let now = js_sys::Date::now();
for i in 0..slots_clone.len() {
let should_terminate = {
let s = slots_clone[i].borrow();
matches!(&*s, WorkerSlot::Active(ref w)
if w.current_load() == 0 && (now - w.last_active()) >= timeout as f64)
};
if should_terminate {
*slots_clone[i].borrow_mut() = WorkerSlot::Empty;
}
}
});
let id = window()
.expect_throw("Window missing")
.set_interval_with_callback_and_timeout_and_arguments_0(
cb.as_ref().unchecked_ref(),
(timeout / 2).max(1).min(i32::MAX as u32) as i32,
)
.expect_throw("Could not set interval");
(Some(cb), Some(id))
} else {
(None, None)
};
Ok(Self {
slots,
num_slots,
scheduler: Scheduler::new(options.strategy()),
wasm_module,
pool_path: options.path.clone(),
pool_path_bg: options.path_bg.clone(),
_idle_checker_cb: idle_checker_cb,
_idle_checker_id: idle_checker_id,
worker_ready: tokio::sync::Notify::new(),
})
}
pub async fn run<T, R>(&self, func: WebWorkerFn<T, R>, arg: &T) -> R
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_internal(func, arg).await
}
pub async fn run_channel<T, R>(&self, func: WebWorkerChannelFn<T, R>, arg: &T) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_channel_internal(func, arg).await
}
pub async fn run_bytes(
&self,
func: WebWorkerFn<Box<[u8]>, Box<[u8]>>,
arg: &Box<[u8]>,
) -> Box<[u8]> {
self.run_internal(func, arg).await
}
async fn acquire_worker(&self) -> usize {
loop {
let loads = self.compute_loads();
if let Some(id) = self.scheduler.schedule(&loads) {
return id;
}
let empty_slot = self
.slots
.iter()
.position(|slot| matches!(&*slot.borrow(), WorkerSlot::Empty));
if let Some(i) = empty_slot {
*self.slots[i].borrow_mut() = WorkerSlot::Creating;
}
if let Some(slot_id) = empty_slot {
let worker_result = WebWorker::with_path_and_module(
self.pool_path.as_deref(),
self.pool_path_bg.as_deref(),
None,
self.wasm_module.clone(),
)
.await;
match worker_result {
Ok(worker) => {
*self.slots[slot_id].borrow_mut() = WorkerSlot::Active(worker);
self.worker_ready.notify_waiters();
return slot_id;
}
Err(_) => {
*self.slots[slot_id].borrow_mut() = WorkerSlot::Empty;
self.worker_ready.notify_waiters();
panic!("Couldn't recreate worker");
}
}
}
self.worker_ready.notified().await;
}
}
fn compute_loads(&self) -> Vec<Option<usize>> {
self.slots
.iter()
.map(|slot| match &*slot.borrow() {
WorkerSlot::Active(w) => Some(w.current_load()),
_ => None,
})
.collect()
}
#[allow(clippy::await_holding_refcell_ref)]
pub(crate) async fn run_internal<T, R, A>(&self, func: WebWorkerFn<T, R>, arg: A) -> R
where
A: Borrow<T>,
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let worker_id = self.acquire_worker().await;
let slot = self.slots[worker_id].borrow();
match &*slot {
WorkerSlot::Active(worker) => worker.run_internal(func, arg.borrow()).await,
_ => unreachable!("acquire_worker guarantees Active slot"),
}
}
#[allow(clippy::await_holding_refcell_ref)]
pub(crate) async fn run_channel_internal<T, R>(
&self,
func: WebWorkerChannelFn<T, R>,
arg: &T,
) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let worker_id = self.acquire_worker().await;
let slot = self.slots[worker_id].borrow();
match &*slot {
WorkerSlot::Active(worker) => worker.run_channel_internal(func, arg).await,
_ => unreachable!("acquire_worker guarantees Active slot"),
}
}
pub fn current_load(&self) -> usize {
self.slots
.iter()
.map(|slot| match &*slot.borrow() {
WorkerSlot::Active(w) => w.current_load(),
_ => 0,
})
.sum()
}
pub fn num_workers(&self) -> usize {
self.num_slots
}
pub fn num_active_workers(&self) -> usize {
self.slots
.iter()
.filter(|s| matches!(&*RefCell::borrow(s), WorkerSlot::Active(_)))
.count()
}
pub async fn with_precompiled_wasm() -> Result<Self, InitError> {
let mut options = WorkerPoolOptions::new();
options.precompile_wasm = Some(true);
Self::with_options(options).await
}
async fn precompile_wasm(
options: &WorkerPoolOptions,
) -> Result<js_sys::WebAssembly::Module, InitError> {
use wasm_bindgen::JsCast;
let wasm_path = if let Some(bg_path) = options.path_bg() {
bg_path.to_string()
} else if let Some(js_path) = options.path() {
if js_path.ends_with(".js") {
js_path.replace(".js", "_bg.wasm")
} else {
format!("{}_bg.wasm", js_path)
}
} else {
let js_path = crate::webworker::js::main_js().as_string().unwrap_throw();
if js_path.ends_with(".js") {
js_path.replace(".js", "_bg.wasm")
} else {
format!("{}_bg.wasm", js_path)
}
};
use wasm_bindgen::UnwrapThrowExt;
let window = web_sys::window().unwrap_throw();
let resp_value = JsFuture::from(window.fetch_with_str(&wasm_path))
.await
.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to fetch WASM from '{}': {:?}. Check that path_bg points to the correct WASM file URL.",
wasm_path, e
))
})?;
let resp: web_sys::Response = resp_value.unchecked_into();
let array_buffer = JsFuture::from(resp.array_buffer().unwrap_throw())
.await
.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to read WASM bytes from '{}': {:?}",
wasm_path, e
))
})?;
let compile_promise = js_sys::WebAssembly::compile(&array_buffer);
let module_value = JsFuture::from(compile_promise).await.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to compile WASM from '{}': {:?}. This usually means the file is not a valid WASM binary or the URL returned an error page.",
wasm_path, e
))
})?;
Ok(module_value.into())
}
}