1use std::ffi::OsString;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, Mutex, mpsc};
6use std::thread::{self, JoinHandle};
7use std::time::Duration;
8
9use rand::RngExt as _;
10use tempfile::TempDir;
11
12mod codex_home;
13mod config;
14
15pub use config::{LogCaptureConfig, LogPathMode, PoolConfig, PoolStats};
16
17use crate::{
18 AcpClient, DEFAULT_CODEX_ACP_MODEL, DEFAULT_CODEX_ACP_REASONING_EFFORT, DEFAULT_SYSTEM_PROMPT,
19 PoolError, RateLimitEvent, RubricOptions, RubricVerdict, build_codex_acp_args, encode_png,
20 parse_verdict,
21};
22use codex_home::seed_codex_home;
23
24const RECYCLE_SPAWN_ATTEMPTS: u32 = 2;
25
26#[derive(Debug)]
28pub struct RubricPool {
29 senders: Vec<mpsc::Sender<Job>>,
30 handles: Mutex<Vec<JoinHandle<()>>>,
31 next: AtomicUsize,
32 config: PoolConfig,
33 shared: Arc<SharedPoolState>,
34}
35
36struct Job {
37 png_path: PathBuf,
38 question: String,
39 options: RubricOptions,
40 reply: mpsc::Sender<Result<RubricVerdict, PoolError>>,
41}
42
43#[derive(Default, Debug)]
44struct SharedPoolState {
45 completed: AtomicU64,
46 failures: AtomicU64,
47 worker_recycles: AtomicU64,
48 rate_limit_events: Mutex<Vec<RateLimitEvent>>,
49 fatal_quota: AtomicBool,
50 alive_mask: AtomicU64,
51}
52
53impl RubricPool {
54 pub fn new(config: PoolConfig) -> Result<Self, PoolError> {
61 if config.workers == 0 {
62 return Err(PoolError::Spawn(
63 "workers must be greater than zero".to_string(),
64 ));
65 }
66 if config.workers > u64::BITS as usize {
67 return Err(PoolError::Spawn(format!(
68 "workers={} exceeds alive bitmask capacity {}",
69 config.workers,
70 u64::BITS
71 )));
72 }
73
74 let shared = Arc::new(SharedPoolState::default());
75 shared
76 .alive_mask
77 .store(alive_mask(config.workers), Ordering::Release);
78
79 let mut senders = Vec::with_capacity(config.workers);
80 let mut handles = Vec::with_capacity(config.workers);
81 for worker_id in 0..config.workers {
82 let (job_tx, job_rx) = mpsc::channel();
83 let (ready_tx, ready_rx) = mpsc::channel();
84 let worker = Worker {
85 id: worker_id,
86 config: config.clone(),
87 shared: Arc::clone(&shared),
88 jobs: job_rx,
89 };
90 let handle = thread::spawn(move || worker.run(ready_tx));
91 match ready_rx.recv() {
92 Ok(Ok(())) => {
93 senders.push(job_tx);
94 handles.push(handle);
95 }
96 Ok(Err(error)) => {
97 shared
98 .alive_mask
99 .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
100 drop(senders);
101 join_handles(handles);
102 let _ = handle.join();
103 return Err(error);
104 }
105 Err(error) => {
106 shared
107 .alive_mask
108 .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
109 drop(senders);
110 join_handles(handles);
111 let _ = handle.join();
112 return Err(PoolError::WorkerCrashed {
113 worker_id,
114 message: format!("worker exited before startup result: {error}"),
115 });
116 }
117 }
118 }
119
120 Ok(Self {
121 senders,
122 handles: Mutex::new(handles),
123 next: AtomicUsize::new(0),
124 config,
125 shared,
126 })
127 }
128
129 pub fn submit(
136 &self,
137 png_path: &Path,
138 question: &str,
139 opts: RubricOptions,
140 ) -> Result<RubricVerdict, PoolError> {
141 if self.shared.fatal_quota.load(Ordering::Acquire) {
142 return Err(PoolError::QuotaExceeded);
143 }
144
145 let worker_id = self.next_live_worker()?;
146 let (reply_tx, reply_rx) = mpsc::channel();
147 let job = Job {
148 png_path: png_path.to_path_buf(),
149 question: question.to_string(),
150 options: merge_options(opts, &self.config.default_options),
151 reply: reply_tx,
152 };
153
154 self.senders[worker_id]
155 .send(job)
156 .map_err(|_| PoolError::WorkerCrashed {
157 worker_id,
158 message: "worker channel closed".to_string(),
159 })?;
160
161 match reply_rx.recv_timeout(self.config.submit_timeout) {
162 Ok(result) => result,
163 Err(mpsc::RecvTimeoutError::Timeout) => Err(PoolError::Timeout {
164 worker_id,
165 timeout: self.config.submit_timeout,
166 }),
167 Err(mpsc::RecvTimeoutError::Disconnected) => Err(PoolError::WorkerCrashed {
168 worker_id,
169 message: "worker dropped reply channel".to_string(),
170 }),
171 }
172 }
173
174 #[must_use]
176 pub fn shutdown(self) -> PoolStats {
177 let Self {
178 senders,
179 handles,
180 shared,
181 ..
182 } = self;
183 drop(senders);
184 if let Ok(handles) = handles.into_inner() {
185 join_handles(handles);
186 }
187 shared.stats()
188 }
189
190 #[must_use]
192 pub fn stats(&self) -> PoolStats {
193 self.shared.stats()
194 }
195
196 fn next_live_worker(&self) -> Result<usize, PoolError> {
197 let worker_count = self.senders.len();
198 for _ in 0..worker_count {
199 let idx = self.next.fetch_add(1, Ordering::AcqRel) % worker_count;
200 let mask = self.shared.alive_mask.load(Ordering::Acquire);
201 if mask & worker_bit(idx) != 0 {
202 return Ok(idx);
203 }
204 }
205 Err(PoolError::NoLiveWorkers)
206 }
207}
208
209struct Worker {
210 id: usize,
211 config: PoolConfig,
212 shared: Arc<SharedPoolState>,
213 jobs: mpsc::Receiver<Job>,
214}
215
216struct WorkerRuntime {
217 acp: AcpClient,
218 _codex_home: TempDir,
219 prompts: u32,
220 model: String,
221 effort: String,
222}
223
224impl Worker {
225 fn run(self, ready: mpsc::Sender<Result<(), PoolError>>) {
226 let mut runtime = match self.spawn_runtime(&self.config.default_options) {
227 Ok(runtime) => {
228 let _ = ready.send(Ok(()));
229 runtime
230 }
231 Err(error) => {
232 self.mark_dead();
233 let _ = ready.send(Err(error));
234 return;
235 }
236 };
237
238 while let Ok(job) = self.jobs.recv() {
239 let result = self.handle_job(&mut runtime, &job);
240 let fatal_quota = matches!(result, Err(PoolError::QuotaExceeded));
241 let _ = job.reply.send(result);
242 if fatal_quota {
243 self.shared.fatal_quota.store(true, Ordering::Release);
244 }
245 if self.shared.alive_mask.load(Ordering::Acquire) & worker_bit(self.id) == 0 {
246 break;
247 }
248 }
249 }
250
251 fn handle_job(
252 &self,
253 runtime: &mut WorkerRuntime,
254 job: &Job,
255 ) -> Result<RubricVerdict, PoolError> {
256 let mut last_error = None;
257 for attempt in 0..=self.config.max_retries {
258 if !runtime.matches_options(&job.options) {
259 self.recycle_runtime(runtime, &job.options)?;
260 }
261
262 match self.evaluate_once(runtime, job) {
263 Ok(verdict) => {
264 self.shared.completed.fetch_add(1, Ordering::AcqRel);
265 runtime.prompts += 1;
266 if runtime.prompts >= self.config.max_prompts_per_worker {
267 self.recycle_runtime(runtime, &job.options)?;
268 }
269 return Ok(verdict);
270 }
271 Err(PoolError::QuotaExceeded) => {
272 self.shared.failures.fetch_add(1, Ordering::AcqRel);
273 return Err(PoolError::QuotaExceeded);
274 }
275 Err(PoolError::RateLimited { retry_after }) => {
276 let delay =
277 backoff_delay(attempt, self.config.backoff_base, self.config.backoff_cap);
278 self.shared.push_rate_limit_event(RateLimitEvent {
279 worker_id: self.id,
280 attempt,
281 delay,
282 retry_after,
283 });
284 last_error = Some(PoolError::RateLimited { retry_after });
285 if attempt < self.config.max_retries {
286 thread::sleep(delay);
287 }
288 }
289 Err(error) => {
290 last_error = Some(error);
291 self.recycle_runtime(runtime, &job.options)?;
292 }
293 }
294 }
295
296 self.shared.failures.fetch_add(1, Ordering::AcqRel);
297 Err(last_error.unwrap_or_else(|| PoolError::Rpc("retry loop exhausted".to_string())))
298 }
299
300 fn evaluate_once(
301 &self,
302 runtime: &mut WorkerRuntime,
303 job: &Job,
304 ) -> Result<RubricVerdict, PoolError> {
305 let b64 = encode_png(&job.png_path)?;
306 let system_prompt = job
307 .options
308 .system_prompt
309 .as_deref()
310 .map_or(DEFAULT_SYSTEM_PROMPT, |system_prompt| system_prompt);
311 let prompt = format!("{system_prompt}\n\nQuestion: {}", job.question);
312 let text = runtime.acp.prompt_image(&prompt, &b64)?;
313 parse_verdict(&text).map_err(|e| PoolError::ParseVerdict(format!("from {text:?}: {e}")))
314 }
315
316 fn recycle_runtime(
317 &self,
318 runtime: &mut WorkerRuntime,
319 options: &RubricOptions,
320 ) -> Result<(), PoolError> {
321 self.shared.worker_recycles.fetch_add(1, Ordering::AcqRel);
322 let mut last_error = None;
323 for _ in 0..RECYCLE_SPAWN_ATTEMPTS {
324 match self.spawn_runtime(options) {
325 Ok(new_runtime) => {
326 *runtime = new_runtime;
327 return Ok(());
328 }
329 Err(error) => {
330 last_error = Some(error);
331 }
332 }
333 }
334 self.mark_dead();
335 Err(last_error.unwrap_or_else(|| PoolError::Spawn("recycle failed".to_string())))
336 }
337
338 fn spawn_runtime(&self, options: &RubricOptions) -> Result<WorkerRuntime, PoolError> {
339 let codex_home =
340 TempDir::new().map_err(|e| PoolError::Spawn(format!("create CODEX_HOME: {e}")))?;
341 seed_codex_home(codex_home.path(), self.config.source_codex_home.as_deref())?;
342 let mut env = self.config.extra_env.clone();
343 env.push((
344 OsString::from("CODEX_HOME"),
345 codex_home.path().as_os_str().to_os_string(),
346 ));
347 if let Some(log_capture) = &self.config.log_capture {
348 fs::create_dir_all(&log_capture.temp_dir).map_err(|e| {
349 PoolError::Spawn(format!(
350 "create ACP TMPDIR {}: {e}",
351 log_capture.temp_dir.display()
352 ))
353 })?;
354 env.push((
355 OsString::from("TMPDIR"),
356 log_capture.temp_dir.as_os_str().to_os_string(),
357 ));
358 }
359
360 let model = options
361 .model
362 .as_deref()
363 .map_or(DEFAULT_CODEX_ACP_MODEL, |model| model);
364 let effort = options
365 .effort
366 .as_deref()
367 .map_or(DEFAULT_CODEX_ACP_REASONING_EFFORT, |effort| effort);
368 let acp_args = build_codex_acp_args(model, effort);
369 let mut acp = AcpClient::spawn(&self.config.codex_acp_binary, &acp_args, &env, None)?;
370 acp.start_session(None)?;
371
372 Ok(WorkerRuntime {
373 acp,
374 _codex_home: codex_home,
375 prompts: 0,
376 model: model.to_string(),
377 effort: effort.to_string(),
378 })
379 }
380
381 fn mark_dead(&self) {
382 self.shared
383 .alive_mask
384 .fetch_and(!worker_bit(self.id), Ordering::AcqRel);
385 }
386}
387
388impl WorkerRuntime {
389 fn matches_options(&self, options: &RubricOptions) -> bool {
390 options.model.as_deref() == Some(self.model.as_str())
391 && options.effort.as_deref() == Some(self.effort.as_str())
392 }
393}
394
395impl SharedPoolState {
396 fn stats(&self) -> PoolStats {
397 PoolStats {
398 completed: self.completed.load(Ordering::Acquire),
399 failures: self.failures.load(Ordering::Acquire),
400 rate_limit_events: self
401 .rate_limit_events
402 .lock()
403 .map(|events| events.clone())
404 .unwrap_or_default(),
405 worker_recycles: self.worker_recycles.load(Ordering::Acquire),
406 }
407 }
408
409 fn push_rate_limit_event(&self, event: RateLimitEvent) {
410 if let Ok(mut events) = self.rate_limit_events.lock() {
411 events.push(event);
412 }
413 }
414}
415
416fn merge_options(mut opts: RubricOptions, defaults: &RubricOptions) -> RubricOptions {
417 if opts.model.is_none() {
418 opts.model.clone_from(&defaults.model);
419 }
420 if opts.effort.is_none() {
421 opts.effort.clone_from(&defaults.effort);
422 }
423 if opts.system_prompt.is_none() {
424 opts.system_prompt.clone_from(&defaults.system_prompt);
425 }
426 opts
427}
428
429fn backoff_delay(attempt: u32, base: Duration, cap: Duration) -> Duration {
430 let multiplier = 1u32 << attempt.min(6);
431 let capped = base.saturating_mul(multiplier).min(cap);
432 let capped_millis = u64::try_from(capped.as_millis()).map_or(u64::MAX, |millis| millis);
433 let jitter_cap = capped_millis / 4;
434 let jitter_ms = rand::rng().random_range(0..=jitter_cap);
435 capped.saturating_add(Duration::from_millis(jitter_ms))
436}
437
438fn alive_mask(workers: usize) -> u64 {
439 if workers == u64::BITS as usize {
440 u64::MAX
441 } else {
442 (1u64 << workers) - 1
443 }
444}
445
446fn worker_bit(worker_id: usize) -> u64 {
447 1u64 << worker_id
448}
449
450fn join_handles(handles: Vec<JoinHandle<()>>) {
451 for handle in handles {
452 let _ = handle.join();
453 }
454}