use anyhow::Result;
use async_trait::async_trait;
use folk_core::runtime::WorkerHandle;
use folk_protocol::RpcMessage;
use tokio::sync::mpsc;
use tracing::debug;
use crate::worker::WorkerCommand;
pub(crate) struct EmbedWorkerHandle {
worker_id: u32,
cmd_tx: mpsc::UnboundedSender<WorkerCommand>,
task_resp_rx: mpsc::UnboundedReceiver<RpcMessage>,
control_rx: mpsc::UnboundedReceiver<RpcMessage>,
thread: Option<std::thread::JoinHandle<()>>,
}
impl EmbedWorkerHandle {
pub(crate) fn new(
worker_id: u32,
cmd_tx: mpsc::UnboundedSender<WorkerCommand>,
task_resp_rx: mpsc::UnboundedReceiver<RpcMessage>,
control_rx: mpsc::UnboundedReceiver<RpcMessage>,
thread: std::thread::JoinHandle<()>,
) -> Self {
Self {
worker_id,
cmd_tx,
task_resp_rx,
control_rx,
thread: Some(thread),
}
}
}
#[async_trait]
impl WorkerHandle for EmbedWorkerHandle {
fn pid(&self) -> u32 {
self.worker_id
}
async fn send_task(&mut self, msg: RpcMessage) -> Result<()> {
self.cmd_tx
.send(WorkerCommand::Task(msg))
.map_err(|_| anyhow::anyhow!("worker thread gone"))?;
Ok(())
}
async fn recv_task(&mut self) -> Result<Option<RpcMessage>> {
Ok(self.task_resp_rx.recv().await)
}
async fn send_control(&mut self, _msg: RpcMessage) -> Result<()> {
Ok(())
}
async fn recv_control(&mut self) -> Result<Option<RpcMessage>> {
Ok(self.control_rx.recv().await)
}
async fn terminate(&mut self) -> Result<()> {
debug!(worker_id = self.worker_id, "terminating embed worker");
let _ = self.cmd_tx.send(WorkerCommand::Terminate);
if let Some(thread) = self.thread.take() {
let worker_id = self.worker_id;
tokio::task::spawn_blocking(move || {
if let Err(e) = thread.join() {
tracing::warn!(worker_id, "worker thread panicked: {e:?}");
}
})
.await?;
}
Ok(())
}
}