1use std::collections::HashMap;
49use std::io;
50use std::process::Stdio;
51use std::sync::atomic::{AtomicUsize, Ordering};
52use std::time::Duration;
53
54use async_trait::async_trait;
55use serde::{Deserialize, Serialize};
56use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
57use tokio::process::{Child, ChildStdin, ChildStdout, Command};
58use tokio::sync::Mutex;
59use tokio_util::sync::CancellationToken;
60
61use super::handler::{JobCtx, JobHandler, JobOutcome};
62
63const DEFAULT_TIMEOUT_SECS: u64 = 600;
66
67#[derive(Debug, Clone)]
69pub struct WorkerPoolConfig {
70 pub kind: &'static str,
72 pub argv: Vec<String>,
75 pub env: HashMap<String, String>,
77 pub cwd: Option<String>,
79 pub size: usize,
81 pub timeout_secs: Option<u64>,
83}
84
85struct ChildIo {
87 child: Child,
88 stdin: ChildStdin,
89 lines: Lines<BufReader<ChildStdout>>,
90}
91
92impl ChildIo {
93 async fn exchange(&mut self, request: &str) -> io::Result<String> {
97 self.stdin.write_all(request.as_bytes()).await?;
98 self.stdin.write_all(b"\n").await?;
99 self.stdin.flush().await?;
100 self.lines.next_line().await?.map_or_else(
101 || {
102 Err(io::Error::new(
103 io::ErrorKind::UnexpectedEof,
104 "worker closed stdout",
105 ))
106 },
107 Ok,
108 )
109 }
110}
111
112#[derive(Serialize)]
113struct WorkerRequest<'a> {
114 id: &'a str,
115 payload: &'a serde_json::Value,
116}
117
118#[derive(Deserialize)]
119struct WorkerResponse {
120 #[serde(default)]
125 id: Option<String>,
126 #[serde(flatten)]
127 body: WorkerResponseBody,
128}
129
130#[derive(Deserialize)]
131#[serde(tag = "status", rename_all = "snake_case")]
132enum WorkerResponseBody {
133 Ok,
134 Error {
135 #[serde(default)]
136 message: Option<String>,
137 },
138 Throttled {
139 #[serde(default)]
140 retry_after_secs: Option<u64>,
141 },
142}
143
144pub struct WorkerPoolHandler {
148 config: WorkerPoolConfig,
149 slots: Vec<Mutex<Option<ChildIo>>>,
152 next: AtomicUsize,
153}
154
155impl std::fmt::Debug for WorkerPoolHandler {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("WorkerPoolHandler")
158 .field("kind", &self.config.kind)
159 .field("size", &self.slots.len())
160 .finish_non_exhaustive()
161 }
162}
163
164impl WorkerPoolHandler {
165 #[allow(
173 clippy::unused_async,
174 reason = "process-spawning constructor — kept async for a future readiness handshake and so callers needn't change if spawn gains awaits"
175 )]
176 pub async fn spawn(config: WorkerPoolConfig) -> Result<Self, String> {
177 if config.argv.is_empty() {
178 return Err("worker pool argv is empty".to_owned());
179 }
180 let size = config.size.max(1);
181 let mut slots = Vec::with_capacity(size);
182 for _ in 0..size {
183 let io = spawn_child(&config)?;
184 slots.push(Mutex::new(Some(io)));
185 }
186 Ok(Self {
187 config,
188 slots,
189 next: AtomicUsize::new(0),
190 })
191 }
192
193 fn timeout(&self) -> Duration {
194 Duration::from_secs(
195 self.config
196 .timeout_secs
197 .unwrap_or(DEFAULT_TIMEOUT_SECS)
198 .max(1),
199 )
200 }
201
202 #[allow(
203 clippy::significant_drop_tightening,
204 reason = "the per-slot guard is intentionally held across the request/response exchange — one in-flight request per child"
205 )]
206 async fn dispatch(
207 &self,
208 job_id: &str,
209 payload: &serde_json::Value,
210 cancel: &CancellationToken,
211 ) -> JobOutcome {
212 let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.slots.len();
213 let mut guard = self.slots[idx].lock().await;
214
215 let mut io = match guard.take() {
218 Some(io) => io,
219 None => match spawn_child(&self.config) {
220 Ok(io) => io,
221 Err(e) => return JobOutcome::Failed(format!("respawn worker: {e}")),
222 },
223 };
224
225 let request = match serde_json::to_string(&WorkerRequest {
226 id: job_id,
227 payload,
228 }) {
229 Ok(s) => s,
230 Err(e) => {
231 *guard = Some(io);
233 return JobOutcome::Failed(format!("encode request: {e}"));
234 }
235 };
236
237 let exchanged = tokio::select! {
238 () = cancel.cancelled() => {
239 let _ = io.child.start_kill();
242 return JobOutcome::Failed("cancelled by supervisor".to_owned());
243 }
244 res = tokio::time::timeout(self.timeout(), io.exchange(&request)) => res,
245 };
246
247 match exchanged {
248 Ok(Ok(line)) => match parse_response(&line, job_id) {
249 Ok(outcome) => {
251 *guard = Some(io);
252 outcome
253 }
254 Err(reason) => {
258 let _ = io.child.start_kill();
259 JobOutcome::Failed(reason)
260 }
261 },
262 Ok(Err(e)) => {
263 let _ = io.child.start_kill();
264 JobOutcome::Failed(format!("worker io: {e}"))
265 }
266 Err(_) => {
267 let _ = io.child.start_kill();
268 JobOutcome::Failed(format!(
269 "worker timeout after {}s",
270 self.timeout().as_secs()
271 ))
272 }
273 }
274 }
275}
276
277#[async_trait]
278impl JobHandler for WorkerPoolHandler {
279 fn kind(&self) -> &'static str {
280 self.config.kind
281 }
282
283 async fn run(&self, ctx: JobCtx<'_>, payload: serde_json::Value) -> JobOutcome {
284 self.dispatch(ctx.job_id.as_str(), &payload, &ctx.cancel)
285 .await
286 }
287}
288
289fn spawn_child(config: &WorkerPoolConfig) -> Result<ChildIo, String> {
293 let (program, args) = config
294 .argv
295 .split_first()
296 .ok_or_else(|| "worker pool argv is empty".to_owned())?;
297 let mut cmd = Command::new(program);
298 cmd.args(args)
299 .envs(&config.env)
300 .stdin(Stdio::piped())
301 .stdout(Stdio::piped())
302 .stderr(Stdio::inherit())
303 .kill_on_drop(true);
304 if let Some(cwd) = &config.cwd {
305 cmd.current_dir(cwd);
306 }
307 let mut child = cmd.spawn().map_err(|e| format!("spawn {program:?}: {e}"))?;
308 let stdin = child
309 .stdin
310 .take()
311 .ok_or_else(|| "child stdin not piped".to_owned())?;
312 let stdout = child
313 .stdout
314 .take()
315 .ok_or_else(|| "child stdout not piped".to_owned())?;
316 Ok(ChildIo {
317 child,
318 stdin,
319 lines: BufReader::new(stdout).lines(),
320 })
321}
322
323fn parse_response(line: &str, expected_id: &str) -> Result<JobOutcome, String> {
328 let resp: WorkerResponse =
329 serde_json::from_str(line).map_err(|e| format!("bad worker response {line:?}: {e}"))?;
330 if let Some(id) = &resp.id
333 && id != expected_id
334 {
335 return Err(format!(
336 "worker response id mismatch (desync): got {id:?}, want {expected_id:?}"
337 ));
338 }
339 Ok(match resp.body {
340 WorkerResponseBody::Ok => JobOutcome::Done,
341 WorkerResponseBody::Error { message } => {
342 JobOutcome::Failed(message.unwrap_or_else(|| "worker reported error".to_owned()))
343 }
344 WorkerResponseBody::Throttled { retry_after_secs } => JobOutcome::Throttled {
345 retry_after: Duration::from_secs(retry_after_secs.unwrap_or(60)),
346 },
347 })
348}
349
350#[cfg(test)]
351#[allow(
352 clippy::unwrap_used,
353 clippy::expect_used,
354 clippy::panic,
355 reason = "tests crash loudly on setup or assertion failure; that's the point"
356)]
357mod tests {
358 use super::*;
359
360 async fn pool_echoing(reply: &str, size: usize) -> WorkerPoolHandler {
364 let script = format!("while IFS= read -r _line; do printf '%s\\n' '{reply}'; done");
367 WorkerPoolHandler::spawn(WorkerPoolConfig {
368 kind: "worker_pool_test",
369 argv: vec!["sh".into(), "-c".into(), script],
370 env: HashMap::new(),
371 cwd: None,
372 size,
373 timeout_secs: Some(5),
374 })
375 .await
376 .expect("spawn pool")
377 }
378
379 fn cancel() -> CancellationToken {
380 CancellationToken::new()
381 }
382
383 #[tokio::test]
384 async fn ok_response_maps_to_done() {
385 let pool = pool_echoing(r#"{"status":"ok"}"#, 1).await;
386 let out = pool
387 .dispatch("job-1", &serde_json::json!({"x":1}), &cancel())
388 .await;
389 assert!(matches!(out, JobOutcome::Done), "got: {out:?}");
390 }
391
392 #[tokio::test]
393 async fn error_response_maps_to_failed() {
394 let pool = pool_echoing(r#"{"status":"error","message":"boom"}"#, 1).await;
395 let out = pool
396 .dispatch("job-2", &serde_json::json!({}), &cancel())
397 .await;
398 match out {
399 JobOutcome::Failed(msg) => assert_eq!(msg, "boom"),
400 other => panic!("expected Failed, got {other:?}"),
401 }
402 }
403
404 #[tokio::test]
405 async fn throttled_response_maps_to_throttled() {
406 let pool = pool_echoing(r#"{"status":"throttled","retry_after_secs":12}"#, 1).await;
407 let out = pool
408 .dispatch("job-3", &serde_json::json!({}), &cancel())
409 .await;
410 match out {
411 JobOutcome::Throttled { retry_after } => assert_eq!(retry_after.as_secs(), 12),
412 other => panic!("expected Throttled, got {other:?}"),
413 }
414 }
415
416 #[tokio::test]
417 async fn id_mismatch_is_treated_as_failure() {
418 let pool = pool_echoing(r#"{"id":"stale","status":"ok"}"#, 1).await;
422 let out = pool
423 .dispatch("job-X", &serde_json::json!({}), &cancel())
424 .await;
425 match out {
426 JobOutcome::Failed(msg) => assert!(msg.contains("mismatch"), "got: {msg}"),
427 other => panic!("expected Failed, got {other:?}"),
428 }
429 }
430
431 #[tokio::test]
432 async fn matching_id_is_accepted() {
433 let pool = pool_echoing(r#"{"id":"job-Y","status":"ok"}"#, 1).await;
435 let out = pool
436 .dispatch("job-Y", &serde_json::json!({}), &cancel())
437 .await;
438 assert!(matches!(out, JobOutcome::Done), "got: {out:?}");
439 }
440
441 #[tokio::test]
442 async fn garbage_response_maps_to_failed() {
443 let pool = pool_echoing("not json", 1).await;
444 let out = pool
445 .dispatch("job-4", &serde_json::json!({}), &cancel())
446 .await;
447 assert!(matches!(out, JobOutcome::Failed(_)), "got: {out:?}");
448 }
449
450 #[tokio::test]
451 async fn reuses_the_same_warm_child_across_jobs() {
452 let pool = pool_echoing(r#"{"status":"ok"}"#, 1).await;
456 for i in 0..3 {
457 let out = pool
458 .dispatch(&format!("job-{i}"), &serde_json::json!({}), &cancel())
459 .await;
460 assert!(matches!(out, JobOutcome::Done), "iter {i}: {out:?}");
461 }
462 }
463
464 #[tokio::test]
465 async fn empty_argv_rejected() {
466 let err = WorkerPoolHandler::spawn(WorkerPoolConfig {
467 kind: "x",
468 argv: vec![],
469 env: HashMap::new(),
470 cwd: None,
471 size: 1,
472 timeout_secs: None,
473 })
474 .await;
475 assert!(err.is_err());
476 }
477}