#![forbid(missing_docs)]
use std::{
error::Error,
fmt::{Debug, Display},
io,
num::{NonZeroUsize, TryFromIntError},
panic::{catch_unwind, UnwindSafe},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
mpsc::{channel, Sender},
Arc, Condvar, Mutex,
},
thread::{self, available_parallelism},
};
type ThreadPoolFunctionBoxed = Box<dyn FnOnce() + Send + UnwindSafe>;
#[derive(Debug)]
pub struct JobHasPanicedError {}
impl Display for JobHasPanicedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "At least one job in the threadpool has caused a panic")
}
}
impl Error for JobHasPanicedError {}
#[derive(Debug, Default)]
struct SharedState {
jobs_queued: AtomicUsize,
jobs_running: AtomicUsize,
jobs_paniced: AtomicUsize,
is_finished: Mutex<bool>,
has_paniced: AtomicBool,
}
impl Display for SharedState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SharedState<jobs_queued: {}, jobs_running: {}, jobs_paniced: {}, is_finished: {}, has_paniced: {}>",
self.jobs_queued.load(Ordering::Relaxed),
self.jobs_running.load(Ordering::Relaxed),
self.jobs_paniced.load(Ordering::Relaxed),
self.is_finished.lock().expect("Shared state should never panic"),
self.has_paniced.load(Ordering::Relaxed)
)
}
}
impl SharedState {
fn new() -> Self {
Self {
jobs_running: AtomicUsize::new(0),
jobs_queued: AtomicUsize::new(0),
jobs_paniced: AtomicUsize::new(0),
is_finished: Mutex::new(true),
has_paniced: AtomicBool::new(false),
}
}
fn job_starting(&self) {
debug_assert!(
self.jobs_queued.load(Ordering::Acquire) > 0,
"Negative jobs queued"
);
self.jobs_running.fetch_add(1, Ordering::SeqCst);
self.jobs_queued.fetch_sub(1, Ordering::SeqCst);
}
fn job_finished(&self) {
debug_assert!(
self.jobs_running.load(Ordering::Acquire) > 0,
"Negative jobs running"
);
self.jobs_running.fetch_sub(1, Ordering::SeqCst);
if self.jobs_queued.load(Ordering::Acquire) == 0
&& self.jobs_running.load(Ordering::Acquire) == 0
{
let mut is_finished = self
.is_finished
.lock()
.expect("Shared state should never panic");
*is_finished = true;
}
}
fn job_queued(&self) {
self.jobs_queued.fetch_add(1, Ordering::SeqCst);
let mut is_finished = self
.is_finished
.lock()
.expect("Shared state should never panic");
*is_finished = false;
}
fn job_paniced(&self) {
println!("Checking panic");
self.has_paniced.store(true, Ordering::SeqCst);
self.jobs_paniced.fetch_add(1, Ordering::SeqCst);
println!("Has paniced {}", self.has_paniced.load(Ordering::Acquire));
}
}
#[derive(Debug)]
pub struct ThreadPool {
thread_amount: NonZeroUsize,
job_sender: Arc<Sender<ThreadPoolFunctionBoxed>>,
shared_state: Arc<SharedState>,
cvar: Arc<Condvar>,
}
impl Clone for ThreadPool {
fn clone(&self) -> Self {
Self {
thread_amount: self.thread_amount,
job_sender: self.job_sender.clone(),
shared_state: self.shared_state.clone(),
cvar: self.cvar.clone(),
}
}
}
impl Display for ThreadPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Threadpool< thread_amount: {}, shared_state: {}>",
self.thread_amount, self.shared_state
)
}
}
impl ThreadPool {
fn new(builder: ThreadPoolBuilder) -> io::Result<Self> {
let thread_amount = builder.thread_amount;
let (job_sender, job_receiver) = channel::<ThreadPoolFunctionBoxed>();
let job_sender = Arc::new(job_sender);
let shareable_job_reciever = Arc::new(Mutex::new(job_receiver));
let shared_state = Arc::new(SharedState::new());
let cvar = Arc::new(Condvar::new());
for thread_num in 0..thread_amount.get() {
let job_reciever = shareable_job_reciever.clone();
let thread_name = format!("Threadpool worker {thread_num}");
thread::Builder::new().name(thread_name).spawn(move || {
loop {
let job = {
let lock = job_reciever .lock()
.expect("Cannot get reciever");
lock.recv()
};
match job {
Ok(job) => job(),
Err(_) => break,
};
}
})?;
}
Ok(Self {
thread_amount,
job_sender,
shared_state,
cvar,
})
}
pub fn send_job(&self, job: impl FnOnce() + Send + UnwindSafe + 'static) {
self.shared_state.job_queued();
debug_assert!(self.jobs_queued() > 0, "Job didn't queue properly");
debug_assert!(!self.is_finished(), "Finish wasn't properly set to false");
let state = self.shared_state.clone();
let cvar = self.cvar.clone();
let job_with_state = Self::job_function(Box::new(job), state, cvar);
self.job_sender
.send(Box::new(job_with_state))
.expect("The sender cannot be deallocated while the threadpool is in use")
}
fn job_function(
job: ThreadPoolFunctionBoxed,
state: Arc<SharedState>,
cvar: Arc<Condvar>,
) -> impl FnOnce() + Send + 'static {
move || {
state.job_starting();
let result = catch_unwind(job);
println!("{result:?}");
if result.is_err() {
state.job_paniced();
}
state.job_finished();
cvar.notify_all();
}
}
pub fn wait_until_finished(&self) -> Result<(), JobHasPanicedError> {
let mut is_finished = self
.shared_state
.is_finished
.lock()
.expect("Shared state should never panic");
while !*is_finished && !self.has_paniced() {
is_finished = self
.cvar
.wait(is_finished)
.expect("Shared state should never panic");
}
println!("panic {}", self.has_paniced());
debug_assert!(
self.has_paniced() || self.jobs_running() == 0,
"wait_until_finished stopped {} jobs running and {} panics",
self.jobs_running(),
self.jobs_paniced()
);
debug_assert!(
self.has_paniced() || self.jobs_queued() == 0,
"wait_until_finished stopped while {} jobs queued and {} panics",
self.jobs_queued(),
self.jobs_paniced()
);
println!("WERE DONE WAITING");
match self.shared_state.has_paniced.load(Ordering::Acquire) {
true => Err(JobHasPanicedError {}),
false => Ok(()),
}
}
pub fn wait_until_job_done(&self) -> Result<(), JobHasPanicedError> {
fn paniced(state: &SharedState) -> bool {
state.jobs_paniced.load(Ordering::Acquire) != 0
}
let is_finished = self
.shared_state
.is_finished
.lock()
.expect("Shared state should never panic");
if *is_finished {
return Ok(());
};
drop(self.cvar.wait(is_finished));
if paniced(&self.shared_state) {
Err(JobHasPanicedError {})
} else {
Ok(())
}
}
pub fn wait_until_finished_unchecked(&self) {
let mut is_finished = self
.shared_state
.is_finished
.lock()
.expect("Shared state sould never panic");
if *is_finished {
return;
}
while !*is_finished {
is_finished = self
.cvar
.wait(is_finished)
.expect("Shared state should never panic")
}
debug_assert!(
self.shared_state.jobs_running.load(Ordering::Acquire) == 0,
"Job still running after wait_until_finished_unchecked"
);
debug_assert!(
self.shared_state.jobs_queued.load(Ordering::Acquire) == 0,
"Job still queued after wait_until_finished_unchecked"
);
}
pub fn wait_until_job_done_unchecked(&self) {
let is_finished = self
.shared_state
.is_finished
.lock()
.expect("Shared state should never panic");
if *is_finished {
return;
};
drop(self.cvar.wait(is_finished));
}
pub fn reset_state(&mut self) {
let cvar = Arc::new(Condvar::new());
let shared_state = Arc::new(SharedState::new());
self.cvar = cvar;
self.shared_state = shared_state;
}
pub fn clone_with_new_state(&self) -> Self {
let mut new_pool = self.clone();
new_pool.reset_state();
new_pool
}
pub fn jobs_running(&self) -> usize {
self.shared_state.jobs_running.load(Ordering::Acquire)
}
pub fn jobs_queued(&self) -> usize {
self.shared_state.jobs_queued.load(Ordering::Acquire)
}
pub fn jobs_paniced(&self) -> usize {
self.shared_state.jobs_paniced.load(Ordering::Acquire)
}
pub fn has_paniced(&self) -> bool {
self.shared_state.has_paniced.load(Ordering::Acquire)
}
pub fn is_finished(&self) -> bool {
*self
.shared_state
.is_finished
.lock()
.expect("Shared state should never panic")
}
pub const fn threads(&self) -> NonZeroUsize {
self.thread_amount
}
}
pub struct ThreadPoolBuilder {
thread_amount: NonZeroUsize,
}
impl Default for ThreadPoolBuilder {
fn default() -> Self {
Self {
thread_amount: NonZeroUsize::try_from(1).unwrap(),
}
}
}
impl ThreadPoolBuilder {
pub fn with_thread_amount(thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
ThreadPoolBuilder { thread_amount }
}
pub fn with_thread_amount_usize(
thread_amount: usize,
) -> Result<ThreadPoolBuilder, TryFromIntError> {
let thread_amount = NonZeroUsize::try_from(thread_amount)?;
Ok(Self::with_thread_amount(thread_amount))
}
pub fn with_max_threads() -> io::Result<ThreadPoolBuilder> {
let max_threads = available_parallelism()?;
Ok(ThreadPoolBuilder {
thread_amount: max_threads,
})
}
pub fn set_thread_amount(mut self, thread_amount: NonZeroUsize) -> ThreadPoolBuilder {
self.thread_amount = thread_amount;
self
}
pub fn set_thread_amount_usize(
self,
thread_amount: usize,
) -> Result<ThreadPoolBuilder, TryFromIntError> {
let thread_amount = NonZeroUsize::try_from(thread_amount)?;
Ok(self.set_thread_amount(thread_amount))
}
pub fn set_max_threads(mut self) -> io::Result<ThreadPoolBuilder> {
let max_threads = available_parallelism()?;
self.thread_amount = max_threads;
Ok(self)
}
pub fn build(self) -> io::Result<ThreadPool> {
ThreadPool::new(self)
}
}
#[cfg(test)]
mod test {
use core::panic;
use std::{
num::NonZeroUsize,
sync::{mpsc::channel, Arc, Barrier},
thread::sleep,
time::Duration,
};
use crate::ThreadPoolBuilder;
#[test]
fn deal_with_panics() {
fn panic_fn() {
panic!("Test panic");
}
let thread_num: NonZeroUsize = 1.try_into().unwrap();
let builder = ThreadPoolBuilder::with_thread_amount(thread_num);
let pool = builder.build().unwrap();
for _ in 0..10 {
pool.send_job(panic_fn);
}
assert!(
pool.wait_until_finished().is_err(),
"Pool didn't detect panic in wait_until_finished"
);
assert!(
pool.has_paniced(),
"Pool didn't detect panic in has_paniced"
);
pool.wait_until_finished_unchecked();
assert!(
pool.jobs_queued() == 0,
"Incorrect amount of jobs queued after wait"
);
assert!(
pool.jobs_running() == 0,
"Incorrect amount of jobs running after wait"
);
assert!(
pool.jobs_paniced() == 10,
"Incorrect amount of jobs paniced after wait"
);
}
#[test]
fn receive_value() {
let (tx, rx) = channel::<u32>();
let func = move || {
tx.send(69).unwrap();
};
let pool = ThreadPoolBuilder::default().build().unwrap();
pool.send_job(func);
assert_eq!(rx.recv(), Ok(69), "Incorrect value received");
}
#[test]
fn test_wait() {
const TASKS: usize = 1000;
const THREADS: usize = 16;
let b0 = Arc::new(Barrier::new(THREADS + 1));
let b1 = Arc::new(Barrier::new(THREADS + 1));
let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
.unwrap()
.build()
.unwrap();
for i in 0..TASKS {
let b0 = b0.clone();
let b1 = b1.clone();
pool.send_job(move || {
if i < THREADS {
b0.wait();
b1.wait();
}
});
}
b0.wait();
assert_eq!(
pool.jobs_running(),
THREADS,
"Incorrect amount of jobs running"
);
assert_eq!(
pool.jobs_paniced(),
0,
"Incorrect amount of threads paniced"
);
b1.wait();
assert!(
pool.wait_until_finished().is_ok(),
"wait_until_finished incorrectly detected a panic"
);
assert_eq!(
pool.jobs_queued(),
0,
"Incorrect amount of jobs queued after wait"
);
assert_eq!(
pool.jobs_running(),
0,
"Incorrect amount of jobs running after wait"
);
assert_eq!(
pool.jobs_paniced(),
0,
"Incorrect amount of threads paniced after wait"
);
}
#[test]
fn test_wait_unchecked() {
const TASKS: usize = 1000;
const THREADS: usize = 16;
let b0 = Arc::new(Barrier::new(THREADS + 1));
let b1 = Arc::new(Barrier::new(THREADS + 1));
let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
let pool = builder.build().unwrap();
for i in 0..TASKS {
let b0 = b0.clone();
let b1 = b1.clone();
pool.send_job(move || {
if i < THREADS {
b0.wait();
b1.wait();
}
panic!("Test panic");
});
}
b0.wait();
assert_eq!(
pool.jobs_running(),
THREADS,
"Incorrect amount of jobs running"
);
assert_eq!(pool.jobs_paniced(), 0);
b1.wait();
pool.wait_until_finished_unchecked();
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), TASKS);
}
#[test]
fn test_clones() {
const TASKS: usize = 1000;
const THREADS: usize = 16;
let pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
.unwrap()
.build()
.unwrap();
let clone = pool.clone();
let clone_with_new_state = pool.clone_with_new_state();
let b0 = Arc::new(Barrier::new(THREADS + 1));
let b1 = Arc::new(Barrier::new(THREADS + 1));
for i in 0..TASKS {
let b0_copy = b0.clone();
let b1_copy = b1.clone();
pool.send_job(move || {
if i < THREADS / 2 {
b0_copy.wait();
b1_copy.wait();
}
});
let b0_copy = b0.clone();
let b1_copy = b1.clone();
clone_with_new_state.send_job(move || {
if i < THREADS / 2 {
b0_copy.wait();
b1_copy.wait();
}
panic!("Test panic")
});
}
b0.wait();
assert_eq!(
pool.jobs_running(),
THREADS / 2,
"Incorrect amount of jobs running in pool"
);
assert_eq!(
pool.jobs_paniced(),
0,
"Incorrect amount of jobs paniced in pool"
);
assert_eq!(
clone_with_new_state.jobs_running(),
THREADS / 2,
"Incorrect amount of jobs running in clone_with_new_state"
);
assert_eq!(
clone_with_new_state.jobs_paniced(),
0,
"Incorrect amount of jobs paniced in clone_with_new_state"
);
b1.wait();
assert!(
clone_with_new_state.wait_until_finished().is_err(),
"Clone with new state didn't detect panic"
);
assert!(
clone.wait_until_finished().is_ok(),
"Pool incorrectly detected panic"
);
assert_eq!(
pool.jobs_queued(),
0,
"Incorrect amount of jobs queued in pool after wait"
);
assert_eq!(
pool.jobs_running(),
0,
"Incorrect amount of jobs running in pool after wait"
);
assert_eq!(
pool.jobs_paniced(),
0,
"Incorrect amount of jobs paniced in pool after wait"
);
clone_with_new_state.wait_until_finished_unchecked();
assert!(
clone_with_new_state.wait_until_finished().is_err(),
"clone_with_new_state didn't detect panics after wait"
);
assert_eq!(
clone_with_new_state.jobs_queued(),
0,
"Incorrect amount of jobs queued in clone_with_new_state after wait"
);
assert_eq!(
clone_with_new_state.jobs_running(),
0,
"Incorrect amount of jobs running in clone_with_new_state after wait"
);
assert_eq!(
clone_with_new_state.jobs_paniced(),
TASKS,
"Incorrect panics in clone"
);
assert_eq!(
pool.jobs_queued(),
0,
"Incorrect amount of jobs queued in pool after everything"
);
assert_eq!(
pool.jobs_running(),
0,
"Incorrect amount of jobs running in pool after everything"
);
assert_eq!(
pool.jobs_paniced(),
0,
"Incorrect amount of jobs paniced in pool after everything"
);
}
#[test]
fn reset_state_while_running() {
const TASKS: usize = 32;
const THREADS: usize = 16;
let mut pool = ThreadPoolBuilder::with_thread_amount_usize(THREADS)
.unwrap()
.build()
.unwrap();
let b0 = Arc::new(Barrier::new(THREADS + 1));
let b1 = Arc::new(Barrier::new(THREADS + 1));
for i in 0..TASKS {
let b0_copy = b0.clone();
let b1_copy = b1.clone();
pool.send_job(move || {
if i < THREADS {
b0_copy.wait();
b1_copy.wait();
}
});
}
b0.wait();
assert_ne!(pool.jobs_queued(), 0);
assert_ne!(pool.jobs_running(), 0);
pool.reset_state();
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
b1.wait();
pool.wait_until_finished().expect("Nothing should panic");
sleep(Duration::from_secs(1));
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
}
#[test]
fn reset_panic_test() {
const TASKS: usize = 32;
const THREADS: usize = 16;
let num = NonZeroUsize::try_from(THREADS).unwrap();
let mut pool = ThreadPoolBuilder::with_thread_amount(num).build().unwrap();
let b0 = Arc::new(Barrier::new(THREADS + 1));
let b1 = Arc::new(Barrier::new(THREADS + 1));
for i in 0..TASKS {
let b0_copy = b0.clone();
let b1_copy = b1.clone();
pool.send_job(move || {
if i < THREADS {
b0_copy.wait();
b1_copy.wait();
}
panic!("Test panic");
});
}
b0.wait();
assert_ne!(pool.jobs_queued(), 0);
assert_ne!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
pool.reset_state();
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
b1.wait();
pool.wait_until_finished().expect("Nothing should panic");
sleep(Duration::from_secs(1));
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
}
#[test]
fn test_wait_until_job_done() {
const THREADS: usize = 1;
let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
let pool = builder.build().unwrap();
assert!(pool.wait_until_job_done().is_ok());
pool.send_job(|| {});
assert!(pool.wait_until_job_done().is_ok());
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
pool.send_job(|| panic!("Test panic"));
assert!(pool.wait_until_job_done().is_err());
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 1);
}
#[test]
fn test_wait_until_job_done_unchecked() {
const THREADS: usize = 1;
let builder = ThreadPoolBuilder::with_thread_amount_usize(THREADS).unwrap();
let pool = builder.build().unwrap();
pool.wait_until_job_done_unchecked();
pool.send_job(|| {});
pool.wait_until_job_done_unchecked();
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 0);
pool.send_job(|| panic!("Test panic"));
pool.wait_until_job_done_unchecked();
assert_eq!(pool.jobs_queued(), 0);
assert_eq!(pool.jobs_running(), 0);
assert_eq!(pool.jobs_paniced(), 1);
}
#[test]
#[allow(dead_code)]
fn test_flakiness() {
for _ in 0..10 {
test_wait();
test_wait_unchecked();
deal_with_panics();
receive_value();
test_clones();
reset_state_while_running();
test_wait_until_job_done_unchecked();
test_wait_until_job_done();
reset_panic_test();
}
}
}