use std::{
fmt,
panic::{
AssertUnwindSafe,
catch_unwind,
},
sync::Arc,
};
type HookCallback = Arc<dyn Fn(usize) + Send + Sync + 'static>;
#[derive(Clone, Default)]
pub struct ThreadPoolHooks {
before_worker_start: Option<HookCallback>,
after_worker_stop: Option<HookCallback>,
before_task: Option<HookCallback>,
after_task: Option<HookCallback>,
}
impl ThreadPoolHooks {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn before_worker_start<F>(mut self, hook: F) -> Self
where
F: Fn(usize) + Send + Sync + 'static,
{
self.before_worker_start = Some(Arc::new(hook));
self
}
#[inline]
pub fn after_worker_stop<F>(mut self, hook: F) -> Self
where
F: Fn(usize) + Send + Sync + 'static,
{
self.after_worker_stop = Some(Arc::new(hook));
self
}
#[inline]
pub fn before_task<F>(mut self, hook: F) -> Self
where
F: Fn(usize) + Send + Sync + 'static,
{
self.before_task = Some(Arc::new(hook));
self
}
#[inline]
pub fn after_task<F>(mut self, hook: F) -> Self
where
F: Fn(usize) + Send + Sync + 'static,
{
self.after_task = Some(Arc::new(hook));
self
}
#[inline]
pub(crate) fn run_before_worker_start(&self, worker_index: usize) {
Self::run_hook(&self.before_worker_start, worker_index);
}
#[inline]
pub(crate) fn run_after_worker_stop(&self, worker_index: usize) {
Self::run_hook(&self.after_worker_stop, worker_index);
}
#[inline]
pub(crate) fn run_before_task(&self, worker_index: usize) {
Self::run_hook(&self.before_task, worker_index);
}
#[inline]
pub(crate) fn run_after_task(&self, worker_index: usize) {
Self::run_hook(&self.after_task, worker_index);
}
#[inline]
pub(crate) fn has_task_hooks(&self) -> bool {
self.before_task.is_some() || self.after_task.is_some()
}
#[inline]
fn run_hook(hook: &Option<HookCallback>, worker_index: usize) {
if let Some(hook) = hook {
let _ = catch_unwind(AssertUnwindSafe(|| hook(worker_index)));
}
}
}
impl fmt::Debug for ThreadPoolHooks {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ThreadPoolHooks")
.field("before_worker_start", &self.before_worker_start.is_some())
.field("after_worker_stop", &self.after_worker_stop.is_some())
.field("before_task", &self.before_task.is_some())
.field("after_task", &self.after_task.is_some())
.finish()
}
}