use rong::{Source, spawn_local, *};
use std::cell::RefCell;
use std::path::PathBuf;
use std::rc::Rc;
use std::sync::{
Arc, OnceLock,
atomic::{AtomicBool, Ordering},
};
use tokio::sync::mpsc;
use tracing::{error, warn};
type WorkerInitializer = Box<dyn Fn(&JSContext) -> JSResult<()> + Send + Sync>;
static WORKER_INITIALIZER: OnceLock<WorkerInitializer> = OnceLock::new();
pub fn set_initializer<F>(f: F)
where
F: Fn(&JSContext) -> JSResult<()> + Send + Sync + 'static,
{
let _ = WORKER_INITIALIZER.set(Box::new(f));
}
pub fn init(ctx: &JSContext) -> JSResult<()> {
ctx.register_class::<Worker>()?;
Ok(())
}
fn js_value_to_json(ctx: &JSContext, data: &JSValue) -> JSResult<String> {
if let Some(obj) = data.clone().into_object() {
obj.to_json_string()
} else {
let json_obj = ctx.global().get::<_, JSObject>("JSON")?;
let stringify = json_obj.get::<_, JSFunc>("stringify")?;
stringify.call::<_, String>(None, (data.clone(),))
}
}
enum ToWorker {
Message(String),
Terminate,
}
enum FromWorker {
Message(String),
Error(String),
}
#[js_export]
pub struct Worker {
to_worker: mpsc::Sender<ToWorker>,
from_worker: Arc<tokio::sync::Mutex<mpsc::Receiver<FromWorker>>>,
message_handler: Rc<RefCell<Option<JSFunc>>>,
error_handler: Rc<RefCell<Option<JSFunc>>>,
terminated: Arc<AtomicBool>,
polling_started: Arc<AtomicBool>,
polling_handle: Rc<RefCell<Option<tokio::task::JoinHandle<()>>>>,
thread_handle: Arc<std::sync::Mutex<Option<std::thread::JoinHandle<()>>>>,
}
#[js_class]
impl Worker {
#[js_method(constructor)]
fn new(_ctx: JSContext, path: String) -> JSResult<Self> {
let (to_worker_tx, to_worker_rx) = mpsc::channel::<ToWorker>(256);
let (from_worker_tx, from_worker_rx) = mpsc::channel::<FromWorker>(256);
let script_path = if PathBuf::from(&path).is_absolute() {
PathBuf::from(&path)
} else {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(&path)
};
let terminated = Arc::new(AtomicBool::new(false));
let terminated_thread = terminated.clone();
let thread_handle = std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create worker tokio runtime");
rt.block_on(async {
let local = tokio::task::LocalSet::new();
local
.run_until(Self::run_worker_thread(
script_path,
to_worker_rx,
from_worker_tx,
terminated_thread,
))
.await;
});
});
Ok(Worker {
to_worker: to_worker_tx,
from_worker: Arc::new(tokio::sync::Mutex::new(from_worker_rx)),
message_handler: Rc::new(RefCell::new(None)),
error_handler: Rc::new(RefCell::new(None)),
terminated,
polling_started: Arc::new(AtomicBool::new(false)),
polling_handle: Rc::new(RefCell::new(None)),
thread_handle: Arc::new(std::sync::Mutex::new(Some(thread_handle))),
})
}
#[js_method(rename = "postMessage")]
fn post_message(&self, ctx: JSContext, data: JSValue) -> JSResult<()> {
if self.terminated.load(Ordering::Acquire) {
return Ok(());
}
let json = js_value_to_json(&ctx, &data)?;
let tx = self.to_worker.clone();
spawn_local(async move {
let _ = tx.send(ToWorker::Message(json)).await;
});
Ok(())
}
#[js_method]
fn terminate(&self) -> JSResult<()> {
if self.terminated.swap(true, Ordering::AcqRel) {
return Ok(());
}
if let Some(handle) = self.polling_handle.borrow_mut().take() {
handle.abort();
}
let tx = self.to_worker.clone();
spawn_local(async move {
let _ = tx.send(ToWorker::Terminate).await;
});
let thread_handle = self
.thread_handle
.lock()
.ok()
.and_then(|mut guard| guard.take());
if let Some(handle) = thread_handle {
RongExecutor::global().spawn_blocking(move || {
let _ = handle.join();
});
}
Ok(())
}
#[js_method(setter, rename = "onmessage")]
fn set_onmessage(&self, ctx: JSContext, handler: JSFunc) -> JSResult<()> {
*self.message_handler.borrow_mut() = Some(handler);
self.ensure_polling(ctx);
Ok(())
}
#[js_method(setter, rename = "onerror")]
fn set_onerror(&self, ctx: JSContext, handler: JSFunc) -> JSResult<()> {
*self.error_handler.borrow_mut() = Some(handler);
self.ensure_polling(ctx);
Ok(())
}
fn ensure_polling(&self, ctx: JSContext) {
if self.polling_started.swap(true, Ordering::AcqRel) {
return;
}
let from_worker = self.from_worker.clone();
let message_handler = self.message_handler.clone();
let error_handler = self.error_handler.clone();
let terminated = self.terminated.clone();
let polling_handle_slot = self.polling_handle.clone();
let polling_handle = spawn_local(async move {
loop {
if terminated.load(Ordering::Acquire) {
break;
}
let msg = {
let mut rx = from_worker.lock().await;
rx.recv().await
};
match msg {
Some(FromWorker::Message(json_str)) => {
if terminated.load(Ordering::Acquire) {
break;
}
match JSObject::from_json_string(&ctx, &json_str) {
Ok(value) => {
let handler = message_handler.borrow().clone();
if let Some(func) = handler {
let event = JSObject::new(&ctx);
event.set("data", value).ok();
if let Err(e) = func.call_async::<_, ()>(None, (event,)).await {
let err_handler = error_handler.borrow().clone();
if let Some(err_fn) = err_handler {
let err_message = worker_error_message(&ctx, e);
let err_event =
worker_error_event(&ctx, err_message.as_str());
let _ = err_fn
.call_async::<_, ()>(None, (err_event,))
.await;
} else {
error!(target: "rong", error = ?e, "worker onmessage handler failed");
}
}
}
}
Err(e) => {
warn!(target: "rong", error = ?e, "worker failed to deserialize JSON message");
}
}
}
Some(FromWorker::Error(message)) => {
let err_handler = error_handler.borrow().clone();
if let Some(err_fn) = err_handler {
let err_event = worker_error_event(&ctx, &message);
let _ = err_fn.call_async::<_, ()>(None, (err_event,)).await;
} else {
error!(target: "rong", message = %message, "worker emitted error event without handler");
}
}
None => break,
}
}
});
*polling_handle_slot.borrow_mut() = Some(polling_handle);
}
#[js_method(gc_mark)]
fn gc_mark_with<F>(&self, mut mark_fn: F)
where
F: FnMut(&JSValue),
{
for slot in [&self.message_handler, &self.error_handler] {
if let Some(handler) = slot.borrow().clone() {
mark_fn(handler.as_js_value());
}
}
}
}
impl Worker {
async fn run_worker_thread(
script_path: PathBuf,
mut to_worker_rx: mpsc::Receiver<ToWorker>,
from_worker_tx: mpsc::Sender<FromWorker>,
terminated: Arc<AtomicBool>,
) {
let runtime = RongJS::runtime();
let ctx = runtime.context();
if let Some(initializer) = WORKER_INITIALIZER.get() {
if let Err(e) = initializer(&ctx) {
let _ = from_worker_tx
.send(FromWorker::Error(format!(
"initializer failed: {}",
worker_error_message(&ctx, e)
)))
.await;
return;
}
} else {
rong_console::init(&ctx).ok();
}
let tx = from_worker_tx.clone();
let post_ctx = ctx.clone();
let post_message_fn = JSFunc::new(&ctx, move |data: JSValue| {
let c = post_ctx.clone();
let t = tx.clone();
spawn_local(async move {
match js_value_to_json(&c, &data) {
Ok(json) => {
let _ = t.send(FromWorker::Message(json)).await;
}
Err(e) => {
let _ = t
.send(FromWorker::Error(format!(
"postMessage serialization failed: {}",
worker_error_message(&c, e)
)))
.await;
}
}
});
});
ctx.global().set("postMessage", post_message_fn).ok();
let terminated_close = terminated.clone();
let close_fn = JSFunc::new(&ctx, move || {
terminated_close.store(true, Ordering::Release);
});
let global = ctx.global();
global.set("close", close_fn).ok();
global.set("self", global.clone()).ok();
match Source::from_path(&ctx, &script_path).await {
Ok(source) => {
if let Err(e) = ctx.eval_async::<()>(source).await {
let _ = from_worker_tx
.send(FromWorker::Error(format!(
"script error in {:?}: {}",
script_path,
worker_error_message(&ctx, e)
)))
.await;
return;
}
}
Err(e) => {
let _ = from_worker_tx
.send(FromWorker::Error(format!(
"failed to load {:?}: {}",
script_path, e
)))
.await;
return;
}
}
loop {
if terminated.load(Ordering::Acquire) {
break;
}
match to_worker_rx.recv().await {
Some(ToWorker::Message(json_str)) => {
match JSObject::from_json_string(&ctx, &json_str) {
Ok(data) => {
if let Ok(handler) = ctx.global().get::<_, JSValue>("onmessage")
&& let Ok(func) = handler.to_rust::<JSFunc>()
{
let event = JSObject::new(&ctx);
event.set("data", data).ok();
if let Err(e) = func.call_async::<_, ()>(None, (event,)).await {
let _ = from_worker_tx
.send(FromWorker::Error(format!(
"worker onmessage handler error: {}",
worker_error_message(&ctx, e)
)))
.await;
}
}
}
Err(e) => {
let _ = from_worker_tx
.send(FromWorker::Error(format!(
"JSON deserialization failed: {}",
e
)))
.await;
}
}
}
Some(ToWorker::Terminate) | None => break,
}
}
}
}
fn worker_error_message(ctx: &JSContext, err: RongJSError) -> String {
err.into_host_in(ctx)
.into_host_error()
.map(|host| host.message)
.unwrap_or_else(|| "Worker error".to_string())
}
fn worker_error_event(ctx: &JSContext, message: &str) -> JSObject {
let event = JSObject::new(ctx);
let _ = event.set("type", "error");
let _ = event.set("message", message);
event
}
#[cfg(test)]
mod tests {
use super::*;
use rong_test::*;
#[test]
fn test_worker() {
let workspace_root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.canonicalize()
.expect("workspace root");
std::env::set_current_dir(&workspace_root).expect("set cwd");
set_initializer(|ctx| {
rong_console::init(ctx)?;
rong_timer::init(ctx)?;
Ok(())
});
async_run!(|ctx: JSContext| async move {
init(&ctx)?;
rong_console::init(&ctx)?;
rong_assert::init(&ctx)?;
rong_timer::init(&ctx)?;
let passed = UnitJSRunner::load_script(&ctx, "worker.js")
.await?
.run()
.await?;
assert!(passed);
Ok(())
})
}
}