1use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context, Result, anyhow};
9use async_trait::async_trait;
10use bytes::Bytes;
11use folk_api::Executor;
12use tokio::sync::{Semaphore, mpsc, oneshot, watch};
13use tokio::task::JoinHandle;
14use tracing::{debug, error, info, warn};
15
16use crate::config::WorkersConfig;
17use crate::runtime::{Runtime, WorkerHandle};
18use crate::worker_slot::SlotInfo;
19
20#[derive(Debug, thiserror::Error)]
22pub enum WorkError {
23 #[error("all workers busy")]
24 Busy,
25 #[error("worker died during request")]
26 WorkerDied,
27 #[error("execution timed out")]
28 Timeout,
29 #[error("worker returned application error: {message}")]
30 Application { code: i32, message: String },
31 #[error("internal error: {0}")]
32 Internal(String),
33}
34
35struct DispatchRequest {
37 request_id: Arc<str>,
38 method: String,
39 payload: serde_json::Value,
40 reply: oneshot::Sender<Result<serde_json::Value>>,
41}
42
43pub struct WorkerPool {
45 request_tx: mpsc::Sender<DispatchRequest>,
46 semaphore: Arc<Semaphore>,
47 runtime: Arc<dyn Runtime>,
48 reload_tx: watch::Sender<u64>,
51 _pool_task: JoinHandle<()>,
52}
53
54impl WorkerPool {
55 pub fn new(runtime: Arc<dyn Runtime>, config: WorkersConfig) -> Result<Arc<Self>> {
60 let semaphore = Arc::new(Semaphore::new(config.count));
61 let (request_tx, request_rx) = mpsc::channel::<DispatchRequest>(1024);
62 let (reload_tx, reload_rx) = watch::channel(0u64);
63
64 let pool_task = tokio::spawn(pool_main(
65 runtime.clone(),
66 config,
67 request_rx,
68 semaphore.clone(),
69 reload_rx,
70 ));
71
72 Ok(Arc::new(Self {
73 request_tx,
74 semaphore,
75 runtime,
76 reload_tx,
77 _pool_task: pool_task,
78 }))
79 }
80
81 pub async fn trigger_reload(&self) {
87 if let Err(e) = self.runtime.reload().await {
88 warn!(error = %e, "reload: cache invalidation failed; recycling anyway");
89 }
90 self.reload_tx.send_modify(|g| *g += 1);
91 let generation = *self.reload_tx.borrow();
92 info!(generation, "hot reload triggered; recycling workers");
93 }
94
95 async fn dispatch_value(
100 &self,
101 method: &str,
102 payload: serde_json::Value,
103 ) -> Result<(serde_json::Value, Arc<str>)> {
104 let permit = self
105 .semaphore
106 .clone()
107 .acquire_owned()
108 .await
109 .context("pool semaphore closed")?;
110
111 let request_id: Arc<str> = Arc::from(uuid::Uuid::now_v7().hyphenated().to_string());
115 let (reply_tx, reply_rx) = oneshot::channel();
116 self.request_tx
117 .send(DispatchRequest {
118 request_id: request_id.clone(),
119 method: method.to_string(),
120 payload,
121 reply: reply_tx,
122 })
123 .await
124 .map_err(|_| anyhow!("pool task gone"))?;
125
126 let result = reply_rx
127 .await
128 .map_err(|_| anyhow!("pool dropped reply channel"))?;
129
130 drop(permit);
131 result.map(|value| (value, request_id))
132 }
133}
134
135#[async_trait]
136impl Executor for WorkerPool {
137 async fn execute_method(&self, method: &str, payload: Bytes) -> Result<Bytes> {
138 debug!(
139 method,
140 payload_len = payload.len(),
141 "pool: execute_method called (bytes path)"
142 );
143 let value: serde_json::Value =
145 serde_json::from_slice(&payload).context("pool: failed to parse payload as JSON")?;
146 let (result, _id) = self.dispatch_value(method, value).await?;
147 let bytes = serde_json::to_vec(&result).context("pool: failed to serialize response")?;
148 Ok(Bytes::from(bytes))
149 }
150
151 async fn execute_value(
152 &self,
153 method: &str,
154 payload: serde_json::Value,
155 ) -> Result<serde_json::Value> {
156 debug!(method, "pool: execute_value called (zero-copy path)");
157 let (value, _id) = self.dispatch_value(method, payload).await?;
158 Ok(value)
159 }
160
161 async fn execute_value_traced(
162 &self,
163 method: &str,
164 payload: serde_json::Value,
165 ) -> Result<(serde_json::Value, Arc<str>)> {
166 debug!(method, "pool: execute_value_traced called");
167 self.dispatch_value(method, payload).await
168 }
169}
170
171async fn pool_main(
174 runtime: Arc<dyn Runtime>,
175 config: WorkersConfig,
176 mut request_rx: mpsc::Receiver<DispatchRequest>,
177 _semaphore: Arc<Semaphore>,
178 reload_rx: watch::Receiver<u64>,
179) {
180 let mut slot_inboxes: Vec<mpsc::Sender<DispatchRequest>> = Vec::with_capacity(config.count);
181 let mut slot_supervisors: Vec<JoinHandle<()>> = Vec::with_capacity(config.count);
182
183 for slot_id in 0..config.count {
184 let (slot_tx, slot_rx) = mpsc::channel::<DispatchRequest>(8);
185 slot_inboxes.push(slot_tx);
186 let runtime_clone = runtime.clone();
187 let cfg_clone = config.clone();
188 let reload_clone = reload_rx.clone();
189 let supervisor = tokio::spawn(slot_supervisor(
190 slot_id,
191 runtime_clone,
192 cfg_clone,
193 slot_rx,
194 reload_clone,
195 ));
196 slot_supervisors.push(supervisor);
197 }
198
199 let n = slot_inboxes.len();
203 let mut next: usize = 0;
204 let mut dead: Vec<bool> = vec![false; n];
205
206 while let Some(initial_req) = request_rx.recv().await {
207 let mut req: Option<DispatchRequest> = Some(initial_req);
211 let mut sent = false;
212
213 for attempt in 0..n {
214 let chosen = (next.wrapping_add(attempt)) % n;
215 if dead[chosen] {
216 continue;
217 }
218 let r = req.take().expect("req must be Some when slot is alive");
219 match slot_inboxes[chosen].send(r).await {
220 Ok(()) => {
221 next = chosen.wrapping_add(1);
222 sent = true;
223 break;
224 },
225 Err(e) => {
226 warn!(slot_id = chosen, "slot inbox closed; skipping dead slot");
227 dead[chosen] = true;
228 req = Some(e.0);
229 },
230 }
231 }
232
233 if !sent {
234 error!("all worker slot inboxes closed; cannot dispatch request");
235 if let Some(r) = req {
236 let _ = r.reply.send(Err(anyhow::anyhow!("all worker slots dead")));
237 }
238 }
239 }
240
241 info!("pool main loop exiting; awaiting supervisors");
242 for handle in slot_supervisors {
243 let _ = handle.await;
244 }
245}
246
247async fn slot_supervisor(
250 slot_id: usize,
251 runtime: Arc<dyn Runtime>,
252 config: WorkersConfig,
253 mut inbox: mpsc::Receiver<DispatchRequest>,
254 mut reload_rx: watch::Receiver<u64>,
255) {
256 let mut slot = SlotInfo::new();
257 let mut worker: Option<Box<dyn WorkerHandle>> = None;
258 let mut boot_generation: u64 = *reload_rx.borrow();
261
262 loop {
263 if worker.is_none() {
265 boot_generation = *reload_rx.borrow();
266 match boot_worker(&runtime, &config, &mut slot).await {
267 Ok(w) => worker = Some(w),
268 Err(e) => {
269 error!(slot_id, error = ?e, "failed to boot worker, will retry");
270 tokio::time::sleep(Duration::from_secs(1)).await;
271 continue;
272 },
273 }
274 }
275
276 let recyclable = worker.as_ref().is_some_and(|w| w.is_recyclable());
277
278 let req = tokio::select! {
280 biased;
281 res = reload_rx.changed(), if recyclable => {
284 if res.is_err() {
285 info!(slot_id, "supervisor shutting down (reload channel closed)");
287 if let Some(mut w) = worker.take() {
288 let _ = w.terminate().await;
289 }
290 return;
291 }
292 if *reload_rx.borrow() > boot_generation {
293 info!(slot_id, "recycling idle worker for hot reload");
294 if let Some(mut w) = worker.take() {
295 let _ = w.terminate().await;
296 }
297 slot = SlotInfo::new();
298 }
299 continue;
300 },
301 maybe_req = inbox.recv() => {
302 let Some(req) = maybe_req else {
303 info!(slot_id, "supervisor shutting down (inbox closed)");
304 if let Some(mut w) = worker.take() {
305 if let Err(e) = w.terminate().await {
306 warn!(slot_id, error = ?e, "terminate error during shutdown");
307 }
308 }
309 return;
310 };
311 req
312 },
313 };
314
315 let Some(w) = worker.as_mut() else {
317 unreachable!()
318 };
319 slot.mark_busy();
320 let result = dispatch_one(
321 w.as_mut(),
322 &req.method,
323 req.payload,
324 req.request_id.clone(),
325 config.exec_timeout,
326 )
327 .await;
328 slot.mark_idle();
329
330 let _ = req.reply.send(result.map_err(anyhow::Error::from));
332
333 let reload_pending = *reload_rx.borrow() > boot_generation;
336 if reload_pending || slot.should_recycle(&config) {
337 if let Some(ref w) = worker {
338 if !w.is_recyclable() {
339 debug!(slot_id, "skipping recycle for non-recyclable worker");
340 continue;
341 }
342 }
343 let reason = if reload_pending {
344 "hot reload"
345 } else {
346 "lifecycle"
347 };
348 info!(
349 slot_id,
350 jobs = slot.jobs_handled,
351 reason,
352 "recycling worker"
353 );
354 if let Some(mut w) = worker.take() {
355 let _ = w.terminate().await;
356 }
357 slot = SlotInfo::new();
358 }
359 }
360}
361
362async fn boot_worker(
363 runtime: &Arc<dyn Runtime>,
364 config: &WorkersConfig,
365 slot: &mut SlotInfo,
366) -> Result<Box<dyn WorkerHandle>> {
367 debug!("boot_worker: spawning");
368 let mut handle = runtime.spawn().await.context("spawn")?;
369 debug!(id = handle.id(), "boot_worker: waiting for ready");
370
371 let timeout = tokio::time::timeout(config.boot_timeout, handle.ready());
372 match timeout.await {
373 Ok(Ok(())) => {
374 let id = handle.id();
375 slot.mark_ready(id);
376 debug!(id, "worker ready");
377 Ok(handle)
378 },
379 Ok(Err(e)) => {
380 let _ = handle.terminate().await;
381 Err(e).context("worker ready() failed during boot")
382 },
383 Err(_) => {
384 let _ = handle.terminate().await;
385 anyhow::bail!("worker boot timed out after {:?}", config.boot_timeout)
386 },
387 }
388}
389
390async fn dispatch_one(
391 worker: &mut dyn WorkerHandle,
392 method: &str,
393 payload: serde_json::Value,
394 request_id: Arc<str>,
395 exec_timeout: Duration,
396) -> Result<serde_json::Value, WorkError> {
397 let recv = tokio::time::timeout(exec_timeout, worker.execute(method, payload, request_id));
398 match recv.await {
399 Ok(Ok(result)) => Ok(result),
400 Ok(Err(e)) => Err(WorkError::Internal(e.to_string())),
401 Err(_) => Err(WorkError::Timeout),
402 }
403}