use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::mpsc;
use anyhow::Result;
use async_trait::async_trait;
use folk_core::config::WorkersConfig;
use folk_core::runtime::{Runtime, WorkerHandle};
use tracing::debug;
use crate::bridge;
use crate::worker;
static NEXT_WORKER_ID: AtomicU32 = AtomicU32::new(1);
pub struct WorkerTxSide {
pub task_tx: mpsc::SyncSender<bridge::TaskRequest>,
pub ready_rx: mpsc::Receiver<()>,
}
pub struct ExtensionRuntime {
config: WorkersConfig,
channels: std::sync::Mutex<Vec<WorkerTxSide>>,
}
impl ExtensionRuntime {
pub fn new(config: WorkersConfig, tx_sides: Vec<WorkerTxSide>) -> Self {
Self {
config,
channels: std::sync::Mutex::new(tx_sides),
}
}
#[allow(clippy::unnecessary_wraps)] fn spawn_zts_worker(&self) -> Result<Box<dyn WorkerHandle>> {
let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
let script = std::env::current_dir()
.unwrap_or_default()
.join(&self.config.script)
.to_string_lossy()
.into_owned();
let (task_tx, task_rx) = mpsc::sync_channel::<bridge::TaskRequest>(8);
let (ready_tx, ready_rx) = mpsc::sync_channel::<()>(1);
let handle = worker::spawn_zts_worker(worker_id, script, task_rx, ready_tx);
crate::register_zts_worker(handle);
debug!(worker_id, "ZTS worker thread spawned");
Ok(Box::new(ChannelWorkerHandle {
worker_id,
task_tx: Some(task_tx),
ready_rx: Some(ready_rx),
}))
}
fn take_preconnected(&self) -> Result<Box<dyn WorkerHandle>> {
let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
let tx_side = self.channels.lock().unwrap().pop().ok_or_else(|| {
anyhow::anyhow!("no more pre-connected channels (worker {worker_id})")
})?;
debug!(worker_id, "pre-connected worker channel taken");
Ok(Box::new(ChannelWorkerHandle {
worker_id,
task_tx: Some(tx_side.task_tx),
ready_rx: Some(tx_side.ready_rx),
}))
}
}
#[async_trait]
impl Runtime for ExtensionRuntime {
async fn spawn(&self) -> Result<Box<dyn WorkerHandle>> {
let has_preconnected = !self.channels.lock().unwrap().is_empty();
if has_preconnected {
self.take_preconnected()
} else if self.config.count > 1 {
self.spawn_zts_worker()
} else {
anyhow::bail!("no workers available and ZTS multi-worker not requested")
}
}
}
pub struct ChannelWorkerHandle {
worker_id: u32,
task_tx: Option<mpsc::SyncSender<bridge::TaskRequest>>,
ready_rx: Option<mpsc::Receiver<()>>,
}
#[async_trait]
impl WorkerHandle for ChannelWorkerHandle {
fn id(&self) -> u32 {
self.worker_id
}
async fn ready(&mut self) -> Result<()> {
if let Some(rx) = self.ready_rx.take() {
tokio::task::spawn_blocking(move || rx.recv())
.await
.map_err(|e| anyhow::anyhow!("spawn_blocking panicked: {e}"))?
.map_err(|_| anyhow::anyhow!("worker died before ready"))?;
}
Ok(())
}
async fn execute(
&mut self,
method: &str,
payload: serde_json::Value,
) -> Result<serde_json::Value> {
let tx = self
.task_tx
.as_ref()
.ok_or_else(|| anyhow::anyhow!("worker terminated"))?
.clone();
let method = method.to_string();
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
tx.send(bridge::TaskRequest {
method,
payload,
reply: reply_tx,
})
.map_err(|_| anyhow::anyhow!("worker process gone"))?;
reply_rx
.await
.map_err(|_| anyhow::anyhow!("worker dropped reply"))?
}
async fn terminate(&mut self) -> Result<()> {
self.task_tx.take();
Ok(())
}
}