use std::thread::JoinHandle;
use crossbeam_channel::Sender;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub(crate) struct TaskPool {
job_tx: Sender<Job>,
_workers: Vec<JoinHandle<()>>,
}
impl TaskPool {
pub(crate) fn new(name: &'static str, n: usize) -> Self {
let n = n.max(1);
let (job_tx, job_rx) = crossbeam_channel::unbounded::<Job>();
let workers = (0..n)
.map(|_| {
let job_rx = job_rx.clone();
std::thread::Builder::new()
.name(name.to_owned())
.spawn(move || {
for job in job_rx {
if let Err(panic) =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(job))
{
let msg = panic
.downcast_ref::<&'static str>()
.copied()
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic payload>");
log::error!("LSP task pool worker caught panic: {msg}");
}
}
})
.expect("failed to spawn LSP worker thread")
})
.collect();
Self {
job_tx,
_workers: workers,
}
}
pub(crate) fn spawner(&self) -> Spawner {
Spawner(self.job_tx.clone())
}
pub(crate) fn spawn(&self, f: impl FnOnce() + Send + 'static) {
let _ = self.job_tx.send(Box::new(f));
}
}
#[derive(Clone)]
pub(crate) struct Spawner(Sender<Job>);
impl Spawner {
pub(crate) fn spawn(&self, f: impl FnOnce() + Send + 'static) {
let _ = self.0.send(Box::new(f));
}
}
pub(crate) fn read_pool_size() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn every_spawned_job_runs() {
let pool = TaskPool::new("test-pool", 4);
let spawner = pool.spawner();
let (tx, rx) = crossbeam_channel::unbounded::<usize>();
const N: usize = 64;
for i in 0..N {
let tx = tx.clone();
spawner.spawn(move || {
let _ = tx.send(i);
});
}
drop(tx);
let mut seen: Vec<usize> = rx.iter().collect();
seen.sort_unstable();
assert_eq!(seen, (0..N).collect::<Vec<_>>());
}
#[test]
fn panicking_job_does_not_kill_the_pool() {
let pool = TaskPool::new("test-pool-panic", 1);
let spawner = pool.spawner();
let ran = Arc::new(AtomicUsize::new(0));
spawner.spawn(|| panic!("boom"));
let ran2 = Arc::clone(&ran);
let (done_tx, done_rx) = crossbeam_channel::bounded::<()>(1);
spawner.spawn(move || {
ran2.fetch_add(1, Ordering::SeqCst);
let _ = done_tx.send(());
});
done_rx
.recv_timeout(std::time::Duration::from_secs(5))
.expect("survivor job should run after a panicking job");
assert_eq!(ran.load(Ordering::SeqCst), 1);
}
}