1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
31use std::sync::{Arc, Mutex};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
35pub struct WorkerId(pub u32);
36
37#[derive(Debug, Clone, Copy)]
38pub struct WorkerHealth {
39 pub in_flight: u32,
41 pub completed: u64,
43 pub errored: u64,
45}
46
47pub trait Worker<Req, Resp>: Send + Sync {
51 fn id(&self) -> WorkerId;
52 fn health(&self) -> WorkerHealth;
53 fn dispatch(&self, req: Req) -> Result<Resp, WorkerError>;
56}
57
58#[derive(Debug, Clone)]
59pub enum WorkerError {
60 Domain { reason: String },
63 WorkerCrash { reason: String },
66}
67
68impl std::fmt::Display for WorkerError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Self::Domain { reason } => write!(f, "domain error: {reason}"),
72 Self::WorkerCrash { reason } => write!(f, "worker crash: {reason}"),
73 }
74 }
75}
76
77impl std::error::Error for WorkerError {}
78
79pub struct InProcessWorker<Req, Resp, F>
83where
84 F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
85{
86 id: WorkerId,
87 handler: F,
88 in_flight: AtomicU32,
89 completed: AtomicU64,
90 errored: AtomicU64,
91 _p: std::marker::PhantomData<fn() -> (Req, Resp)>,
94}
95
96impl<Req, Resp, F> InProcessWorker<Req, Resp, F>
97where
98 F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
99{
100 pub fn new(id: WorkerId, handler: F) -> Self {
101 Self {
102 id,
103 handler,
104 in_flight: AtomicU32::new(0),
105 completed: AtomicU64::new(0),
106 errored: AtomicU64::new(0),
107 _p: std::marker::PhantomData,
108 }
109 }
110}
111
112impl<Req, Resp, F> Worker<Req, Resp> for InProcessWorker<Req, Resp, F>
113where
114 Req: Send,
115 Resp: Send,
116 F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
117{
118 fn id(&self) -> WorkerId {
119 self.id
120 }
121
122 fn health(&self) -> WorkerHealth {
123 WorkerHealth {
124 in_flight: self.in_flight.load(Ordering::Relaxed),
125 completed: self.completed.load(Ordering::Relaxed),
126 errored: self.errored.load(Ordering::Relaxed),
127 }
128 }
129
130 fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
131 self.in_flight.fetch_add(1, Ordering::Relaxed);
132 let result = (self.handler)(req);
133 self.in_flight.fetch_sub(1, Ordering::Relaxed);
134 self.completed.fetch_add(1, Ordering::Relaxed);
135 if result.is_err() {
136 self.errored.fetch_add(1, Ordering::Relaxed);
137 }
138 result
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
144pub enum DispatchPolicy {
145 RoundRobin,
147 LeastLoaded,
150}
151
152pub struct WorkerPool<Req, Resp> {
155 workers: Vec<Arc<dyn Worker<Req, Resp>>>,
156 next_rr: Mutex<usize>,
157 pub policy: DispatchPolicy,
158}
159
160impl<Req, Resp> WorkerPool<Req, Resp> {
161 pub fn new(policy: DispatchPolicy) -> Self {
162 Self {
163 workers: Vec::new(),
164 next_rr: Mutex::new(0),
165 policy,
166 }
167 }
168
169 pub fn add(&mut self, worker: Arc<dyn Worker<Req, Resp>>) {
170 self.workers.push(worker);
171 }
172
173 pub fn len(&self) -> usize {
174 self.workers.len()
175 }
176 pub fn is_empty(&self) -> bool {
177 self.workers.is_empty()
178 }
179
180 pub fn select(&self) -> Option<&Arc<dyn Worker<Req, Resp>>> {
182 if self.workers.is_empty() {
183 return None;
184 }
185 match self.policy {
186 DispatchPolicy::RoundRobin => {
187 let mut rr = self.next_rr.lock().unwrap();
188 let pick = *rr % self.workers.len();
189 *rr = (*rr + 1) % self.workers.len();
190 Some(&self.workers[pick])
191 }
192 DispatchPolicy::LeastLoaded => self
193 .workers
194 .iter()
195 .min_by_key(|w| (w.health().in_flight, w.id().0)),
196 }
197 }
198
199 pub fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
201 match self.select() {
202 Some(w) => w.dispatch(req),
203 None => Err(WorkerError::WorkerCrash {
204 reason: "no workers available".into(),
205 }),
206 }
207 }
208
209 pub fn health(&self) -> Vec<(WorkerId, WorkerHealth)> {
211 let mut h: Vec<_> = self.workers.iter().map(|w| (w.id(), w.health())).collect();
212 h.sort_by_key(|(id, _)| id.0);
213 h
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 fn make_worker(
222 id: u32,
223 ) -> Arc<InProcessWorker<u32, u32, impl Fn(u32) -> Result<u32, WorkerError> + Send + Sync>>
224 {
225 Arc::new(InProcessWorker::new(WorkerId(id), |x: u32| Ok(x * 2)))
226 }
227
228 #[test]
229 fn in_process_worker_handles_dispatch() {
230 let w = make_worker(7);
231 assert_eq!(w.dispatch(5).unwrap(), 10);
232 let h = w.health();
233 assert_eq!(h.completed, 1);
234 assert_eq!(h.errored, 0);
235 assert_eq!(h.in_flight, 0);
236 }
237
238 #[test]
239 fn errors_increment_errored_count() {
240 let w: Arc<InProcessWorker<u32, u32, _>> =
241 Arc::new(InProcessWorker::new(WorkerId(1), |_x: u32| {
242 Err(WorkerError::Domain {
243 reason: "bad".into(),
244 })
245 }));
246 let _ = w.dispatch(1);
247 let h = w.health();
248 assert_eq!(h.errored, 1);
249 assert_eq!(h.completed, 1);
250 }
251
252 #[test]
253 fn round_robin_visits_each_worker() {
254 let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
255 for i in 0..3 {
256 pool.add(make_worker(i));
257 }
258
259 let mut ids = Vec::new();
260 for _ in 0..6 {
261 let w = pool.select().unwrap();
262 ids.push(w.id().0);
263 }
264 assert_eq!(ids, vec![0, 1, 2, 0, 1, 2]);
267 }
268
269 #[test]
270 fn least_loaded_picks_quietest() {
271 let w0 = make_worker(0);
274 let w1 = make_worker(1);
275 let w2 = make_worker(2);
276 w0.in_flight.fetch_add(5, Ordering::Relaxed);
278 w1.in_flight.fetch_add(3, Ordering::Relaxed);
279
280 let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::LeastLoaded);
281 pool.add(w0);
282 pool.add(w1);
283 pool.add(w2);
284
285 let pick = pool.select().unwrap();
286 assert_eq!(
287 pick.id().0,
288 2,
289 "least-loaded should pick the worker with 0 in-flight"
290 );
291 }
292
293 #[test]
294 fn empty_pool_dispatch_errors() {
295 let pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
296 let err = pool.dispatch(1).unwrap_err();
297 assert!(matches!(err, WorkerError::WorkerCrash { .. }));
298 }
299
300 #[test]
301 fn health_snapshot_includes_every_worker() {
302 let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
303 for i in 0..3 {
304 pool.add(make_worker(i));
305 }
306 let _ = pool.dispatch(1);
307 let _ = pool.dispatch(2);
308 let h = pool.health();
309 assert_eq!(h.len(), 3);
310 let total_completed: u64 = h.iter().map(|(_, hh)| hh.completed).sum();
311 assert_eq!(total_completed, 2);
312 }
313}