Skip to main content

folk_core/
worker_pool.rs

1//! Worker pool: dispatches requests to PHP workers, manages slot lifecycle.
2//!
3//! See `folk-spec/spec/03-worker-lifecycle.md` for the design.
4
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context, Result, anyhow};
9use async_trait::async_trait;
10use bytes::Bytes;
11use folk_api::Executor;
12use folk_protocol::RpcMessage;
13use rmpv::Value as RmpValue;
14use tokio::sync::{Semaphore, mpsc, oneshot};
15use tokio::task::JoinHandle;
16use tracing::{debug, error, info, warn};
17
18use crate::config::WorkersConfig;
19use crate::runtime::{Runtime, WorkerHandle};
20use crate::worker_slot::SlotInfo;
21
22/// Pool errors. Plugins typically translate these into their own domain errors.
23#[derive(Debug, thiserror::Error)]
24pub enum WorkError {
25    #[error("all workers busy")]
26    Busy,
27    #[error("worker died during request")]
28    WorkerDied,
29    #[error("execution timed out")]
30    Timeout,
31    #[error("worker returned application error: {message}")]
32    Application { code: i32, message: String },
33    #[error("protocol error: {0}")]
34    Protocol(#[from] folk_protocol::Error),
35    #[error("io error: {0}")]
36    Io(#[from] std::io::Error),
37}
38
39/// One dispatch request: method name, payload + reply channel.
40struct DispatchRequest {
41    method: String,
42    payload: Bytes,
43    reply: oneshot::Sender<Result<Bytes>>,
44}
45
46/// Worker pool — the dispatch surface.
47pub struct WorkerPool {
48    request_tx: mpsc::Sender<DispatchRequest>,
49    semaphore: Arc<Semaphore>,
50    _pool_task: JoinHandle<()>,
51}
52
53impl WorkerPool {
54    /// Construct a pool with `config.count` workers spawned via `runtime`.
55    ///
56    /// Returns once the pool task is started. Workers boot asynchronously
57    /// in the background.
58    pub fn new(runtime: Arc<dyn Runtime>, config: WorkersConfig) -> Result<Arc<Self>> {
59        let semaphore = Arc::new(Semaphore::new(config.count));
60        let (request_tx, request_rx) = mpsc::channel::<DispatchRequest>(1024);
61
62        let pool_task = tokio::spawn(pool_main(runtime, config, request_rx, semaphore.clone()));
63
64        Ok(Arc::new(Self {
65            request_tx,
66            semaphore,
67            _pool_task: pool_task,
68        }))
69    }
70}
71
72#[async_trait]
73impl Executor for WorkerPool {
74    async fn execute_method(&self, method: &str, payload: Bytes) -> Result<Bytes> {
75        let permit = self
76            .semaphore
77            .clone()
78            .acquire_owned()
79            .await
80            .context("pool semaphore closed")?;
81
82        let (reply_tx, reply_rx) = oneshot::channel();
83        self.request_tx
84            .send(DispatchRequest {
85                method: method.to_string(),
86                payload,
87                reply: reply_tx,
88            })
89            .await
90            .map_err(|_| anyhow!("pool task gone"))?;
91
92        let result = reply_rx
93            .await
94            .map_err(|_| anyhow!("pool dropped reply channel"))?;
95
96        drop(permit);
97        result
98    }
99}
100
101// ---- Pool task ---------------------------------------------------------------
102
103async fn pool_main(
104    runtime: Arc<dyn Runtime>,
105    config: WorkersConfig,
106    mut request_rx: mpsc::Receiver<DispatchRequest>,
107    _semaphore: Arc<Semaphore>,
108) {
109    let mut slot_inboxes: Vec<mpsc::Sender<DispatchRequest>> = Vec::with_capacity(config.count);
110    let mut slot_supervisors: Vec<JoinHandle<()>> = Vec::with_capacity(config.count);
111
112    for slot_id in 0..config.count {
113        let (slot_tx, slot_rx) = mpsc::channel::<DispatchRequest>(8);
114        slot_inboxes.push(slot_tx);
115        let runtime_clone = runtime.clone();
116        let cfg_clone = config.clone();
117        let supervisor = tokio::spawn(slot_supervisor(slot_id, runtime_clone, cfg_clone, slot_rx));
118        slot_supervisors.push(supervisor);
119    }
120
121    // Round-robin dispatch.
122    let mut next: usize = 0;
123    while let Some(req) = request_rx.recv().await {
124        let chosen = next % slot_inboxes.len();
125        next = next.wrapping_add(1);
126
127        if slot_inboxes[chosen].send(req).await.is_err() {
128            warn!(slot_id = chosen, "slot inbox closed; failed to dispatch");
129        }
130    }
131
132    info!("pool main loop exiting; awaiting supervisors");
133    for handle in slot_supervisors {
134        let _ = handle.await;
135    }
136}
137
138// ---- Slot supervisor ---------------------------------------------------------
139
140async fn slot_supervisor(
141    slot_id: usize,
142    runtime: Arc<dyn Runtime>,
143    config: WorkersConfig,
144    mut inbox: mpsc::Receiver<DispatchRequest>,
145) {
146    let mut slot = SlotInfo::new();
147    let mut worker: Option<Box<dyn WorkerHandle>> = None;
148
149    loop {
150        // Spawn a worker if we don't have one.
151        if worker.is_none() {
152            match boot_worker(&runtime, &config, &mut slot).await {
153                Ok(w) => worker = Some(w),
154                Err(e) => {
155                    error!(slot_id, error = ?e, "failed to boot worker, will retry");
156                    tokio::time::sleep(Duration::from_secs(1)).await;
157                    continue;
158                },
159            }
160        }
161
162        let Some(w) = worker.as_mut() else {
163            unreachable!()
164        };
165
166        // Wait for a request or for shutdown.
167        let Some(req) = inbox.recv().await else {
168            info!(slot_id, "supervisor shutting down (inbox closed)");
169            if let Err(e) = w.terminate().await {
170                warn!(slot_id, error = ?e, "terminate error during shutdown");
171            }
172            return;
173        };
174
175        // Dispatch.
176        slot.mark_busy();
177        let result = dispatch_one(w.as_mut(), &req.method, req.payload, config.exec_timeout).await;
178        slot.mark_idle();
179
180        // Send reply.
181        let _ = req.reply.send(result.map_err(anyhow::Error::from));
182
183        // Recycle?
184        if slot.should_recycle(&config) {
185            info!(slot_id, jobs = slot.jobs_handled, "recycling worker");
186            if let Some(mut w) = worker.take() {
187                let _ = w.terminate().await;
188            }
189            slot = SlotInfo::new();
190        }
191    }
192}
193
194async fn boot_worker(
195    runtime: &Arc<dyn Runtime>,
196    config: &WorkersConfig,
197    slot: &mut SlotInfo,
198) -> Result<Box<dyn WorkerHandle>> {
199    let mut handle = runtime.spawn().await.context("spawn")?;
200    let timeout = tokio::time::timeout(config.boot_timeout, handle.recv_control());
201    match timeout.await {
202        Ok(Ok(Some(RpcMessage::Notify { method, .. }))) if method == "control.ready" => {
203            let pid = handle.pid();
204            slot.mark_ready(pid);
205            debug!(pid, "worker ready");
206            Ok(handle)
207        },
208        Ok(Ok(other)) => {
209            let _ = handle.terminate().await;
210            anyhow::bail!("expected control.ready, got {other:?}")
211        },
212        Ok(Err(e)) => {
213            let _ = handle.terminate().await;
214            Err(e).context("recv_control failed during boot")
215        },
216        Err(_) => {
217            let _ = handle.terminate().await;
218            anyhow::bail!("worker boot timed out after {:?}", config.boot_timeout)
219        },
220    }
221}
222
223async fn dispatch_one(
224    worker: &mut dyn WorkerHandle,
225    method: &str,
226    payload: Bytes,
227    exec_timeout: Duration,
228) -> Result<Bytes, WorkError> {
229    static MSGID: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1);
230    let msgid = MSGID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
231
232    let params = rmp_serde::from_slice::<RmpValue>(&payload)
233        .map_err(|e| WorkError::Protocol(folk_protocol::Error::Decode(e)))?;
234    let request = RpcMessage::request(msgid, method, params);
235
236    worker
237        .send_task(request)
238        .await
239        .map_err(|_| WorkError::WorkerDied)?;
240
241    let recv = tokio::time::timeout(exec_timeout, worker.recv_task());
242    let response = match recv.await {
243        Ok(Ok(Some(msg))) => msg,
244        Ok(Ok(None) | Err(_)) => return Err(WorkError::WorkerDied),
245        Err(_) => return Err(WorkError::Timeout),
246    };
247
248    match response {
249        RpcMessage::Response { error, result, .. } => {
250            if !error.is_nil() {
251                return Err(WorkError::Application {
252                    code: -1,
253                    message: format!("{error:?}"),
254                });
255            }
256            let mut buf = Vec::new();
257            rmp_serde::encode::write(&mut buf, &result)
258                .map_err(|e| WorkError::Protocol(folk_protocol::Error::Encode(e)))?;
259            Ok(Bytes::from(buf))
260        },
261        other => Err(WorkError::Protocol(folk_protocol::Error::InvalidFrame(
262            format!("expected Response, got {other:?}"),
263        ))),
264    }
265}