1use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::{Context, Result, anyhow};
9use async_trait::async_trait;
10use bytes::Bytes;
11use folk_api::Executor;
12use folk_protocol::RpcMessage;
13use rmpv::Value as RmpValue;
14use tokio::sync::{Semaphore, mpsc, oneshot};
15use tokio::task::JoinHandle;
16use tracing::{debug, error, info, warn};
17
18use crate::config::WorkersConfig;
19use crate::runtime::{Runtime, WorkerHandle};
20use crate::worker_slot::SlotInfo;
21
22#[derive(Debug, thiserror::Error)]
24pub enum WorkError {
25 #[error("all workers busy")]
26 Busy,
27 #[error("worker died during request")]
28 WorkerDied,
29 #[error("execution timed out")]
30 Timeout,
31 #[error("worker returned application error: {message}")]
32 Application { code: i32, message: String },
33 #[error("protocol error: {0}")]
34 Protocol(#[from] folk_protocol::Error),
35 #[error("io error: {0}")]
36 Io(#[from] std::io::Error),
37}
38
39struct DispatchRequest {
41 method: String,
42 payload: Bytes,
43 reply: oneshot::Sender<Result<Bytes>>,
44}
45
46pub struct WorkerPool {
48 request_tx: mpsc::Sender<DispatchRequest>,
49 semaphore: Arc<Semaphore>,
50 _pool_task: JoinHandle<()>,
51}
52
53impl WorkerPool {
54 pub fn new(runtime: Arc<dyn Runtime>, config: WorkersConfig) -> Result<Arc<Self>> {
59 let semaphore = Arc::new(Semaphore::new(config.count));
60 let (request_tx, request_rx) = mpsc::channel::<DispatchRequest>(1024);
61
62 let pool_task = tokio::spawn(pool_main(runtime, config, request_rx, semaphore.clone()));
63
64 Ok(Arc::new(Self {
65 request_tx,
66 semaphore,
67 _pool_task: pool_task,
68 }))
69 }
70}
71
72#[async_trait]
73impl Executor for WorkerPool {
74 async fn execute_method(&self, method: &str, payload: Bytes) -> Result<Bytes> {
75 let permit = self
76 .semaphore
77 .clone()
78 .acquire_owned()
79 .await
80 .context("pool semaphore closed")?;
81
82 let (reply_tx, reply_rx) = oneshot::channel();
83 self.request_tx
84 .send(DispatchRequest {
85 method: method.to_string(),
86 payload,
87 reply: reply_tx,
88 })
89 .await
90 .map_err(|_| anyhow!("pool task gone"))?;
91
92 let result = reply_rx
93 .await
94 .map_err(|_| anyhow!("pool dropped reply channel"))?;
95
96 drop(permit);
97 result
98 }
99}
100
101async fn pool_main(
104 runtime: Arc<dyn Runtime>,
105 config: WorkersConfig,
106 mut request_rx: mpsc::Receiver<DispatchRequest>,
107 _semaphore: Arc<Semaphore>,
108) {
109 let mut slot_inboxes: Vec<mpsc::Sender<DispatchRequest>> = Vec::with_capacity(config.count);
110 let mut slot_supervisors: Vec<JoinHandle<()>> = Vec::with_capacity(config.count);
111
112 for slot_id in 0..config.count {
113 let (slot_tx, slot_rx) = mpsc::channel::<DispatchRequest>(8);
114 slot_inboxes.push(slot_tx);
115 let runtime_clone = runtime.clone();
116 let cfg_clone = config.clone();
117 let supervisor = tokio::spawn(slot_supervisor(slot_id, runtime_clone, cfg_clone, slot_rx));
118 slot_supervisors.push(supervisor);
119 }
120
121 let mut next: usize = 0;
123 while let Some(req) = request_rx.recv().await {
124 let chosen = next % slot_inboxes.len();
125 next = next.wrapping_add(1);
126
127 if slot_inboxes[chosen].send(req).await.is_err() {
128 warn!(slot_id = chosen, "slot inbox closed; failed to dispatch");
129 }
130 }
131
132 info!("pool main loop exiting; awaiting supervisors");
133 for handle in slot_supervisors {
134 let _ = handle.await;
135 }
136}
137
138async fn slot_supervisor(
141 slot_id: usize,
142 runtime: Arc<dyn Runtime>,
143 config: WorkersConfig,
144 mut inbox: mpsc::Receiver<DispatchRequest>,
145) {
146 let mut slot = SlotInfo::new();
147 let mut worker: Option<Box<dyn WorkerHandle>> = None;
148
149 loop {
150 if worker.is_none() {
152 match boot_worker(&runtime, &config, &mut slot).await {
153 Ok(w) => worker = Some(w),
154 Err(e) => {
155 error!(slot_id, error = ?e, "failed to boot worker, will retry");
156 tokio::time::sleep(Duration::from_secs(1)).await;
157 continue;
158 },
159 }
160 }
161
162 let Some(w) = worker.as_mut() else {
163 unreachable!()
164 };
165
166 let Some(req) = inbox.recv().await else {
168 info!(slot_id, "supervisor shutting down (inbox closed)");
169 if let Err(e) = w.terminate().await {
170 warn!(slot_id, error = ?e, "terminate error during shutdown");
171 }
172 return;
173 };
174
175 slot.mark_busy();
177 let result = dispatch_one(w.as_mut(), &req.method, req.payload, config.exec_timeout).await;
178 slot.mark_idle();
179
180 let _ = req.reply.send(result.map_err(anyhow::Error::from));
182
183 if slot.should_recycle(&config) {
185 info!(slot_id, jobs = slot.jobs_handled, "recycling worker");
186 if let Some(mut w) = worker.take() {
187 let _ = w.terminate().await;
188 }
189 slot = SlotInfo::new();
190 }
191 }
192}
193
194async fn boot_worker(
195 runtime: &Arc<dyn Runtime>,
196 config: &WorkersConfig,
197 slot: &mut SlotInfo,
198) -> Result<Box<dyn WorkerHandle>> {
199 let mut handle = runtime.spawn().await.context("spawn")?;
200 let timeout = tokio::time::timeout(config.boot_timeout, handle.recv_control());
201 match timeout.await {
202 Ok(Ok(Some(RpcMessage::Notify { method, .. }))) if method == "control.ready" => {
203 let pid = handle.pid();
204 slot.mark_ready(pid);
205 debug!(pid, "worker ready");
206 Ok(handle)
207 },
208 Ok(Ok(other)) => {
209 let _ = handle.terminate().await;
210 anyhow::bail!("expected control.ready, got {other:?}")
211 },
212 Ok(Err(e)) => {
213 let _ = handle.terminate().await;
214 Err(e).context("recv_control failed during boot")
215 },
216 Err(_) => {
217 let _ = handle.terminate().await;
218 anyhow::bail!("worker boot timed out after {:?}", config.boot_timeout)
219 },
220 }
221}
222
223async fn dispatch_one(
224 worker: &mut dyn WorkerHandle,
225 method: &str,
226 payload: Bytes,
227 exec_timeout: Duration,
228) -> Result<Bytes, WorkError> {
229 static MSGID: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1);
230 let msgid = MSGID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
231
232 let params = rmp_serde::from_slice::<RmpValue>(&payload)
233 .map_err(|e| WorkError::Protocol(folk_protocol::Error::Decode(e)))?;
234 let request = RpcMessage::request(msgid, method, params);
235
236 worker
237 .send_task(request)
238 .await
239 .map_err(|_| WorkError::WorkerDied)?;
240
241 let recv = tokio::time::timeout(exec_timeout, worker.recv_task());
242 let response = match recv.await {
243 Ok(Ok(Some(msg))) => msg,
244 Ok(Ok(None) | Err(_)) => return Err(WorkError::WorkerDied),
245 Err(_) => return Err(WorkError::Timeout),
246 };
247
248 match response {
249 RpcMessage::Response { error, result, .. } => {
250 if !error.is_nil() {
251 return Err(WorkError::Application {
252 code: -1,
253 message: format!("{error:?}"),
254 });
255 }
256 let mut buf = Vec::new();
257 rmp_serde::encode::write(&mut buf, &result)
258 .map_err(|e| WorkError::Protocol(folk_protocol::Error::Encode(e)))?;
259 Ok(Bytes::from(buf))
260 },
261 other => Err(WorkError::Protocol(folk_protocol::Error::InvalidFrame(
262 format!("expected Response, got {other:?}"),
263 ))),
264 }
265}