use std::{
cell::RefCell,
collections::HashMap,
rc::Rc,
sync::atomic::{AtomicU32, Ordering},
};
use super::com::*;
use super::js::*;
use js_sys::Array;
use serde::{Deserialize, Serialize};
use tokio::sync::{oneshot, Semaphore};
use wasm_bindgen::{prelude::Closure, JsCast, JsValue, UnwrapThrowExt};
use web_sys::{
Blob, BlobPropertyBag, MessageChannel, MessageEvent, MessagePort, Url, Worker, WorkerOptions,
WorkerType,
};
use crate::{
channel::Channel,
channel_task::ChannelTask,
convert::{from_bytes, to_bytes},
error::{Full, InitError},
func::{WebWorkerChannelFn, WebWorkerFn},
};
type Callback = dyn FnMut(MessageEvent);
pub struct WebWorker {
worker: Worker,
task_limit: Option<Semaphore>,
current_task: AtomicU32,
open_tasks: Rc<RefCell<HashMap<u32, oneshot::Sender<Response>>>>,
_callback: Closure<Callback>,
}
impl WebWorker {
fn worker_blob(
wasm_path: Option<&str>,
wasm_bg_path: Option<&str>,
has_precompiled_module: bool,
) -> String {
let blob_options = BlobPropertyBag::new();
blob_options.set_type("application/javascript");
let mut wasm_path_owned = None;
let wasm_path = wasm_path.unwrap_or_else(|| {
wasm_path_owned = Some(main_js().as_string().unwrap_throw());
wasm_path_owned.as_ref().unwrap_throw()
});
let wasm_bg_path = match wasm_bg_path {
Some(path) => format!("{{module_or_path: '{path}'}}"),
None => "undefined".to_string(),
};
let worker_js = if has_precompiled_module {
super::js::WORKER_JS_WITH_PRECOMPILED
} else {
super::js::WORKER_JS
};
let code = Array::new();
code.push(&JsValue::from_str(
&worker_js
.replace("{{wasm}}", wasm_path)
.replace("{{wasm_bg}}", &wasm_bg_path),
));
Url::create_object_url_with_blob(
&Blob::new_with_blob_sequence_and_options(&code.into(), &blob_options)
.expect_throw("Couldn't create blob"),
)
.expect_throw("Couldn't create object URL")
}
pub async fn new(task_limit: Option<usize>) -> Result<WebWorker, InitError> {
Self::with_path(None, None, task_limit).await
}
pub async fn with_path(
main_js: Option<&str>,
main_bg_js: Option<&str>,
task_limit: Option<usize>,
) -> Result<WebWorker, InitError> {
Self::with_path_and_module(main_js, main_bg_js, task_limit, None).await
}
pub async fn with_path_and_module(
main_js: Option<&str>,
main_bg_js: Option<&str>,
task_limit: Option<usize>,
wasm_module: Option<js_sys::WebAssembly::Module>,
) -> Result<WebWorker, InitError> {
let worker_options = WorkerOptions::new();
worker_options.set_type(WorkerType::Module);
let script_url = WebWorker::worker_blob(main_js, main_bg_js, wasm_module.is_some());
let worker = Worker::new_with_options(&script_url, &worker_options)
.map_err(InitError::WebWorkerCreation)?;
if let Some(module) = wasm_module {
let init_msg = js_sys::Object::new();
js_sys::Reflect::set(
&init_msg,
&JsValue::from_str("type"),
&JsValue::from_str("wasm_module"),
)
.expect_throw("Could not set type");
js_sys::Reflect::set(&init_msg, &JsValue::from_str("module"), &module)
.expect_throw("Could not set module");
worker
.post_message(&init_msg)
.expect_throw("Could not send WASM module to worker");
}
let (tx, rx) = oneshot::channel();
let handler = Closure::once(move |event: MessageEvent| {
let data = event.data();
let post_init: PostInit = serde_wasm_bindgen::from_value(data)
.expect_throw("Error deserializing post init data");
let _ = tx.send(post_init);
});
worker.set_onmessage(Some(handler.as_ref().unchecked_ref()));
let post_init = rx.await.expect_throw("WebWorker init sender dropped");
if !post_init.success {
return Err(InitError::WebWorkerModuleLoading(
post_init
.message
.expect_throw("Post init should have error message"),
));
}
let tasks = Rc::new(RefCell::new(HashMap::new()));
let callback_handle = Self::callback(Rc::clone(&tasks));
worker.set_onmessage(Some(callback_handle.as_ref().unchecked_ref()));
Ok(WebWorker {
worker,
task_limit: task_limit.map(|limit| Semaphore::new(limit)),
current_task: AtomicU32::new(0),
open_tasks: tasks,
_callback: callback_handle,
})
}
fn callback(
tasks: Rc<RefCell<HashMap<u32, oneshot::Sender<Response>>>>,
) -> Closure<dyn FnMut(MessageEvent)> {
Closure::new(move |event: MessageEvent| {
let data = event.data();
let response: Response =
serde_wasm_bindgen::from_value(data).expect_throw("Could not deserialize response");
let mut tasks_wg = tasks.borrow_mut();
if let Some(channel) = tasks_wg.remove(&response.id) {
let _ = channel.send(response);
}
})
}
pub async fn run<T, R>(&self, func: WebWorkerFn<T, R>, arg: &T) -> R
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_internal(func, arg).await
}
pub async fn run_channel<T, R>(&self, func: WebWorkerChannelFn<T, R>, arg: &T) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.run_channel_internal(func, arg).await
}
pub async fn try_run<T, R>(&self, func: WebWorkerFn<T, R>, arg: &T) -> Result<R, Full>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
self.try_run_internal(func, arg).await
}
pub async fn run_bytes(
&self,
func: WebWorkerFn<Box<[u8]>, Box<[u8]>>,
arg: &Box<[u8]>,
) -> Box<[u8]> {
self.run_internal(func, arg).await
}
pub async fn try_run_bytes(
&self,
func: WebWorkerFn<Box<[u8]>, Box<[u8]>>,
arg: &Box<[u8]>,
) -> Result<Box<[u8]>, Full> {
self.try_run_internal(func, arg).await
}
pub(crate) async fn try_run_internal<T, R>(
&self,
func: WebWorkerFn<T, R>,
arg: &T,
) -> Result<R, Full>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let _permit = if let Some(ref s) = self.task_limit {
Some(match s.try_acquire() {
Ok(permit) => permit,
Err(_) => return Err(Full),
})
} else {
None
};
Ok(self.force_run(func.name, arg, false, None).await)
}
pub(crate) async fn run_internal<T, R>(&self, func: WebWorkerFn<T, R>, arg: &T) -> R
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let _permit = if let Some(ref s) = self.task_limit {
Some(s.acquire().await.unwrap())
} else {
None
};
self.force_run(func.name, arg, false, None).await
}
pub(crate) async fn run_channel_internal<T, R>(
&self,
func: WebWorkerChannelFn<T, R>,
arg: &T,
) -> ChannelTask<R>
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let _permit = if let Some(ref s) = self.task_limit {
Some(s.acquire().await.unwrap())
} else {
None
};
let msg_channel = MessageChannel::new().expect_throw("Could not create MessageChannel");
let channel = Channel::from(msg_channel.port1());
let worker_port = msg_channel.port2();
let result_rx = self.send_channel_request(func.name, arg, worker_port);
ChannelTask::new(channel, result_rx)
}
async fn force_run<T, R>(
&self,
func_name: &'static str,
arg: &T,
is_channel: bool,
port: Option<MessagePort>,
) -> R
where
T: Serialize + for<'de> Deserialize<'de>,
R: Serialize + for<'de> Deserialize<'de>,
{
let id = self.current_task.fetch_add(1, Ordering::Relaxed);
let request = Request {
id,
func_name,
is_channel,
arg: to_bytes(arg),
};
let res = self.send_request(id, request, port).await;
from_bytes(&res)
}
async fn send_request(&self, id: u32, request: Request, port: Option<MessagePort>) -> Vec<u8> {
let (sender, receiver) = oneshot::channel();
self.open_tasks.borrow_mut().insert(id, sender);
if let Some(port) = port {
let transfer = Array::new();
transfer.push(&port);
self.worker
.post_message_with_transfer(
&serde_wasm_bindgen::to_value(&request)
.expect_throw("Could not serialize request"),
&transfer,
)
.expect_throw("WebWorker gone");
} else {
self.worker
.post_message(
&serde_wasm_bindgen::to_value(&request)
.expect_throw("Could not serialize request"),
)
.expect_throw("WebWorker gone");
}
receiver
.await
.expect_throw("WebWorker gone")
.response
.expect_throw("Could not find function")
}
fn send_channel_request<T>(
&self,
func_name: &'static str,
arg: &T,
port: MessagePort,
) -> oneshot::Receiver<Vec<u8>>
where
T: Serialize + for<'de> Deserialize<'de>,
{
let id = self.current_task.fetch_add(1, Ordering::Relaxed);
let request = Request {
id,
func_name,
is_channel: true,
arg: to_bytes(arg),
};
let (sender, receiver) = oneshot::channel();
self.open_tasks.borrow_mut().insert(id, sender);
let transfer = Array::new();
transfer.push(&port);
self.worker
.post_message_with_transfer(
&serde_wasm_bindgen::to_value(&request).expect_throw("Could not serialize request"),
&transfer,
)
.expect_throw("WebWorker gone");
let (byte_sender, byte_receiver) = oneshot::channel();
wasm_bindgen_futures::spawn_local(async move {
if let Ok(response) = receiver.await {
let _ = byte_sender.send(response.response.expect("Could not find function"));
}
});
byte_receiver
}
pub fn capacity(&self) -> Option<usize> {
self.task_limit.as_ref().map(|s| s.available_permits())
}
pub fn current_load(&self) -> usize {
self.open_tasks.borrow().len()
}
}
impl Drop for WebWorker {
fn drop(&mut self) {
self.worker.terminate();
}
}