Skip to main content

rlx_runtime/
worker_pool.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Worker pool with isolation primitives (plan #36).
17//!
18//! Borrowed from MAX's `serve/worker_interface/`. The serving
19//! pattern: engines run in workers (eventually subprocesses); a
20//! main router forwards requests via IPC. One worker crashing
21//! doesn't take the server down.
22//!
23//! This module ships the in-process layer (testable, deterministic)
24//! plus the trait surface that a future `SubprocessWorker` will
25//! implement. The IPC plumbing (stdin/stdout JSON-lines, recovery
26//! on crash) is intentionally out of scope until a serving binary
27//! consumes it; we'd rather build it once against a real consumer
28//! than build it twice.
29
30use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
31use std::sync::{Arc, Mutex};
32
33/// Stable worker identifier.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
35pub struct WorkerId(pub u32);
36
37#[derive(Debug, Clone, Copy)]
38pub struct WorkerHealth {
39    /// Outstanding requests this worker is processing.
40    pub in_flight: u32,
41    /// Lifetime requests handled (successful + errored).
42    pub completed: u64,
43    /// Lifetime requests that errored.
44    pub errored: u64,
45}
46
47/// Trait every worker implements. `Req` and `Resp` are
48/// caller-defined; the future subprocess flavour will use a
49/// serde-friendly wire type as both parameters.
50pub trait Worker<Req, Resp>: Send + Sync {
51    fn id(&self) -> WorkerId;
52    fn health(&self) -> WorkerHealth;
53    /// Block until this request finishes. Errors propagate the
54    /// engine's failure mode without crashing the worker.
55    fn dispatch(&self, req: Req) -> Result<Resp, WorkerError>;
56}
57
58#[derive(Debug, Clone)]
59pub enum WorkerError {
60    /// The handler returned a domain error (request was bad,
61    /// model rejected it, etc.). Worker stays healthy.
62    Domain { reason: String },
63    /// The worker itself failed (panic, OOM, lost subprocess).
64    /// Pool will route around it on the next request.
65    WorkerCrash { reason: String },
66}
67
68impl std::fmt::Display for WorkerError {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            Self::Domain { reason } => write!(f, "domain error: {reason}"),
72            Self::WorkerCrash { reason } => write!(f, "worker crash: {reason}"),
73        }
74    }
75}
76
77impl std::error::Error for WorkerError {}
78
79/// In-process worker — runs the handler closure inline. Useful
80/// for tests and for single-process serving. Tracks in-flight /
81/// completed / errored counts atomically.
82pub struct InProcessWorker<Req, Resp, F>
83where
84    F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
85{
86    id: WorkerId,
87    handler: F,
88    in_flight: AtomicU32,
89    completed: AtomicU64,
90    errored: AtomicU64,
91    // PhantomData<fn() -> _> is always Send + Sync regardless
92    // of T's bounds — we don't actually own a Req or Resp.
93    _p: std::marker::PhantomData<fn() -> (Req, Resp)>,
94}
95
96impl<Req, Resp, F> InProcessWorker<Req, Resp, F>
97where
98    F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
99{
100    pub fn new(id: WorkerId, handler: F) -> Self {
101        Self {
102            id,
103            handler,
104            in_flight: AtomicU32::new(0),
105            completed: AtomicU64::new(0),
106            errored: AtomicU64::new(0),
107            _p: std::marker::PhantomData,
108        }
109    }
110}
111
112impl<Req, Resp, F> Worker<Req, Resp> for InProcessWorker<Req, Resp, F>
113where
114    Req: Send,
115    Resp: Send,
116    F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
117{
118    fn id(&self) -> WorkerId {
119        self.id
120    }
121
122    fn health(&self) -> WorkerHealth {
123        WorkerHealth {
124            in_flight: self.in_flight.load(Ordering::Relaxed),
125            completed: self.completed.load(Ordering::Relaxed),
126            errored: self.errored.load(Ordering::Relaxed),
127        }
128    }
129
130    fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
131        self.in_flight.fetch_add(1, Ordering::Relaxed);
132        let result = (self.handler)(req);
133        self.in_flight.fetch_sub(1, Ordering::Relaxed);
134        self.completed.fetch_add(1, Ordering::Relaxed);
135        if result.is_err() {
136            self.errored.fetch_add(1, Ordering::Relaxed);
137        }
138        result
139    }
140}
141
142/// Pool dispatch policy.
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
144pub enum DispatchPolicy {
145    /// Round-robin: deterministic, ignores load.
146    RoundRobin,
147    /// Least-loaded: pick the worker with the fewest in-flight
148    /// requests; ties broken by `id`.
149    LeastLoaded,
150}
151
152/// Pool of workers. Generic over `(Req, Resp)`; works with the
153/// `Worker` trait directly.
154pub struct WorkerPool<Req, Resp> {
155    workers: Vec<Arc<dyn Worker<Req, Resp>>>,
156    next_rr: Mutex<usize>,
157    pub policy: DispatchPolicy,
158}
159
160impl<Req, Resp> WorkerPool<Req, Resp> {
161    pub fn new(policy: DispatchPolicy) -> Self {
162        Self {
163            workers: Vec::new(),
164            next_rr: Mutex::new(0),
165            policy,
166        }
167    }
168
169    pub fn add(&mut self, worker: Arc<dyn Worker<Req, Resp>>) {
170        self.workers.push(worker);
171    }
172
173    pub fn len(&self) -> usize {
174        self.workers.len()
175    }
176    pub fn is_empty(&self) -> bool {
177        self.workers.is_empty()
178    }
179
180    /// Pick a worker per `policy`.
181    pub fn select(&self) -> Option<&Arc<dyn Worker<Req, Resp>>> {
182        if self.workers.is_empty() {
183            return None;
184        }
185        match self.policy {
186            DispatchPolicy::RoundRobin => {
187                let mut rr = self.next_rr.lock().unwrap();
188                let pick = *rr % self.workers.len();
189                *rr = (*rr + 1) % self.workers.len();
190                Some(&self.workers[pick])
191            }
192            DispatchPolicy::LeastLoaded => self
193                .workers
194                .iter()
195                .min_by_key(|w| (w.health().in_flight, w.id().0)),
196        }
197    }
198
199    /// Dispatch a request through the chosen worker.
200    pub fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
201        match self.select() {
202            Some(w) => w.dispatch(req),
203            None => Err(WorkerError::WorkerCrash {
204                reason: "no workers available".into(),
205            }),
206        }
207    }
208
209    /// Snapshot of every worker's health.
210    pub fn health(&self) -> Vec<(WorkerId, WorkerHealth)> {
211        let mut h: Vec<_> = self.workers.iter().map(|w| (w.id(), w.health())).collect();
212        h.sort_by_key(|(id, _)| id.0);
213        h
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    fn make_worker(
222        id: u32,
223    ) -> Arc<InProcessWorker<u32, u32, impl Fn(u32) -> Result<u32, WorkerError> + Send + Sync>>
224    {
225        Arc::new(InProcessWorker::new(WorkerId(id), |x: u32| Ok(x * 2)))
226    }
227
228    #[test]
229    fn in_process_worker_handles_dispatch() {
230        let w = make_worker(7);
231        assert_eq!(w.dispatch(5).unwrap(), 10);
232        let h = w.health();
233        assert_eq!(h.completed, 1);
234        assert_eq!(h.errored, 0);
235        assert_eq!(h.in_flight, 0);
236    }
237
238    #[test]
239    fn errors_increment_errored_count() {
240        let w: Arc<InProcessWorker<u32, u32, _>> =
241            Arc::new(InProcessWorker::new(WorkerId(1), |_x: u32| {
242                Err(WorkerError::Domain {
243                    reason: "bad".into(),
244                })
245            }));
246        let _ = w.dispatch(1);
247        let h = w.health();
248        assert_eq!(h.errored, 1);
249        assert_eq!(h.completed, 1);
250    }
251
252    #[test]
253    fn round_robin_visits_each_worker() {
254        let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
255        for i in 0..3 {
256            pool.add(make_worker(i));
257        }
258
259        let mut ids = Vec::new();
260        for _ in 0..6 {
261            let w = pool.select().unwrap();
262            ids.push(w.id().0);
263        }
264        // 6 picks across 3 workers RR → each worker hit twice in
265        // a deterministic 0,1,2,0,1,2 sequence.
266        assert_eq!(ids, vec![0, 1, 2, 0, 1, 2]);
267    }
268
269    #[test]
270    fn least_loaded_picks_quietest() {
271        // Build three workers; bump in_flight on two of them so
272        // the third is the obvious least-loaded pick.
273        let w0 = make_worker(0);
274        let w1 = make_worker(1);
275        let w2 = make_worker(2);
276        // Manually bump w0 + w1 in-flight via fetch_add.
277        w0.in_flight.fetch_add(5, Ordering::Relaxed);
278        w1.in_flight.fetch_add(3, Ordering::Relaxed);
279
280        let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::LeastLoaded);
281        pool.add(w0);
282        pool.add(w1);
283        pool.add(w2);
284
285        let pick = pool.select().unwrap();
286        assert_eq!(
287            pick.id().0,
288            2,
289            "least-loaded should pick the worker with 0 in-flight"
290        );
291    }
292
293    #[test]
294    fn empty_pool_dispatch_errors() {
295        let pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
296        let err = pool.dispatch(1).unwrap_err();
297        assert!(matches!(err, WorkerError::WorkerCrash { .. }));
298    }
299
300    #[test]
301    fn health_snapshot_includes_every_worker() {
302        let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
303        for i in 0..3 {
304            pool.add(make_worker(i));
305        }
306        let _ = pool.dispatch(1);
307        let _ = pool.dispatch(2);
308        let h = pool.health();
309        assert_eq!(h.len(), 3);
310        let total_completed: u64 = h.iter().map(|(_, hh)| hh.completed).sum();
311        assert_eq!(total_completed, 2);
312    }
313}