#![allow(dead_code)]
#![allow(clippy::redundant_pub_crate)]
use crate::error::ExecutorError;
use crossbeam_channel::{Receiver, Sender, bounded};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
enum Job {
Owned(Box<dyn FnOnce() + Send + 'static>),
Borrowed(BorrowedJob),
}
#[allow(unsafe_code)]
pub(crate) struct BorrowedJob(*mut (dyn FnMut() + Send));
impl BorrowedJob {
#[allow(unsafe_code)]
pub(crate) const unsafe fn new(ptr: *mut (dyn FnMut() + Send)) -> Self {
Self(ptr)
}
}
#[allow(unsafe_code)]
unsafe impl Send for BorrowedJob {}
#[derive(Default)]
struct Tracker {
submitted: AtomicUsize,
completed: AtomicUsize,
cv: Condvar,
lock: Mutex<()>,
}
impl Tracker {
fn submit(&self) {
self.submitted.fetch_add(1, Ordering::SeqCst);
}
fn complete(&self) {
self.completed.fetch_add(1, Ordering::SeqCst);
drop(self.lock.lock().unwrap());
self.cv.notify_all();
}
#[allow(clippy::significant_drop_tightening)]
fn wait_for_quiescence(&self) {
let mut g = self.lock.lock().unwrap();
while self.submitted.load(Ordering::SeqCst) != self.completed.load(Ordering::SeqCst) {
g = self.cv.wait(g).unwrap();
}
}
}
pub(crate) struct Pool {
mode: PoolMode,
tracker: Arc<Tracker>,
}
enum PoolMode {
Inline,
Threaded {
tx: Sender<Job>,
handles: Vec<JoinHandle<()>>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
},
}
impl Pool {
pub(crate) fn new(
n_workers: usize,
attrs: crate::thread_attrs::ThreadAttributes,
) -> Result<Self, ExecutorError> {
let tracker = Arc::new(Tracker::default());
if n_workers == 0 {
return Ok(Self {
mode: PoolMode::Inline,
tracker,
});
}
let (tx, rx): (Sender<Job>, Receiver<Job>) = bounded(n_workers * 4);
let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let attrs = Arc::new(attrs);
let mut handles = Vec::with_capacity(n_workers);
for i in 0..n_workers {
let rx = rx.clone();
let tracker = Arc::clone(&tracker);
let shutdown = Arc::clone(&shutdown);
let attrs = Arc::clone(&attrs);
let name = {
#[cfg(feature = "thread_attrs")]
{
attrs
.name_prefix
.as_ref()
.map_or_else(|| format!("taktora-pool-{i}"), |p| format!("{p}-{i}"))
}
#[cfg(not(feature = "thread_attrs"))]
{
format!("taktora-pool-{i}")
}
};
let h = thread::Builder::new()
.name(name)
.spawn(move || {
attrs.apply_to_self(i);
while !shutdown.load(Ordering::Acquire) {
match rx.recv() {
Ok(Job::Owned(f)) => {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
tracker.complete();
}
Ok(Job::Borrowed(b)) => {
#[allow(unsafe_code)]
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(
|| unsafe { (*b.0)() },
));
tracker.complete();
}
Err(_) => break,
}
}
})
.map_err(|e| ExecutorError::Builder(format!("spawn worker: {e}")))?;
handles.push(h);
}
Ok(Self {
mode: PoolMode::Threaded {
tx,
handles,
shutdown,
},
tracker,
})
}
#[track_caller]
pub(crate) fn submit<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.tracker.submit();
match &self.mode {
PoolMode::Inline => {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
self.tracker.complete();
}
PoolMode::Threaded { tx, .. } => {
tx.send(Job::Owned(Box::new(f)))
.expect("pool channel closed");
}
}
}
#[track_caller]
#[allow(unsafe_code)]
pub(crate) unsafe fn submit_borrowed(&self, job: BorrowedJob) {
self.tracker.submit();
match &self.mode {
PoolMode::Inline => {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
(*job.0)();
}));
self.tracker.complete();
}
PoolMode::Threaded { tx, .. } => {
tx.send(Job::Borrowed(job)).expect("pool channel closed");
}
}
}
pub(crate) fn barrier(&self) {
self.tracker.wait_for_quiescence();
}
}
impl Drop for Pool {
fn drop(&mut self) {
if let PoolMode::Threaded {
shutdown,
handles,
tx,
} = &mut self.mode
{
shutdown.store(true, Ordering::Release);
let (closed_tx, _) = bounded::<Job>(0);
let _ = std::mem::replace(tx, closed_tx);
for h in handles.drain(..) {
let _ = h.join();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::thread_attrs::ThreadAttributes;
use std::sync::atomic::AtomicU32;
#[test]
fn inline_pool_runs_synchronously() {
let pool = Pool::new(0, ThreadAttributes::new()).unwrap();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..10 {
let c = Arc::clone(&counter);
pool.submit(move || {
c.fetch_add(1, Ordering::SeqCst);
});
}
pool.barrier();
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[test]
fn threaded_pool_runs_concurrently_and_barriers() {
let pool = Pool::new(4, ThreadAttributes::new()).unwrap();
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..100 {
let c = Arc::clone(&counter);
pool.submit(move || {
std::thread::sleep(std::time::Duration::from_millis(1));
c.fetch_add(1, Ordering::SeqCst);
});
}
pool.barrier();
assert_eq!(counter.load(Ordering::SeqCst), 100);
}
#[test]
fn barrier_with_no_work_returns_immediately() {
let pool = Pool::new(2, ThreadAttributes::new()).unwrap();
pool.barrier();
}
#[test]
fn submitted_panic_is_caught_and_completion_counted() {
let pool = Pool::new(2, ThreadAttributes::new()).unwrap();
pool.submit(|| panic!("kaboom"));
pool.submit(|| {});
pool.barrier();
assert_eq!(
pool.tracker.submitted.load(Ordering::SeqCst),
pool.tracker.completed.load(Ordering::SeqCst),
"submitted vs completed counters diverged after panic"
);
}
}