Skip to main content

visual_rubric/
pool.rs

1use std::ffi::OsString;
2use std::path::{Path, PathBuf};
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, mpsc};
5use std::thread::{self, JoinHandle};
6use std::time::Duration;
7
8use rand::Rng as _;
9use tempfile::TempDir;
10
11mod codex_home;
12
13use crate::{
14    AcpClient, DEFAULT_CODEX_ACP_MODEL, DEFAULT_CODEX_ACP_REASONING_EFFORT, DEFAULT_SYSTEM_PROMPT,
15    PoolError, RateLimitEvent, RubricOptions, RubricVerdict, default_codex_acp_binary,
16    default_options, encode_png, parse_verdict,
17};
18use codex_home::seed_codex_home;
19
20const DEFAULT_SUBMIT_TIMEOUT: Duration = Duration::from_secs(600);
21const RECYCLE_SPAWN_ATTEMPTS: u32 = 2;
22
23/// Reusable worker pool for evaluating screenshot rubrics through Codex ACP.
24pub struct RubricPool {
25    senders: Vec<mpsc::Sender<Job>>,
26    handles: Mutex<Vec<JoinHandle<()>>>,
27    next: AtomicUsize,
28    config: PoolConfig,
29    shared: Arc<SharedPoolState>,
30}
31
32/// Configuration for a [`RubricPool`].
33#[derive(Clone, Debug)]
34pub struct PoolConfig {
35    /// Number of worker processes to keep alive.
36    pub workers: usize,
37    /// Number of prompts after which a worker is recycled.
38    pub max_prompts_per_worker: u32,
39    /// Number of retries for recoverable worker or rate-limit failures.
40    pub max_retries: u32,
41    /// Initial retry backoff.
42    pub backoff_base: Duration,
43    /// Maximum retry backoff.
44    pub backoff_cap: Duration,
45    /// Options applied when a submitted job omits an override.
46    pub default_options: RubricOptions,
47    /// Path to the `codex-acp` executable.
48    pub codex_acp_binary: PathBuf,
49    /// Extra environment variables for worker processes.
50    pub extra_env: Vec<(OsString, OsString)>,
51    /// Maximum time to wait for one submitted job.
52    pub submit_timeout: Duration,
53    /// Optional Codex home directory to seed into worker-local homes.
54    pub source_codex_home: Option<PathBuf>,
55}
56
57impl Default for PoolConfig {
58    fn default() -> Self {
59        Self {
60            workers: 4,
61            max_prompts_per_worker: 50,
62            max_retries: 4,
63            backoff_base: Duration::from_secs(30),
64            backoff_cap: Duration::from_secs(300),
65            default_options: default_options(),
66            codex_acp_binary: default_codex_acp_binary(),
67            extra_env: Vec::new(),
68            submit_timeout: DEFAULT_SUBMIT_TIMEOUT,
69            source_codex_home: None,
70        }
71    }
72}
73
74struct Job {
75    png_path: PathBuf,
76    question: String,
77    options: RubricOptions,
78    reply: mpsc::Sender<Result<RubricVerdict, PoolError>>,
79}
80
81/// Snapshot of pool execution counters.
82#[derive(Clone, Debug, Default)]
83pub struct PoolStats {
84    /// Successfully completed jobs.
85    pub completed: u64,
86    /// Failed jobs.
87    pub failures: u64,
88    /// Rate-limit events observed by workers.
89    pub rate_limit_events: Vec<RateLimitEvent>,
90    /// Number of worker runtime recycles.
91    pub worker_recycles: u64,
92}
93
94#[derive(Default)]
95struct SharedPoolState {
96    completed: AtomicU64,
97    failures: AtomicU64,
98    worker_recycles: AtomicU64,
99    rate_limit_events: Mutex<Vec<RateLimitEvent>>,
100    fatal_quota: AtomicBool,
101    alive_mask: AtomicU64,
102}
103
104impl RubricPool {
105    /// Starts a worker pool from the supplied configuration.
106    ///
107    /// # Errors
108    ///
109    /// Returns [`PoolError`] when configuration is invalid or worker startup
110    /// fails.
111    pub fn new(config: PoolConfig) -> Result<Self, PoolError> {
112        if config.workers == 0 {
113            return Err(PoolError::Spawn(
114                "workers must be greater than zero".to_string(),
115            ));
116        }
117        if config.workers > u64::BITS as usize {
118            return Err(PoolError::Spawn(format!(
119                "workers={} exceeds alive bitmask capacity {}",
120                config.workers,
121                u64::BITS
122            )));
123        }
124
125        let shared = Arc::new(SharedPoolState::default());
126        shared
127            .alive_mask
128            .store(alive_mask(config.workers), Ordering::Release);
129
130        let mut senders = Vec::with_capacity(config.workers);
131        let mut handles = Vec::with_capacity(config.workers);
132        for worker_id in 0..config.workers {
133            let (job_tx, job_rx) = mpsc::channel();
134            let (ready_tx, ready_rx) = mpsc::channel();
135            let worker = Worker {
136                id: worker_id,
137                config: config.clone(),
138                shared: Arc::clone(&shared),
139                jobs: job_rx,
140            };
141            let handle = thread::spawn(move || worker.run(ready_tx));
142            match ready_rx.recv() {
143                Ok(Ok(())) => {
144                    senders.push(job_tx);
145                    handles.push(handle);
146                }
147                Ok(Err(error)) => {
148                    shared
149                        .alive_mask
150                        .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
151                    drop(senders);
152                    join_handles(handles);
153                    let _ = handle.join();
154                    return Err(error);
155                }
156                Err(error) => {
157                    shared
158                        .alive_mask
159                        .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
160                    drop(senders);
161                    join_handles(handles);
162                    let _ = handle.join();
163                    return Err(PoolError::WorkerCrashed {
164                        worker_id,
165                        message: format!("worker exited before startup result: {error}"),
166                    });
167                }
168            }
169        }
170
171        Ok(Self {
172            senders,
173            handles: Mutex::new(handles),
174            next: AtomicUsize::new(0),
175            config,
176            shared,
177        })
178    }
179
180    /// Submits one PNG rubric job to a live worker.
181    ///
182    /// # Errors
183    ///
184    /// Returns [`PoolError`] for missing workers, worker crashes, timeouts, PNG
185    /// IO, Codex ACP failures, or verdict parsing failures.
186    pub fn submit(
187        &self,
188        png_path: &Path,
189        question: &str,
190        opts: RubricOptions,
191    ) -> Result<RubricVerdict, PoolError> {
192        if self.shared.fatal_quota.load(Ordering::Acquire) {
193            return Err(PoolError::QuotaExceeded);
194        }
195
196        let worker_id = self.next_live_worker()?;
197        let (reply_tx, reply_rx) = mpsc::channel();
198        let job = Job {
199            png_path: png_path.to_path_buf(),
200            question: question.to_string(),
201            options: merge_options(opts, &self.config.default_options),
202            reply: reply_tx,
203        };
204
205        self.senders[worker_id]
206            .send(job)
207            .map_err(|_| PoolError::WorkerCrashed {
208                worker_id,
209                message: "worker channel closed".to_string(),
210            })?;
211
212        match reply_rx.recv_timeout(self.config.submit_timeout) {
213            Ok(result) => result,
214            Err(mpsc::RecvTimeoutError::Timeout) => Err(PoolError::Timeout {
215                worker_id,
216                timeout: self.config.submit_timeout,
217            }),
218            Err(mpsc::RecvTimeoutError::Disconnected) => Err(PoolError::WorkerCrashed {
219                worker_id,
220                message: "worker dropped reply channel".to_string(),
221            }),
222        }
223    }
224
225    /// Stops workers and returns final pool statistics.
226    #[must_use]
227    pub fn shutdown(self) -> PoolStats {
228        let Self {
229            senders,
230            handles,
231            shared,
232            ..
233        } = self;
234        drop(senders);
235        if let Ok(handles) = handles.into_inner() {
236            join_handles(handles);
237        }
238        shared.stats()
239    }
240
241    /// Returns current pool statistics without shutting the pool down.
242    #[must_use]
243    pub fn stats(&self) -> PoolStats {
244        self.shared.stats()
245    }
246
247    fn next_live_worker(&self) -> Result<usize, PoolError> {
248        let worker_count = self.senders.len();
249        for _ in 0..worker_count {
250            let idx = self.next.fetch_add(1, Ordering::AcqRel) % worker_count;
251            let mask = self.shared.alive_mask.load(Ordering::Acquire);
252            if mask & worker_bit(idx) != 0 {
253                return Ok(idx);
254            }
255        }
256        Err(PoolError::NoLiveWorkers)
257    }
258}
259
260struct Worker {
261    id: usize,
262    config: PoolConfig,
263    shared: Arc<SharedPoolState>,
264    jobs: mpsc::Receiver<Job>,
265}
266
267struct WorkerRuntime {
268    acp: AcpClient,
269    _codex_home: TempDir,
270    prompts: u32,
271    model: String,
272    effort: String,
273}
274
275impl Worker {
276    fn run(self, ready: mpsc::Sender<Result<(), PoolError>>) {
277        let mut runtime = match self.spawn_runtime(&self.config.default_options) {
278            Ok(runtime) => {
279                let _ = ready.send(Ok(()));
280                runtime
281            }
282            Err(error) => {
283                self.mark_dead();
284                let _ = ready.send(Err(error));
285                return;
286            }
287        };
288
289        while let Ok(job) = self.jobs.recv() {
290            let result = self.handle_job(&mut runtime, &job);
291            let fatal_quota = matches!(result, Err(PoolError::QuotaExceeded));
292            let _ = job.reply.send(result);
293            if fatal_quota {
294                self.shared.fatal_quota.store(true, Ordering::Release);
295            }
296            if self.shared.alive_mask.load(Ordering::Acquire) & worker_bit(self.id) == 0 {
297                break;
298            }
299        }
300    }
301
302    fn handle_job(
303        &self,
304        runtime: &mut WorkerRuntime,
305        job: &Job,
306    ) -> Result<RubricVerdict, PoolError> {
307        let mut last_error = None;
308        for attempt in 0..=self.config.max_retries {
309            if !runtime.matches_options(&job.options) {
310                self.recycle_runtime(runtime, &job.options)?;
311            }
312
313            match self.evaluate_once(runtime, job) {
314                Ok(verdict) => {
315                    self.shared.completed.fetch_add(1, Ordering::AcqRel);
316                    runtime.prompts += 1;
317                    if runtime.prompts >= self.config.max_prompts_per_worker {
318                        self.recycle_runtime(runtime, &job.options)?;
319                    }
320                    return Ok(verdict);
321                }
322                Err(PoolError::QuotaExceeded) => {
323                    self.shared.failures.fetch_add(1, Ordering::AcqRel);
324                    return Err(PoolError::QuotaExceeded);
325                }
326                Err(PoolError::RateLimited { retry_after }) => {
327                    let delay =
328                        backoff_delay(attempt, self.config.backoff_base, self.config.backoff_cap);
329                    self.shared.push_rate_limit_event(RateLimitEvent {
330                        worker_id: self.id,
331                        attempt,
332                        delay,
333                        retry_after,
334                    });
335                    last_error = Some(PoolError::RateLimited { retry_after });
336                    if attempt < self.config.max_retries {
337                        thread::sleep(delay);
338                    }
339                }
340                Err(error) => {
341                    last_error = Some(error);
342                    self.recycle_runtime(runtime, &job.options)?;
343                }
344            }
345        }
346
347        self.shared.failures.fetch_add(1, Ordering::AcqRel);
348        Err(last_error.unwrap_or_else(|| PoolError::Rpc("retry loop exhausted".to_string())))
349    }
350
351    fn evaluate_once(
352        &self,
353        runtime: &mut WorkerRuntime,
354        job: &Job,
355    ) -> Result<RubricVerdict, PoolError> {
356        let b64 = encode_png(&job.png_path)?;
357        let system_prompt = job
358            .options
359            .system_prompt
360            .as_deref()
361            .unwrap_or(DEFAULT_SYSTEM_PROMPT);
362        let prompt = format!("{system_prompt}\n\nQuestion: {}", job.question);
363        let text = runtime.acp.prompt_image(&prompt, &b64)?;
364        parse_verdict(&text).map_err(|e| PoolError::ParseVerdict(format!("from {text:?}: {e}")))
365    }
366
367    fn recycle_runtime(
368        &self,
369        runtime: &mut WorkerRuntime,
370        options: &RubricOptions,
371    ) -> Result<(), PoolError> {
372        self.shared.worker_recycles.fetch_add(1, Ordering::AcqRel);
373        let mut last_error = None;
374        for _ in 0..RECYCLE_SPAWN_ATTEMPTS {
375            match self.spawn_runtime(options) {
376                Ok(new_runtime) => {
377                    *runtime = new_runtime;
378                    return Ok(());
379                }
380                Err(error) => {
381                    last_error = Some(error);
382                }
383            }
384        }
385        self.mark_dead();
386        Err(last_error.unwrap_or_else(|| PoolError::Spawn("recycle failed".to_string())))
387    }
388
389    fn spawn_runtime(&self, options: &RubricOptions) -> Result<WorkerRuntime, PoolError> {
390        let codex_home =
391            TempDir::new().map_err(|e| PoolError::Spawn(format!("create CODEX_HOME: {e}")))?;
392        seed_codex_home(codex_home.path(), self.config.source_codex_home.as_deref())?;
393        let mut env = self.config.extra_env.clone();
394        env.push((
395            OsString::from("CODEX_HOME"),
396            codex_home.path().as_os_str().to_os_string(),
397        ));
398
399        let model = options.model.as_deref().unwrap_or(DEFAULT_CODEX_ACP_MODEL);
400        let effort = options
401            .effort
402            .as_deref()
403            .unwrap_or(DEFAULT_CODEX_ACP_REASONING_EFFORT);
404        let mut acp = AcpClient::spawn(&self.config.codex_acp_binary, model, effort, &env, None)?;
405        acp.start_session(None)?;
406
407        Ok(WorkerRuntime {
408            acp,
409            _codex_home: codex_home,
410            prompts: 0,
411            model: model.to_string(),
412            effort: effort.to_string(),
413        })
414    }
415
416    fn mark_dead(&self) {
417        self.shared
418            .alive_mask
419            .fetch_and(!worker_bit(self.id), Ordering::AcqRel);
420    }
421}
422
423impl WorkerRuntime {
424    fn matches_options(&self, options: &RubricOptions) -> bool {
425        options.model.as_deref() == Some(self.model.as_str())
426            && options.effort.as_deref() == Some(self.effort.as_str())
427    }
428}
429
430impl SharedPoolState {
431    fn stats(&self) -> PoolStats {
432        PoolStats {
433            completed: self.completed.load(Ordering::Acquire),
434            failures: self.failures.load(Ordering::Acquire),
435            rate_limit_events: self
436                .rate_limit_events
437                .lock()
438                .map(|events| events.clone())
439                .unwrap_or_default(),
440            worker_recycles: self.worker_recycles.load(Ordering::Acquire),
441        }
442    }
443
444    fn push_rate_limit_event(&self, event: RateLimitEvent) {
445        if let Ok(mut events) = self.rate_limit_events.lock() {
446            events.push(event);
447        }
448    }
449}
450
451fn merge_options(mut opts: RubricOptions, defaults: &RubricOptions) -> RubricOptions {
452    if opts.model.is_none() {
453        opts.model.clone_from(&defaults.model);
454    }
455    if opts.effort.is_none() {
456        opts.effort.clone_from(&defaults.effort);
457    }
458    if opts.system_prompt.is_none() {
459        opts.system_prompt.clone_from(&defaults.system_prompt);
460    }
461    opts
462}
463
464fn backoff_delay(attempt: u32, base: Duration, cap: Duration) -> Duration {
465    let multiplier = 1u32 << attempt.min(6);
466    let capped = base.saturating_mul(multiplier).min(cap);
467    let jitter_cap = capped.as_millis() as u64 / 4;
468    let jitter_ms = rand::rng().random_range(0..=jitter_cap);
469    capped + Duration::from_millis(jitter_ms)
470}
471
472fn alive_mask(workers: usize) -> u64 {
473    if workers == u64::BITS as usize {
474        u64::MAX
475    } else {
476        (1u64 << workers) - 1
477    }
478}
479
480fn worker_bit(worker_id: usize) -> u64 {
481    1u64 << worker_id
482}
483
484fn join_handles(handles: Vec<JoinHandle<()>>) {
485    for handle in handles {
486        let _ = handle.join();
487    }
488}