1use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::mpsc;
13
14use anyhow::Result;
15use async_trait::async_trait;
16use folk_core::config::WorkersConfig;
17use folk_core::runtime::{Runtime, WorkerHandle};
18use tracing::debug;
19
20use crate::bridge;
21use crate::worker;
22
23static NEXT_WORKER_ID: AtomicU32 = AtomicU32::new(1);
24
25pub struct WorkerTxSide {
27 pub task_tx: mpsc::SyncSender<bridge::TaskRequest>,
28 pub ready_rx: mpsc::Receiver<()>,
29}
30
31pub struct ExtensionRuntime {
33 config: WorkersConfig,
34 channels: std::sync::Mutex<Vec<WorkerTxSide>>,
36}
37
38impl ExtensionRuntime {
39 pub fn new(config: WorkersConfig, tx_sides: Vec<WorkerTxSide>) -> Self {
41 Self {
42 config,
43 channels: std::sync::Mutex::new(tx_sides),
44 }
45 }
46
47 #[allow(clippy::unnecessary_wraps)] fn spawn_zts_worker(&self) -> Result<Box<dyn WorkerHandle>> {
50 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
51 let script = std::env::current_dir()
54 .unwrap_or_default()
55 .join(&self.config.script)
56 .to_string_lossy()
57 .into_owned();
58
59 let (task_tx, task_rx) = mpsc::sync_channel::<bridge::TaskRequest>(8);
60 let (ready_tx, ready_rx) = mpsc::sync_channel::<()>(1);
61
62 let handle = worker::spawn_zts_worker(worker_id, script, task_rx, ready_tx);
63 crate::register_zts_worker(handle);
64
65 debug!(worker_id, "ZTS worker thread spawned");
66
67 Ok(Box::new(ChannelWorkerHandle {
68 worker_id,
69 task_tx: Some(task_tx),
70 ready_rx: Some(ready_rx),
71 }))
72 }
73
74 fn take_preconnected(&self) -> Result<Box<dyn WorkerHandle>> {
76 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
77 let tx_side = self.channels.lock().unwrap().pop().ok_or_else(|| {
78 anyhow::anyhow!("no more pre-connected channels (worker {worker_id})")
79 })?;
80
81 debug!(worker_id, "pre-connected worker channel taken");
82
83 Ok(Box::new(ChannelWorkerHandle {
84 worker_id,
85 task_tx: Some(tx_side.task_tx),
86 ready_rx: Some(tx_side.ready_rx),
87 }))
88 }
89}
90
91#[async_trait]
92impl Runtime for ExtensionRuntime {
93 async fn spawn(&self) -> Result<Box<dyn WorkerHandle>> {
94 let has_preconnected = !self.channels.lock().unwrap().is_empty();
95
96 if has_preconnected {
97 self.take_preconnected()
98 } else if self.config.count > 1 {
99 self.spawn_zts_worker()
100 } else {
101 anyhow::bail!("no workers available and ZTS multi-worker not requested")
102 }
103 }
104}
105
106pub struct ChannelWorkerHandle {
108 worker_id: u32,
109 task_tx: Option<mpsc::SyncSender<bridge::TaskRequest>>,
110 ready_rx: Option<mpsc::Receiver<()>>,
111}
112
113#[async_trait]
114impl WorkerHandle for ChannelWorkerHandle {
115 fn id(&self) -> u32 {
116 self.worker_id
117 }
118
119 async fn ready(&mut self) -> Result<()> {
120 if let Some(rx) = self.ready_rx.take() {
121 tokio::task::spawn_blocking(move || rx.recv())
122 .await
123 .map_err(|e| anyhow::anyhow!("spawn_blocking panicked: {e}"))?
124 .map_err(|_| anyhow::anyhow!("worker died before ready"))?;
125 }
126 Ok(())
127 }
128
129 async fn execute(
130 &mut self,
131 method: &str,
132 payload: serde_json::Value,
133 ) -> Result<serde_json::Value> {
134 let tx = self
135 .task_tx
136 .as_ref()
137 .ok_or_else(|| anyhow::anyhow!("worker terminated"))?
138 .clone();
139
140 let method = method.to_string();
141
142 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
144
145 tx.send(bridge::TaskRequest {
148 method,
149 payload,
150 reply: reply_tx,
151 })
152 .map_err(|_| anyhow::anyhow!("worker process gone"))?;
153
154 reply_rx
156 .await
157 .map_err(|_| anyhow::anyhow!("worker dropped reply"))?
158 }
159
160 async fn terminate(&mut self) -> Result<()> {
161 self.task_tx.take();
162 Ok(())
163 }
164}