use std::{
cell::UnsafeCell,
fmt,
marker::PhantomData,
mem,
panic::{catch_unwind, AssertUnwindSafe},
rc::Rc,
sync::{Arc, Mutex},
thread::{Result, Thread},
};
use scoped::ScopeData;
pub use scoped::{scope, Scope, ScopedJoinHandle};
use signal::Signal;
use utils::SpinLockMutex;
pub use utils::{available_parallelism, get_wasm_bindgen_shim_script_path, get_worker_script, is_web_worker_thread};
use wasm_bindgen::prelude::*;
use web_sys::{DedicatedWorkerGlobalScope, Worker, WorkerOptions, WorkerType};
mod scoped;
mod signal;
mod utils;
struct WebWorkerContext {
func: Box<dyn FnOnce() + Send>,
}
#[wasm_bindgen]
pub fn wasm_thread_entry_point(ptr: u32) {
let ctx = unsafe { Box::from_raw(ptr as *mut WebWorkerContext) };
(ctx.func)();
WorkerMessage::ThreadComplete.post();
}
struct BuilderRequest {
builder: Builder,
context: WebWorkerContext,
}
impl BuilderRequest {
pub unsafe fn spawn(self) {
self.builder.spawn_for_context(self.context);
}
}
enum WorkerMessage {
SpawnThread(BuilderRequest),
ThreadComplete,
}
impl WorkerMessage {
pub fn post(self) {
let req = Box::new(self);
js_sys::eval("self")
.unwrap()
.dyn_into::<DedicatedWorkerGlobalScope>()
.unwrap()
.post_message(&JsValue::from(Box::into_raw(req) as u32))
.unwrap();
}
}
static DEFAULT_BUILDER: Mutex<Option<Builder>> = Mutex::new(None);
#[derive(Debug, Clone)]
pub struct Builder {
name: Option<String>,
prefix: Option<String>,
worker_script_url: Option<String>,
stack_size: Option<usize>,
wasm_bindgen_shim_url: Option<String>,
}
impl Default for Builder {
fn default() -> Self {
DEFAULT_BUILDER.lock_spin().unwrap().clone().unwrap_or(Self::empty())
}
}
impl Builder {
pub fn new() -> Builder {
Builder::default()
}
pub fn empty() -> Builder {
Self {
name: None,
prefix: None,
worker_script_url: None,
stack_size: None,
wasm_bindgen_shim_url: None,
}
}
pub fn set_default(self) {
*DEFAULT_BUILDER.lock_spin().unwrap() = Some(self);
}
pub fn prefix(mut self, prefix: String) -> Builder {
self.prefix = Some(prefix);
self
}
pub fn worker_script_url(mut self, worker_script_url: String) -> Builder {
self.worker_script_url = Some(worker_script_url);
self
}
pub fn name(mut self, name: String) -> Builder {
self.name = Some(name);
self
}
pub fn stack_size(mut self, size: usize) -> Builder {
self.stack_size = Some(size);
self
}
pub fn wasm_bindgen_shim_url(mut self, url: String) -> Builder {
self.wasm_bindgen_shim_url = Some(url);
self
}
pub fn spawn<F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
unsafe { self.spawn_unchecked(f) }
}
pub unsafe fn spawn_unchecked<'a, F, T>(self, f: F) -> std::io::Result<JoinHandle<T>>
where
F: FnOnce() -> T,
F: Send + 'a,
T: Send + 'a,
{
Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?))
}
pub(crate) unsafe fn spawn_unchecked_<'a, 'scope, F, T>(
self,
f: F,
scope_data: Option<Arc<ScopeData>>,
) -> std::io::Result<JoinInner<'scope, T>>
where
F: FnOnce() -> T,
F: Send + 'a,
T: Send + 'a,
'scope: 'a,
{
let my_signal = Arc::new(Signal::new());
let their_signal = my_signal.clone();
let my_packet: Arc<Packet<'scope, T>> = Arc::new(Packet {
scope: scope_data,
result: UnsafeCell::new(None),
_marker: PhantomData,
});
let their_packet = my_packet.clone();
#[repr(transparent)]
struct MaybeDangling<T>(mem::MaybeUninit<T>);
impl<T> MaybeDangling<T> {
fn new(x: T) -> Self {
MaybeDangling(mem::MaybeUninit::new(x))
}
fn into_inner(self) -> T {
let ret = unsafe { self.0.assume_init_read() };
mem::forget(self);
ret
}
}
impl<T> Drop for MaybeDangling<T> {
fn drop(&mut self) {
unsafe { self.0.assume_init_drop() };
}
}
let f = MaybeDangling::new(f);
let main = Box::new(move || {
let f = f.into_inner();
let try_result = catch_unwind(AssertUnwindSafe(|| f()));
unsafe { *their_packet.result.get() = Some(try_result) };
drop(their_packet);
their_signal.signal();
});
let context = WebWorkerContext {
func: mem::transmute::<Box<dyn FnOnce() + Send + 'a>, Box<dyn FnOnce() + Send + 'static>>(main),
};
if is_web_worker_thread() {
WorkerMessage::SpawnThread(BuilderRequest { builder: self, context }).post();
} else {
self.spawn_for_context(context);
}
if let Some(scope) = &my_packet.scope {
scope.increment_num_running_threads();
}
Ok(JoinInner {
signal: my_signal,
packet: my_packet,
})
}
unsafe fn spawn_for_context(self, ctx: WebWorkerContext) {
let Builder {
name,
prefix,
worker_script_url,
wasm_bindgen_shim_url,
..
} = self;
let script = worker_script_url.unwrap_or(get_worker_script(wasm_bindgen_shim_url));
let options = WorkerOptions::new();
match (name, prefix) {
(Some(name), Some(prefix)) => {
options.set_name(&format!("{}:{}", prefix, name));
}
(Some(name), None) => {
options.set_name(&name);
}
(None, Some(prefix)) => {
let random = (js_sys::Math::random() * 10e10) as u64;
options.set_name(&format!("{}:{}", prefix, random));
}
(None, None) => {}
};
#[cfg(feature = "es_modules")]
{
js_sys::eval(include_str!("js/module_workers_polyfill.min.js")).unwrap();
options.set_type(WorkerType::Module);
}
#[cfg(not(feature = "es_modules"))]
{
options.set_type(WorkerType::Classic);
}
let worker = Rc::new(Worker::new_with_options(script.as_str(), &options).unwrap());
let mut their_worker = Some(worker.clone());
let callback = Closure::wrap(Box::new(move |x: &web_sys::MessageEvent| {
let req = Box::from_raw(x.data().as_f64().unwrap() as u32 as *mut WorkerMessage);
match *req {
WorkerMessage::SpawnThread(builder) => {
builder.spawn();
}
WorkerMessage::ThreadComplete => {
their_worker.take();
}
};
}) as Box<dyn FnMut(&web_sys::MessageEvent)>);
worker.set_onmessage(Some(callback.as_ref().unchecked_ref()));
callback.forget();
let ctx_ptr = Box::into_raw(Box::new(ctx));
let init = js_sys::Array::new();
init.push(&wasm_bindgen::module());
init.push(&wasm_bindgen::memory());
init.push(&JsValue::from(ctx_ptr as u32));
match worker.post_message(&init) {
Ok(()) => Ok(worker),
Err(e) => {
drop(Box::from_raw(ctx_ptr));
Err(e)
}
}
.unwrap();
}
}
struct Packet<'scope, T> {
scope: Option<Arc<ScopeData>>,
result: UnsafeCell<Option<Result<T>>>,
_marker: PhantomData<Option<&'scope ScopeData>>,
}
unsafe impl<'scope, T: Send> Sync for Packet<'scope, T> {}
impl<'scope, T> Drop for Packet<'scope, T> {
fn drop(&mut self) {
let unhandled_panic = matches!(self.result.get_mut(), Some(Err(_)));
if let Err(_) = catch_unwind(AssertUnwindSafe(|| {
*self.result.get_mut() = None;
})) {
panic!("thread result panicked on drop");
}
if let Some(scope) = &self.scope {
scope.decrement_num_running_threads(unhandled_panic);
}
}
}
pub(crate) struct JoinInner<'scope, T> {
packet: Arc<Packet<'scope, T>>,
signal: Arc<Signal>,
}
impl<'scope, T> JoinInner<'scope, T> {
pub fn join(mut self) -> Result<T> {
self.signal.wait();
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
}
pub async fn join_async(mut self) -> Result<T> {
self.signal.wait_async().await;
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
}
}
pub struct JoinHandle<T>(JoinInner<'static, T>);
impl<T> JoinHandle<T> {
pub fn thread(&self) -> &Thread {
unimplemented!();
}
pub fn join(self) -> Result<T> {
self.0.join()
}
pub async fn join_async(self) -> Result<T> {
self.0.join_async().await
}
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
}
}
impl<T> fmt::Debug for JoinHandle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("JoinHandle { .. }")
}
}
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
Builder::new().spawn(f).expect("failed to spawn thread")
}