1use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context, Result, anyhow};
9use async_trait::async_trait;
10use bytes::Bytes;
11use folk_api::Executor;
12use tokio::sync::{Semaphore, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tracing::{debug, error, info, warn};
15
16use crate::config::WorkersConfig;
17use crate::runtime::{Runtime, WorkerHandle};
18use crate::worker_slot::SlotInfo;
19
20#[derive(Debug, thiserror::Error)]
22pub enum WorkError {
23 #[error("all workers busy")]
24 Busy,
25 #[error("worker died during request")]
26 WorkerDied,
27 #[error("execution timed out")]
28 Timeout,
29 #[error("worker returned application error: {message}")]
30 Application { code: i32, message: String },
31 #[error("internal error: {0}")]
32 Internal(String),
33}
34
35struct DispatchRequest {
37 method: String,
38 payload: serde_json::Value,
39 reply: oneshot::Sender<Result<serde_json::Value>>,
40}
41
42pub struct WorkerPool {
44 request_tx: mpsc::Sender<DispatchRequest>,
45 semaphore: Arc<Semaphore>,
46 runtime: Arc<dyn Runtime>,
47 reload_tx: watch::Sender<u64>,
50 _pool_task: JoinHandle<()>,
51}
52
53impl WorkerPool {
54 pub fn new(runtime: Arc<dyn Runtime>, config: WorkersConfig) -> Result<Arc<Self>> {
59 let semaphore = Arc::new(Semaphore::new(config.count));
60 let (request_tx, request_rx) = mpsc::channel::<DispatchRequest>(1024);
61 let (reload_tx, reload_rx) = watch::channel(0u64);
62
63 let pool_task = tokio::spawn(pool_main(
64 runtime.clone(),
65 config,
66 request_rx,
67 semaphore.clone(),
68 reload_rx,
69 ));
70
71 Ok(Arc::new(Self {
72 request_tx,
73 semaphore,
74 runtime,
75 reload_tx,
76 _pool_task: pool_task,
77 }))
78 }
79
80 pub async fn trigger_reload(&self) {
86 if let Err(e) = self.runtime.reload().await {
87 warn!(error = %e, "reload: cache invalidation failed; recycling anyway");
88 }
89 self.reload_tx.send_modify(|g| *g += 1);
90 let generation = *self.reload_tx.borrow();
91 info!(generation, "hot reload triggered; recycling workers");
92 }
93
94 async fn dispatch_value(
96 &self,
97 method: &str,
98 payload: serde_json::Value,
99 ) -> Result<serde_json::Value> {
100 let permit = self
101 .semaphore
102 .clone()
103 .acquire_owned()
104 .await
105 .context("pool semaphore closed")?;
106
107 let (reply_tx, reply_rx) = oneshot::channel();
108 self.request_tx
109 .send(DispatchRequest {
110 method: method.to_string(),
111 payload,
112 reply: reply_tx,
113 })
114 .await
115 .map_err(|_| anyhow!("pool task gone"))?;
116
117 let result = reply_rx
118 .await
119 .map_err(|_| anyhow!("pool dropped reply channel"))?;
120
121 drop(permit);
122 result
123 }
124}
125
126#[async_trait]
127impl Executor for WorkerPool {
128 async fn execute_method(&self, method: &str, payload: Bytes) -> Result<Bytes> {
129 debug!(
130 method,
131 payload_len = payload.len(),
132 "pool: execute_method called (bytes path)"
133 );
134 let value: serde_json::Value =
136 serde_json::from_slice(&payload).context("pool: failed to parse payload as JSON")?;
137 let result = self.dispatch_value(method, value).await?;
138 let bytes = serde_json::to_vec(&result).context("pool: failed to serialize response")?;
139 Ok(Bytes::from(bytes))
140 }
141
142 async fn execute_value(
143 &self,
144 method: &str,
145 payload: serde_json::Value,
146 ) -> Result<serde_json::Value> {
147 debug!(method, "pool: execute_value called (zero-copy path)");
148 self.dispatch_value(method, payload).await
149 }
150}
151
152async fn pool_main(
155 runtime: Arc<dyn Runtime>,
156 config: WorkersConfig,
157 mut request_rx: mpsc::Receiver<DispatchRequest>,
158 _semaphore: Arc<Semaphore>,
159 reload_rx: watch::Receiver<u64>,
160) {
161 let mut slot_inboxes: Vec<mpsc::Sender<DispatchRequest>> = Vec::with_capacity(config.count);
162 let mut slot_supervisors: Vec<JoinHandle<()>> = Vec::with_capacity(config.count);
163
164 for slot_id in 0..config.count {
165 let (slot_tx, slot_rx) = mpsc::channel::<DispatchRequest>(8);
166 slot_inboxes.push(slot_tx);
167 let runtime_clone = runtime.clone();
168 let cfg_clone = config.clone();
169 let reload_clone = reload_rx.clone();
170 let supervisor = tokio::spawn(slot_supervisor(
171 slot_id,
172 runtime_clone,
173 cfg_clone,
174 slot_rx,
175 reload_clone,
176 ));
177 slot_supervisors.push(supervisor);
178 }
179
180 let mut next: usize = 0;
182 while let Some(req) = request_rx.recv().await {
183 let chosen = next % slot_inboxes.len();
184 next = next.wrapping_add(1);
185
186 if slot_inboxes[chosen].send(req).await.is_err() {
187 warn!(slot_id = chosen, "slot inbox closed; failed to dispatch");
188 }
189 }
190
191 info!("pool main loop exiting; awaiting supervisors");
192 for handle in slot_supervisors {
193 let _ = handle.await;
194 }
195}
196
197async fn slot_supervisor(
200 slot_id: usize,
201 runtime: Arc<dyn Runtime>,
202 config: WorkersConfig,
203 mut inbox: mpsc::Receiver<DispatchRequest>,
204 mut reload_rx: watch::Receiver<u64>,
205) {
206 let mut slot = SlotInfo::new();
207 let mut worker: Option<Box<dyn WorkerHandle>> = None;
208 let mut boot_generation: u64 = *reload_rx.borrow();
211
212 loop {
213 if worker.is_none() {
215 boot_generation = *reload_rx.borrow();
216 match boot_worker(&runtime, &config, &mut slot).await {
217 Ok(w) => worker = Some(w),
218 Err(e) => {
219 error!(slot_id, error = ?e, "failed to boot worker, will retry");
220 tokio::time::sleep(Duration::from_secs(1)).await;
221 continue;
222 },
223 }
224 }
225
226 let recyclable = worker.as_ref().is_some_and(|w| w.is_recyclable());
227
228 let req = tokio::select! {
230 biased;
231 res = reload_rx.changed(), if recyclable => {
234 if res.is_err() {
235 info!(slot_id, "supervisor shutting down (reload channel closed)");
237 if let Some(mut w) = worker.take() {
238 let _ = w.terminate().await;
239 }
240 return;
241 }
242 if *reload_rx.borrow() > boot_generation {
243 info!(slot_id, "recycling idle worker for hot reload");
244 if let Some(mut w) = worker.take() {
245 let _ = w.terminate().await;
246 }
247 slot = SlotInfo::new();
248 }
249 continue;
250 },
251 maybe_req = inbox.recv() => {
252 let Some(req) = maybe_req else {
253 info!(slot_id, "supervisor shutting down (inbox closed)");
254 if let Some(mut w) = worker.take() {
255 if let Err(e) = w.terminate().await {
256 warn!(slot_id, error = ?e, "terminate error during shutdown");
257 }
258 }
259 return;
260 };
261 req
262 },
263 };
264
265 let Some(w) = worker.as_mut() else {
267 unreachable!()
268 };
269 slot.mark_busy();
270 let result = dispatch_one(w.as_mut(), &req.method, req.payload, config.exec_timeout).await;
271 slot.mark_idle();
272
273 let _ = req.reply.send(result.map_err(anyhow::Error::from));
275
276 let reload_pending = *reload_rx.borrow() > boot_generation;
279 if reload_pending || slot.should_recycle(&config) {
280 if let Some(ref w) = worker {
281 if !w.is_recyclable() {
282 debug!(slot_id, "skipping recycle for non-recyclable worker");
283 continue;
284 }
285 }
286 let reason = if reload_pending {
287 "hot reload"
288 } else {
289 "lifecycle"
290 };
291 info!(
292 slot_id,
293 jobs = slot.jobs_handled,
294 reason,
295 "recycling worker"
296 );
297 if let Some(mut w) = worker.take() {
298 let _ = w.terminate().await;
299 }
300 slot = SlotInfo::new();
301 }
302 }
303}
304
305async fn boot_worker(
306 runtime: &Arc<dyn Runtime>,
307 config: &WorkersConfig,
308 slot: &mut SlotInfo,
309) -> Result<Box<dyn WorkerHandle>> {
310 debug!("boot_worker: spawning");
311 let mut handle = runtime.spawn().await.context("spawn")?;
312 debug!(id = handle.id(), "boot_worker: waiting for ready");
313
314 let timeout = tokio::time::timeout(config.boot_timeout, handle.ready());
315 match timeout.await {
316 Ok(Ok(())) => {
317 let id = handle.id();
318 slot.mark_ready(id);
319 debug!(id, "worker ready");
320 Ok(handle)
321 },
322 Ok(Err(e)) => {
323 let _ = handle.terminate().await;
324 Err(e).context("worker ready() failed during boot")
325 },
326 Err(_) => {
327 let _ = handle.terminate().await;
328 anyhow::bail!("worker boot timed out after {:?}", config.boot_timeout)
329 },
330 }
331}
332
333async fn dispatch_one(
334 worker: &mut dyn WorkerHandle,
335 method: &str,
336 payload: serde_json::Value,
337 exec_timeout: Duration,
338) -> Result<serde_json::Value, WorkError> {
339 let recv = tokio::time::timeout(exec_timeout, worker.execute(method, payload));
340 match recv.await {
341 Ok(Ok(result)) => Ok(result),
342 Ok(Err(e)) => Err(WorkError::Internal(e.to_string())),
343 Err(_) => Err(WorkError::Timeout),
344 }
345}