use std::borrow::Borrow;
use futures::future::join_all;
use js_sys::wasm_bindgen::{prelude::wasm_bindgen, UnwrapThrowExt};
use scheduler::Scheduler;
pub use scheduler::Strategy;
use serde::{Deserialize, Serialize};
use wasm_bindgen_futures::JsFuture;
use web_sys::window;
use crate::{
channel_task::ChannelTask,
error::InitError,
func::{WebWorkerChannelFn, WebWorkerFn},
WebWorker,
};
mod scheduler;
#[wasm_bindgen(getter_with_clone)]
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct WorkerPoolOptions {
pub path: Option<String>,
pub path_bg: Option<String>,
pub strategy: Option<Strategy>,
pub num_workers: Option<usize>,
pub precompile_wasm: Option<bool>,
pub(crate) wasm_module: Option<js_sys::WebAssembly::Module>,
}
#[wasm_bindgen]
impl WorkerPoolOptions {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Default::default()
}
}
impl WorkerPoolOptions {
fn path(&self) -> Option<&str> {
self.path.as_deref()
}
fn path_bg(&self) -> Option<&str> {
self.path_bg.as_deref()
}
fn strategy(&self) -> Strategy {
self.strategy.unwrap_or_default()
}
fn num_workers(&self) -> usize {
self.num_workers.unwrap_or_else(|| {
window()
.expect_throw("Window missing")
.navigator()
.hardware_concurrency() as usize
})
}
}
pub struct WebWorkerPool {
workers: Vec<WebWorker>,
scheduler: Scheduler,
#[allow(dead_code)]
wasm_module: Option<js_sys::WebAssembly::Module>,
}
impl WebWorkerPool {
pub async fn new() -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions::default()).await
}
pub async fn with_strategy(strategy: Strategy) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
strategy: Some(strategy),
..Default::default()
})
.await
}
pub async fn with_num_workers(num_workers: usize) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
num_workers: Some(num_workers),
..Default::default()
})
.await
}
pub async fn with_path(path: String) -> Result<Self, InitError> {
Self::with_options(WorkerPoolOptions {
path: Some(path),
..Default::default()
})
.await
}
pub async fn with_options(mut options: WorkerPoolOptions) -> Result<Self, InitError> {
let wasm_module =
if options.wasm_module.is_none() && options.precompile_wasm.unwrap_or(false) {
Some(Self::precompile_wasm(&options).await?)
} else {
options.wasm_module.take()
};
let worker_inits = (0..options.num_workers()).map(|_| {
WebWorker::with_path_and_module(
options.path(),
options.path_bg(),
None,
wasm_module.clone(),
)
});
let workers = join_all(worker_inits).await;
let workers = workers.into_iter().collect::<Result<Vec<_>, _>>()?;
Ok(Self {
workers,
scheduler: Scheduler::new(options.strategy()),
wasm_module,
})
}
pub async fn run<T, R>(&self, func: WebWorkerFn<T, R>, arg: &T) -> R
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_internal(func, arg).await
}
pub async fn run_channel<T, R>(&self, func: WebWorkerChannelFn<T, R>, arg: &T) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_channel_internal(func, arg).await
}
pub async fn run_bytes(
&self,
func: WebWorkerFn<Box<[u8]>, Box<[u8]>>,
arg: &Box<[u8]>,
) -> Box<[u8]> {
self.run_internal(func, arg).await
}
pub(crate) async fn run_internal<T, R, A>(&self, func: WebWorkerFn<T, R>, arg: A) -> R
where
A: Borrow<T>,
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let worker_id = self.scheduler.schedule(self);
self.workers[worker_id]
.run_internal(func, arg.borrow())
.await
}
pub(crate) async fn run_channel_internal<T, R>(
&self,
func: WebWorkerChannelFn<T, R>,
arg: &T,
) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let worker_id = self.scheduler.schedule(self);
self.workers[worker_id]
.run_channel_internal(func, arg)
.await
}
pub fn current_load(&self) -> usize {
self.workers.iter().map(WebWorker::current_load).sum()
}
pub fn num_workers(&self) -> usize {
self.workers.len()
}
pub async fn with_precompiled_wasm() -> Result<Self, InitError> {
let mut options = WorkerPoolOptions::new();
options.precompile_wasm = Some(true);
Self::with_options(options).await
}
async fn precompile_wasm(
options: &WorkerPoolOptions,
) -> Result<js_sys::WebAssembly::Module, InitError> {
use wasm_bindgen::JsCast;
let wasm_path = if let Some(bg_path) = options.path_bg() {
bg_path.to_string()
} else if let Some(js_path) = options.path() {
if js_path.ends_with(".js") {
js_path.replace(".js", "_bg.wasm")
} else {
format!("{}_bg.wasm", js_path)
}
} else {
let js_path = crate::webworker::js::main_js().as_string().unwrap_throw();
if js_path.ends_with(".js") {
js_path.replace(".js", "_bg.wasm")
} else {
format!("{}_bg.wasm", js_path)
}
};
use wasm_bindgen::UnwrapThrowExt;
let window = web_sys::window().unwrap_throw();
let resp_value = JsFuture::from(window.fetch_with_str(&wasm_path))
.await
.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to fetch WASM from '{}': {:?}. Check that path_bg points to the correct WASM file URL.",
wasm_path, e
))
})?;
let resp: web_sys::Response = resp_value.unchecked_into();
let array_buffer = JsFuture::from(resp.array_buffer().unwrap_throw())
.await
.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to read WASM bytes from '{}': {:?}",
wasm_path, e
))
})?;
let compile_promise = js_sys::WebAssembly::compile(&array_buffer);
let module_value = JsFuture::from(compile_promise).await.map_err(|e| {
InitError::WebWorkerModuleLoading(format!(
"Failed to compile WASM from '{}': {:?}. This usually means the file is not a valid WASM binary or the URL returned an error page.",
wasm_path, e
))
})?;
Ok(module_value.into())
}
}