1use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::mpsc;
13use std::thread;
14
15use anyhow::Result;
16use async_trait::async_trait;
17use folk_core::config::WorkersConfig;
18use folk_core::runtime::{Runtime, WorkerHandle};
19use tracing::debug;
20
21use crate::bridge;
22use crate::worker;
23
24static NEXT_WORKER_ID: AtomicU32 = AtomicU32::new(1);
25
26pub struct WorkerTxSide {
28 pub task_tx: mpsc::SyncSender<bridge::TaskRequest>,
29 pub ready_rx: mpsc::Receiver<()>,
30}
31
32pub struct ExtensionRuntime {
34 config: WorkersConfig,
35 channels: std::sync::Mutex<Vec<WorkerTxSide>>,
37}
38
39impl ExtensionRuntime {
40 pub fn new(config: WorkersConfig, tx_sides: Vec<WorkerTxSide>) -> Self {
42 Self {
43 config,
44 channels: std::sync::Mutex::new(tx_sides),
45 }
46 }
47
48 #[allow(clippy::unnecessary_wraps)] fn spawn_zts_worker(&self) -> Result<Box<dyn WorkerHandle>> {
51 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
52 let script = std::env::current_dir()
55 .unwrap_or_default()
56 .join(&self.config.script)
57 .to_string_lossy()
58 .into_owned();
59
60 let (task_tx, task_rx) = mpsc::sync_channel::<bridge::TaskRequest>(8);
61 let (ready_tx, ready_rx) = mpsc::sync_channel::<()>(1);
62
63 let thread_handle = worker::spawn_zts_worker(worker_id, script, task_rx, ready_tx);
64
65 debug!(worker_id, "ZTS worker thread spawned");
66
67 Ok(Box::new(ChannelWorkerHandle {
68 worker_id,
69 task_tx: Some(task_tx),
70 ready_rx: Some(ready_rx),
71 thread_handle: Some(thread_handle),
72 }))
73 }
74
75 fn take_preconnected(&self) -> Result<Box<dyn WorkerHandle>> {
77 let worker_id = NEXT_WORKER_ID.fetch_add(1, Ordering::Relaxed);
78 let tx_side = self.channels.lock().unwrap().pop().ok_or_else(|| {
79 anyhow::anyhow!("no more pre-connected channels (worker {worker_id})")
80 })?;
81
82 debug!(worker_id, "pre-connected worker channel taken");
83
84 Ok(Box::new(ChannelWorkerHandle {
85 worker_id,
86 task_tx: Some(tx_side.task_tx),
87 ready_rx: Some(tx_side.ready_rx),
88 thread_handle: None, }))
90 }
91}
92
93#[async_trait]
94impl Runtime for ExtensionRuntime {
95 async fn spawn(&self) -> Result<Box<dyn WorkerHandle>> {
96 let has_preconnected = !self.channels.lock().unwrap().is_empty();
97
98 if has_preconnected {
99 self.take_preconnected()
100 } else if self.config.count > 1 {
101 self.spawn_zts_worker()
102 } else {
103 anyhow::bail!("no workers available and ZTS multi-worker not requested")
104 }
105 }
106}
107
108pub struct ChannelWorkerHandle {
110 worker_id: u32,
111 task_tx: Option<mpsc::SyncSender<bridge::TaskRequest>>,
112 ready_rx: Option<mpsc::Receiver<()>>,
113 thread_handle: Option<thread::JoinHandle<()>>,
114}
115
116#[async_trait]
117impl WorkerHandle for ChannelWorkerHandle {
118 fn id(&self) -> u32 {
119 self.worker_id
120 }
121
122 async fn ready(&mut self) -> Result<()> {
123 if let Some(rx) = self.ready_rx.take() {
124 tokio::task::spawn_blocking(move || rx.recv())
125 .await
126 .map_err(|e| anyhow::anyhow!("spawn_blocking panicked: {e}"))?
127 .map_err(|_| anyhow::anyhow!("worker died before ready"))?;
128 }
129 Ok(())
130 }
131
132 async fn execute(
133 &mut self,
134 method: &str,
135 payload: serde_json::Value,
136 ) -> Result<serde_json::Value> {
137 let tx = self
138 .task_tx
139 .as_ref()
140 .ok_or_else(|| anyhow::anyhow!("worker terminated"))?
141 .clone();
142
143 let method = method.to_string();
144
145 let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
147
148 tx.send(bridge::TaskRequest {
151 method,
152 payload,
153 reply: reply_tx,
154 })
155 .map_err(|_| anyhow::anyhow!("worker process gone"))?;
156
157 reply_rx
159 .await
160 .map_err(|_| anyhow::anyhow!("worker dropped reply"))?
161 }
162
163 async fn terminate(&mut self) -> Result<()> {
164 self.task_tx.take();
166
167 if let Some(handle) = self.thread_handle.take() {
169 tokio::task::spawn_blocking(move || {
170 let _ = handle.join();
171 })
172 .await
173 .map_err(|e| anyhow::anyhow!("spawn_blocking panicked: {e}"))?;
174 }
175
176 Ok(())
177 }
178
179 fn is_recyclable(&self) -> bool {
180 self.thread_handle.is_some()
182 }
183}