use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
type Task = Box<dyn FnOnce() + Send + 'static>;
pub struct WorkerPool {
workers: Vec<Worker>,
sender: Option<crossbeam_channel::Sender<Task>>,
size: usize,
cancelled: Arc<AtomicBool>,
queued: Arc<AtomicUsize>,
completed: Arc<AtomicUsize>,
}
struct Worker {
#[allow(dead_code)]
id: usize,
handle: Option<thread::JoinHandle<()>>,
}
impl WorkerPool {
pub fn new(size: usize) -> Self {
let capacity = size * 4;
let (sender, receiver) = crossbeam_channel::bounded::<Task>(capacity);
let cancelled = Arc::new(AtomicBool::new(false));
let queued = Arc::new(AtomicUsize::new(0));
let completed = Arc::new(AtomicUsize::new(0));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
let rx = receiver.clone();
let cancelled = Arc::clone(&cancelled);
let queued = Arc::clone(&queued);
let completed = Arc::clone(&completed);
let handle = thread::Builder::new()
.name(format!("zshrs-worker-{}", id))
.spawn(move || {
loop {
let task = match rx.recv() {
Ok(task) => task,
Err(_) => break, };
queued.fetch_sub(1, Ordering::Relaxed);
if cancelled.load(Ordering::Relaxed) {
continue; }
if let Err(e) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task))
{
let msg = if let Some(s) = e.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
tracing::error!(
worker = id,
panic = %msg,
"worker task panicked"
);
}
completed.fetch_add(1, Ordering::Relaxed);
}
tracing::debug!(worker = id, "worker thread exiting");
})
.expect("failed to spawn worker thread");
workers.push(Worker {
id,
handle: Some(handle),
});
}
tracing::info!(
pool_size = size,
channel_capacity = capacity,
"worker pool started"
);
WorkerPool {
workers,
sender: Some(sender),
size,
cancelled,
queued,
completed,
}
}
pub fn default_size() -> Self {
let cpus = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self::new(cpus.clamp(2, 18))
}
pub fn submit<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let depth = self.queued.fetch_add(1, Ordering::Relaxed) + 1;
if depth > self.size * 2 {
tracing::debug!(queue_depth = depth, "worker pool queue building up");
}
self.sender
.as_ref()
.expect("pool shut down")
.send(Box::new(f))
.expect("all workers dead");
}
pub fn submit_with_result<F, R>(&self, f: F) -> crossbeam_channel::Receiver<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (tx, rx) = crossbeam_channel::bounded(1);
self.submit(move || {
let result = f();
let _ = tx.send(result);
});
rx
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
tracing::info!("worker pool: cancel requested");
}
pub fn reset_cancel(&self) {
self.cancelled.store(false, Ordering::Relaxed);
}
pub fn size(&self) -> usize {
self.size
}
pub fn queue_depth(&self) -> usize {
self.queued.load(Ordering::Relaxed)
}
pub fn completed(&self) -> usize {
self.completed.load(Ordering::Relaxed)
}
}
impl Drop for WorkerPool {
fn drop(&mut self) {
self.cancelled.store(true, Ordering::Relaxed);
drop(self.sender.take());
for w in &mut self.workers {
if let Some(handle) = w.handle.take() {
drop(handle);
}
}
tracing::info!(
tasks_completed = self.completed.load(Ordering::Relaxed),
"worker pool shut down"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn wait_for_count(counter: &AtomicUsize, target: usize, max_wait_ms: u64) {
let deadline =
std::time::Instant::now() + std::time::Duration::from_millis(max_wait_ms);
while counter.load(Ordering::Relaxed) < target {
if std::time::Instant::now() >= deadline {
panic!(
"wait_for_count timed out: counter={} target={} after {}ms",
counter.load(Ordering::Relaxed),
target,
max_wait_ms
);
}
std::thread::sleep(std::time::Duration::from_millis(2));
}
}
#[test]
fn test_pool_executes_tasks() {
let pool = WorkerPool::new(2);
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..100 {
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
});
}
wait_for_count(&counter, 100, 5_000);
drop(pool);
assert_eq!(counter.load(Ordering::Relaxed), 100);
}
#[test]
fn test_submit_with_result() {
let pool = WorkerPool::new(2);
let rx = pool.submit_with_result(|| 42);
assert_eq!(rx.recv().unwrap(), 42);
}
#[test]
fn test_default_size() {
let pool = WorkerPool::default_size();
assert!(pool.size() >= 2);
assert!(pool.size() <= 18);
}
#[test]
fn test_panic_does_not_kill_worker() {
let pool = WorkerPool::new(2);
let counter = Arc::new(AtomicUsize::new(0));
pool.submit(|| panic!("intentional test panic"));
for _ in 0..10 {
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
});
}
wait_for_count(&counter, 10, 5_000);
drop(pool);
assert_eq!(counter.load(Ordering::Relaxed), 10);
}
#[test]
fn test_cancel_skips_queued_tasks() {
let pool = WorkerPool::new(1); let barrier = Arc::new(std::sync::Barrier::new(2));
let started = Arc::new(std::sync::Mutex::new(false));
let started_cv = Arc::new(std::sync::Condvar::new());
let counter = Arc::new(AtomicUsize::new(0));
let b = Arc::clone(&barrier);
let started_clone = Arc::clone(&started);
let cv_clone = Arc::clone(&started_cv);
pool.submit(move || {
*started_clone.lock().unwrap() = true;
cv_clone.notify_one();
b.wait();
});
let mut g = started.lock().unwrap();
let timeout = std::time::Duration::from_secs(5);
while !*g {
let (gg, wait_result) = started_cv.wait_timeout(g, timeout).unwrap();
g = gg;
if wait_result.timed_out() && !*g {
panic!("worker never started task #1 within 5s — test scaffolding broken");
}
}
drop(g);
for _ in 0..3 {
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
});
}
pool.cancel();
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(50));
assert_eq!(counter.load(Ordering::Relaxed), 0);
pool.reset_cancel();
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
});
wait_for_count(&counter, 1, 5_000);
drop(pool);
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[test]
fn test_metrics() {
let pool = WorkerPool::new(2);
assert_eq!(pool.completed(), 0);
for _ in 0..10 {
pool.submit(|| {});
}
drop(pool);
}
#[test]
fn test_backpressure_bounded() {
let pool = WorkerPool::new(1);
let counter = Arc::new(AtomicUsize::new(0));
for _ in 0..20 {
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
});
}
wait_for_count(&counter, 20, 5_000);
drop(pool);
assert_eq!(counter.load(Ordering::Relaxed), 20);
}
}