#![allow(clippy::unused_unit)]
use futures::channel::oneshot;
use futures::future::BoxFuture;
use futures::StreamExt;
use futures::{channel::mpsc, Future};
use js_sys::{JsString, Promise};
use log::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use wasm_bindgen::{prelude::*, JsCast};
use web_sys::{DedicatedWorkerGlobalScope, WorkerOptions, WorkerType};
trait AssertSendSync: Send + Sync {}
impl AssertSendSync for ThreadPool {}
#[wasm_bindgen]
pub struct ThreadPool {
state: Arc<PoolState>,
}
impl Clone for ThreadPool {
fn clone(&self) -> Self {
self.state.cnt.fetch_add(1, Ordering::Relaxed);
Self {
state: self.state.clone(),
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
for _ in 0..self.state.size {
self.state.send(Message::Close);
}
}
}
}
#[wasm_bindgen]
pub struct LoaderHelper {}
#[wasm_bindgen]
impl LoaderHelper {
#[wasm_bindgen(js_name = mainJS)]
pub fn main_js(&self) -> JsString {
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = ["import", "meta"], js_name = url)]
static URL: JsString;
}
URL.clone()
}
}
#[wasm_bindgen(module = "/worker.js")]
extern "C" {
#[wasm_bindgen(js_name = "startWorker")]
fn start_worker(
module: JsValue,
memory: JsValue,
shared_data: JsValue,
opts: WorkerOptions,
builder: LoaderHelper,
) -> Promise;
}
impl ThreadPool {
pub async fn new(size: usize) -> Result<ThreadPool, JsValue> {
let (tx, rx) = mpsc::channel(64);
let pool = ThreadPool {
state: Arc::new(PoolState {
tx: parking_lot::Mutex::new(tx),
rx: tokio::sync::Mutex::new(rx),
cnt: AtomicUsize::new(1),
size,
}),
};
for idx in 0..size {
let state = pool.state.clone();
let mut opts = WorkerOptions::new();
opts.type_(WorkerType::Module);
opts.name(&*format!("Worker-{}", idx));
let ptr = Arc::into_raw(state);
let _worker = wasm_bindgen_futures::JsFuture::from(start_worker(
wasm_bindgen::module(),
wasm_bindgen::memory(),
JsValue::from(ptr as u32),
opts,
LoaderHelper {},
))
.await?;
}
Ok(pool)
}
pub async fn max_threads() -> Result<Self, JsValue> {
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = navigator, js_name = hardwareConcurrency)]
static HARDWARE_CONCURRENCY: usize;
}
let pool_size = std::cmp::max(*HARDWARE_CONCURRENCY, 1);
Self::new(pool_size).await
}
pub fn spawn_ok<Fut>(&self, future: Fut)
where
Fut: Future<Output = ()> + Send + 'static,
{
self.state.send(Message::Run(Box::pin(future)));
}
pub fn spawn<Fut>(
&self,
future: Fut,
) -> impl Future<Output = Result<Fut::Output, oneshot::Canceled>> + 'static
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let f = async move {
let res = future.await;
let _ = tx.send(res);
};
self.spawn_ok(f);
rx
}
}
enum Message {
Run(BoxFuture<'static, ()>),
Close,
}
pub struct PoolState {
tx: parking_lot::Mutex<mpsc::Sender<Message>>,
rx: tokio::sync::Mutex<mpsc::Receiver<Message>>,
cnt: AtomicUsize,
size: usize,
}
impl PoolState {
fn send(&self, msg: Message) {
self.tx.lock().start_send(msg).unwrap()
}
fn work(slf: Arc<PoolState>) {
let driver = async move {
let global = js_sys::global().unchecked_into::<DedicatedWorkerGlobalScope>();
while let Some(msg) = slf.rx.lock().await.next().await {
match msg {
Message::Run(future) => wasm_bindgen_futures::spawn_local(future),
Message::Close => break,
}
}
info!("{}: Shutting down", global.name());
global.close();
};
wasm_bindgen_futures::spawn_local(driver);
}
}
#[wasm_bindgen(skip_typescript)]
pub fn worker_entry_point(state_ptr: u32) {
let state = unsafe { Arc::<PoolState>::from_raw(state_ptr as *const PoolState) };
let name = js_sys::global()
.unchecked_into::<DedicatedWorkerGlobalScope>()
.name();
debug!("{}: Entry", name);
PoolState::work(state);
}