use std::cell::RefCell;
use std::sync::mpsc;
use tracing::debug;
pub struct TaskRequest {
pub method: String,
pub payload: serde_json::Value,
pub reply: tokio::sync::oneshot::Sender<anyhow::Result<serde_json::Value>>,
}
struct WorkerState {
worker_id: u32,
task_rx: mpsc::Receiver<TaskRequest>,
ready_tx: Option<mpsc::SyncSender<()>>,
current_reply: Option<tokio::sync::oneshot::Sender<anyhow::Result<serde_json::Value>>>,
}
thread_local! {
static WORKER: RefCell<Option<WorkerState>> = const { RefCell::new(None) };
}
pub fn init_worker_state(
worker_id: u32,
task_rx: mpsc::Receiver<TaskRequest>,
ready_tx: mpsc::SyncSender<()>,
) {
WORKER.with(|w| {
*w.borrow_mut() = Some(WorkerState {
worker_id,
task_rx,
ready_tx: Some(ready_tx),
current_reply: None,
});
});
}
pub fn cleanup_worker_state() {
WORKER.with(|w| {
*w.borrow_mut() = None;
});
}
pub fn has_worker_state() -> bool {
WORKER.with(|w| w.borrow().is_some())
}
pub fn do_ready() -> Result<bool, &'static str> {
WORKER.with(|w| {
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
if let Some(tx) = state.ready_tx.take() {
let _ = tx.send(());
debug!(worker_id = state.worker_id, "worker signaled ready");
Ok(true)
} else {
Ok(false)
}
})
}
pub fn do_recv() -> Result<Option<(String, Vec<u8>)>, &'static str> {
WORKER.with(|w| {
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
if let Ok(req) = state.task_rx.recv() {
let method = req.method.clone();
let payload_bytes = serde_json::to_vec(&req.payload).unwrap_or_default();
state.current_reply = Some(req.reply);
Ok(Some((method, payload_bytes)))
} else {
debug!(worker_id = state.worker_id, "recv: channel closed");
Ok(None)
}
})
}
pub fn do_send(data: &[u8]) -> Result<(), &'static str> {
WORKER.with(|w| {
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
let reply = state.current_reply.take().ok_or("no pending request")?;
let value: serde_json::Value =
serde_json::from_slice(data).unwrap_or(serde_json::Value::Null);
let _ = reply.send(Ok(value));
Ok(())
})
}
pub fn run_dispatch_loop(dispatch_fn: &str) -> Result<(), &'static str> {
WORKER.with(|w| {
{
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
if let Some(tx) = state.ready_tx.take() {
let _ = tx.send(());
debug!(
worker_id = state.worker_id,
"worker signaled ready (dispatch loop)"
);
}
}
loop {
let req = {
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
if let Ok(req) = state.task_rx.recv() {
req
} else {
debug!(worker_id = state.worker_id, "dispatch loop: channel closed");
if state.worker_id == 1 {
crate::join_zts_workers();
}
return Ok(());
}
};
let result = crate::zts::call_dispatch(dispatch_fn, &req.method, &req.payload);
match result {
Ok(value) => {
let _ = req.reply.send(Ok(value));
},
Err(e) => {
let _ = req.reply.send(Err(e));
},
}
}
})
}
pub fn do_send_error(message: &str) -> Result<(), &'static str> {
WORKER.with(|w| {
let mut state = w.borrow_mut();
let state = state.as_mut().ok_or("not in a worker thread")?;
let reply = state.current_reply.take().ok_or("no pending request")?;
let _ = reply.send(Err(anyhow::anyhow!("{message}")));
Ok(())
})
}