use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct WorkerId(pub u32);
#[derive(Debug, Clone, Copy)]
pub struct WorkerHealth {
pub in_flight: u32,
pub completed: u64,
pub errored: u64,
}
pub trait Worker<Req, Resp>: Send + Sync {
fn id(&self) -> WorkerId;
fn health(&self) -> WorkerHealth;
fn dispatch(&self, req: Req) -> Result<Resp, WorkerError>;
}
#[derive(Debug, Clone)]
pub enum WorkerError {
Domain { reason: String },
WorkerCrash { reason: String },
}
impl std::fmt::Display for WorkerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Domain { reason } => write!(f, "domain error: {reason}"),
Self::WorkerCrash { reason } => write!(f, "worker crash: {reason}"),
}
}
}
impl std::error::Error for WorkerError {}
pub struct InProcessWorker<Req, Resp, F>
where
F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
{
id: WorkerId,
handler: F,
in_flight: AtomicU32,
completed: AtomicU64,
errored: AtomicU64,
_p: std::marker::PhantomData<fn() -> (Req, Resp)>,
}
impl<Req, Resp, F> InProcessWorker<Req, Resp, F>
where
F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
{
pub fn new(id: WorkerId, handler: F) -> Self {
Self {
id,
handler,
in_flight: AtomicU32::new(0),
completed: AtomicU64::new(0),
errored: AtomicU64::new(0),
_p: std::marker::PhantomData,
}
}
}
impl<Req, Resp, F> Worker<Req, Resp> for InProcessWorker<Req, Resp, F>
where
Req: Send,
Resp: Send,
F: Fn(Req) -> Result<Resp, WorkerError> + Send + Sync,
{
fn id(&self) -> WorkerId {
self.id
}
fn health(&self) -> WorkerHealth {
WorkerHealth {
in_flight: self.in_flight.load(Ordering::Relaxed),
completed: self.completed.load(Ordering::Relaxed),
errored: self.errored.load(Ordering::Relaxed),
}
}
fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
self.in_flight.fetch_add(1, Ordering::Relaxed);
let result = (self.handler)(req);
self.in_flight.fetch_sub(1, Ordering::Relaxed);
self.completed.fetch_add(1, Ordering::Relaxed);
if result.is_err() {
self.errored.fetch_add(1, Ordering::Relaxed);
}
result
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DispatchPolicy {
RoundRobin,
LeastLoaded,
}
pub struct WorkerPool<Req, Resp> {
workers: Vec<Arc<dyn Worker<Req, Resp>>>,
next_rr: Mutex<usize>,
pub policy: DispatchPolicy,
}
impl<Req, Resp> WorkerPool<Req, Resp> {
pub fn new(policy: DispatchPolicy) -> Self {
Self {
workers: Vec::new(),
next_rr: Mutex::new(0),
policy,
}
}
pub fn add(&mut self, worker: Arc<dyn Worker<Req, Resp>>) {
self.workers.push(worker);
}
pub fn len(&self) -> usize {
self.workers.len()
}
pub fn is_empty(&self) -> bool {
self.workers.is_empty()
}
pub fn select(&self) -> Option<&Arc<dyn Worker<Req, Resp>>> {
if self.workers.is_empty() {
return None;
}
match self.policy {
DispatchPolicy::RoundRobin => {
let mut rr = self.next_rr.lock().unwrap();
let pick = *rr % self.workers.len();
*rr = (*rr + 1) % self.workers.len();
Some(&self.workers[pick])
}
DispatchPolicy::LeastLoaded => self
.workers
.iter()
.min_by_key(|w| (w.health().in_flight, w.id().0)),
}
}
pub fn dispatch(&self, req: Req) -> Result<Resp, WorkerError> {
match self.select() {
Some(w) => w.dispatch(req),
None => Err(WorkerError::WorkerCrash {
reason: "no workers available".into(),
}),
}
}
pub fn health(&self) -> Vec<(WorkerId, WorkerHealth)> {
let mut h: Vec<_> = self.workers.iter().map(|w| (w.id(), w.health())).collect();
h.sort_by_key(|(id, _)| id.0);
h
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_worker(
id: u32,
) -> Arc<InProcessWorker<u32, u32, impl Fn(u32) -> Result<u32, WorkerError> + Send + Sync>>
{
Arc::new(InProcessWorker::new(WorkerId(id), |x: u32| Ok(x * 2)))
}
#[test]
fn in_process_worker_handles_dispatch() {
let w = make_worker(7);
assert_eq!(w.dispatch(5).unwrap(), 10);
let h = w.health();
assert_eq!(h.completed, 1);
assert_eq!(h.errored, 0);
assert_eq!(h.in_flight, 0);
}
#[test]
fn errors_increment_errored_count() {
let w: Arc<InProcessWorker<u32, u32, _>> =
Arc::new(InProcessWorker::new(WorkerId(1), |_x: u32| {
Err(WorkerError::Domain {
reason: "bad".into(),
})
}));
let _ = w.dispatch(1);
let h = w.health();
assert_eq!(h.errored, 1);
assert_eq!(h.completed, 1);
}
#[test]
fn round_robin_visits_each_worker() {
let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
for i in 0..3 {
pool.add(make_worker(i));
}
let mut ids = Vec::new();
for _ in 0..6 {
let w = pool.select().unwrap();
ids.push(w.id().0);
}
assert_eq!(ids, vec![0, 1, 2, 0, 1, 2]);
}
#[test]
fn least_loaded_picks_quietest() {
let w0 = make_worker(0);
let w1 = make_worker(1);
let w2 = make_worker(2);
w0.in_flight.fetch_add(5, Ordering::Relaxed);
w1.in_flight.fetch_add(3, Ordering::Relaxed);
let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::LeastLoaded);
pool.add(w0);
pool.add(w1);
pool.add(w2);
let pick = pool.select().unwrap();
assert_eq!(
pick.id().0,
2,
"least-loaded should pick the worker with 0 in-flight"
);
}
#[test]
fn empty_pool_dispatch_errors() {
let pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
let err = pool.dispatch(1).unwrap_err();
assert!(matches!(err, WorkerError::WorkerCrash { .. }));
}
#[test]
fn health_snapshot_includes_every_worker() {
let mut pool: WorkerPool<u32, u32> = WorkerPool::new(DispatchPolicy::RoundRobin);
for i in 0..3 {
pool.add(make_worker(i));
}
let _ = pool.dispatch(1);
let _ = pool.dispatch(2);
let h = pool.health();
assert_eq!(h.len(), 3);
let total_completed: u64 = h.iter().map(|(_, hh)| hh.completed).sum();
assert_eq!(total_completed, 2);
}
}