1use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context, Result, anyhow};
9use async_trait::async_trait;
10use bytes::Bytes;
11use folk_api::Executor;
12use folk_protocol::RpcMessage;
13use rmpv::Value as RmpValue;
14use tokio::sync::{Semaphore, mpsc, oneshot};
15use tokio::task::JoinHandle;
16use tracing::{debug, error, info, warn};
17
18use crate::config::WorkersConfig;
19use crate::runtime::{Runtime, WorkerHandle};
20use crate::worker_slot::SlotInfo;
21
22#[derive(Debug, thiserror::Error)]
24pub enum WorkError {
25 #[error("all workers busy")]
26 Busy,
27 #[error("worker died during request")]
28 WorkerDied,
29 #[error("execution timed out")]
30 Timeout,
31 #[error("worker returned application error: {message}")]
32 Application { code: i32, message: String },
33 #[error("protocol error: {0}")]
34 Protocol(#[from] folk_protocol::Error),
35 #[error("io error: {0}")]
36 Io(#[from] std::io::Error),
37}
38
39struct DispatchRequest {
41 payload: Bytes,
42 reply: oneshot::Sender<Result<Bytes>>,
43}
44
45pub struct WorkerPool {
47 request_tx: mpsc::Sender<DispatchRequest>,
48 semaphore: Arc<Semaphore>,
49 _pool_task: JoinHandle<()>,
50}
51
52impl WorkerPool {
53 pub fn new(runtime: Arc<dyn Runtime>, config: WorkersConfig) -> Result<Arc<Self>> {
58 let semaphore = Arc::new(Semaphore::new(config.count));
59 let (request_tx, request_rx) = mpsc::channel::<DispatchRequest>(1024);
60
61 let pool_task = tokio::spawn(pool_main(runtime, config, request_rx, semaphore.clone()));
62
63 Ok(Arc::new(Self {
64 request_tx,
65 semaphore,
66 _pool_task: pool_task,
67 }))
68 }
69}
70
71#[async_trait]
72impl Executor for WorkerPool {
73 async fn execute(&self, payload: Bytes) -> Result<Bytes> {
74 let permit = self
75 .semaphore
76 .clone()
77 .acquire_owned()
78 .await
79 .context("pool semaphore closed")?;
80
81 let (reply_tx, reply_rx) = oneshot::channel();
82 self.request_tx
83 .send(DispatchRequest {
84 payload,
85 reply: reply_tx,
86 })
87 .await
88 .map_err(|_| anyhow!("pool task gone"))?;
89
90 let result = reply_rx
91 .await
92 .map_err(|_| anyhow!("pool dropped reply channel"))?;
93
94 drop(permit);
95 result
96 }
97}
98
99async fn pool_main(
102 runtime: Arc<dyn Runtime>,
103 config: WorkersConfig,
104 mut request_rx: mpsc::Receiver<DispatchRequest>,
105 _semaphore: Arc<Semaphore>,
106) {
107 let mut slot_inboxes: Vec<mpsc::Sender<DispatchRequest>> = Vec::with_capacity(config.count);
108 let mut slot_supervisors: Vec<JoinHandle<()>> = Vec::with_capacity(config.count);
109
110 for slot_id in 0..config.count {
111 let (slot_tx, slot_rx) = mpsc::channel::<DispatchRequest>(8);
112 slot_inboxes.push(slot_tx);
113 let runtime_clone = runtime.clone();
114 let cfg_clone = config.clone();
115 let supervisor = tokio::spawn(slot_supervisor(slot_id, runtime_clone, cfg_clone, slot_rx));
116 slot_supervisors.push(supervisor);
117 }
118
119 let mut next: usize = 0;
121 while let Some(req) = request_rx.recv().await {
122 let chosen = next % slot_inboxes.len();
123 next = next.wrapping_add(1);
124
125 if slot_inboxes[chosen].send(req).await.is_err() {
126 warn!(slot_id = chosen, "slot inbox closed; failed to dispatch");
127 }
128 }
129
130 info!("pool main loop exiting; awaiting supervisors");
131 for handle in slot_supervisors {
132 let _ = handle.await;
133 }
134}
135
136async fn slot_supervisor(
139 slot_id: usize,
140 runtime: Arc<dyn Runtime>,
141 config: WorkersConfig,
142 mut inbox: mpsc::Receiver<DispatchRequest>,
143) {
144 let mut slot = SlotInfo::new();
145 let mut worker: Option<Box<dyn WorkerHandle>> = None;
146
147 loop {
148 if worker.is_none() {
150 match boot_worker(&runtime, &config, &mut slot).await {
151 Ok(w) => worker = Some(w),
152 Err(e) => {
153 error!(slot_id, error = ?e, "failed to boot worker, will retry");
154 tokio::time::sleep(Duration::from_secs(1)).await;
155 continue;
156 },
157 }
158 }
159
160 let Some(w) = worker.as_mut() else {
161 unreachable!()
162 };
163
164 let Some(req) = inbox.recv().await else {
166 info!(slot_id, "supervisor shutting down (inbox closed)");
167 if let Err(e) = w.terminate().await {
168 warn!(slot_id, error = ?e, "terminate error during shutdown");
169 }
170 return;
171 };
172
173 slot.mark_busy();
175 let result = dispatch_one(w.as_mut(), req.payload, config.exec_timeout).await;
176 slot.mark_idle();
177
178 let _ = req.reply.send(result.map_err(anyhow::Error::from));
180
181 if slot.should_recycle(&config) {
183 info!(slot_id, jobs = slot.jobs_handled, "recycling worker");
184 if let Some(mut w) = worker.take() {
185 let _ = w.terminate().await;
186 }
187 slot = SlotInfo::new();
188 }
189 }
190}
191
192async fn boot_worker(
193 runtime: &Arc<dyn Runtime>,
194 config: &WorkersConfig,
195 slot: &mut SlotInfo,
196) -> Result<Box<dyn WorkerHandle>> {
197 let mut handle = runtime.spawn().await.context("spawn")?;
198 let timeout = tokio::time::timeout(config.boot_timeout, handle.recv_control());
199 match timeout.await {
200 Ok(Ok(Some(RpcMessage::Notify { method, .. }))) if method == "control.ready" => {
201 let pid = handle.pid();
202 slot.mark_ready(pid);
203 debug!(pid, "worker ready");
204 Ok(handle)
205 },
206 Ok(Ok(other)) => {
207 let _ = handle.terminate().await;
208 anyhow::bail!("expected control.ready, got {other:?}")
209 },
210 Ok(Err(e)) => {
211 let _ = handle.terminate().await;
212 Err(e).context("recv_control failed during boot")
213 },
214 Err(_) => {
215 let _ = handle.terminate().await;
216 anyhow::bail!("worker boot timed out after {:?}", config.boot_timeout)
217 },
218 }
219}
220
221async fn dispatch_one(
222 worker: &mut dyn WorkerHandle,
223 payload: Bytes,
224 exec_timeout: Duration,
225) -> Result<Bytes, WorkError> {
226 static MSGID: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1);
227 let msgid = MSGID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
228
229 let params = rmp_serde::from_slice::<RmpValue>(&payload)
230 .map_err(|e| WorkError::Protocol(folk_protocol::Error::Decode(e)))?;
231 let request = RpcMessage::request(msgid, "dispatch", params);
232
233 worker
234 .send_task(request)
235 .await
236 .map_err(|_| WorkError::WorkerDied)?;
237
238 let recv = tokio::time::timeout(exec_timeout, worker.recv_task());
239 let response = match recv.await {
240 Ok(Ok(Some(msg))) => msg,
241 Ok(Ok(None) | Err(_)) => return Err(WorkError::WorkerDied),
242 Err(_) => return Err(WorkError::Timeout),
243 };
244
245 match response {
246 RpcMessage::Response { error, result, .. } => {
247 if !error.is_nil() {
248 return Err(WorkError::Application {
249 code: -1,
250 message: format!("{error:?}"),
251 });
252 }
253 let mut buf = Vec::new();
254 rmp_serde::encode::write(&mut buf, &result)
255 .map_err(|e| WorkError::Protocol(folk_protocol::Error::Encode(e)))?;
256 Ok(Bytes::from(buf))
257 },
258 other => Err(WorkError::Protocol(folk_protocol::Error::InvalidFrame(
259 format!("expected Response, got {other:?}"),
260 ))),
261 }
262}