use std::{
panic::{AssertUnwindSafe, catch_unwind},
sync::Arc,
};
use super::{task::Task, work_queue::WorkQueue};
pub(super) const DEFAULT_BATCH_SIZE: usize = 16;
pub struct Executor {
work_queue: Arc<WorkQueue>,
batch_size: usize,
executed_count: u64,
failed_count: u64,
}
impl Executor {
#[must_use]
pub const fn new(work_queue: Arc<WorkQueue>) -> Self {
Self {
work_queue,
batch_size: DEFAULT_BATCH_SIZE,
executed_count: 0,
failed_count: 0,
}
}
#[must_use]
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size.max(1);
self
}
pub fn tick(&mut self) -> usize {
let mut processed = 0;
while processed < self.batch_size {
let Some(mut task) = self.work_queue.try_pop() else {
break;
};
if Self::execute_task(&mut task) {
self.executed_count += 1;
} else {
self.failed_count += 1;
}
processed += 1;
}
processed
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn execute_task(task: &mut Task) -> bool {
let result = catch_unwind(AssertUnwindSafe(|| task.execute()));
match result {
Ok(Ok(())) => true,
Ok(Err(_)) => {
task.mark_failed();
false
}
Err(_panic) => {
task.mark_failed();
false
}
}
}
#[must_use]
pub fn schedule(&self, task: Task) -> bool {
self.work_queue.push(task)
}
pub fn spawn<F>(&self, work: F) -> bool
where
F: FnOnce() + Send + 'static,
{
self.schedule(Task::new(work))
}
#[inline]
#[must_use]
pub const fn executed_count(&self) -> u64 {
self.executed_count
}
#[inline]
#[must_use]
pub const fn failed_count(&self) -> u64 {
self.failed_count
}
#[must_use]
pub fn has_pending(&self) -> bool {
!self.work_queue.is_empty()
}
#[inline]
#[must_use]
pub const fn batch_size(&self) -> usize {
self.batch_size
}
pub fn drain(&mut self) -> usize {
let mut total = 0;
while self.has_pending() {
total += self.tick();
}
total
}
pub const fn reset_stats(&mut self) {
self.executed_count = 0;
self.failed_count = 0;
}
}
impl std::fmt::Debug for Executor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Executor")
.field("batch_size", &self.batch_size)
.field("executed_count", &self.executed_count)
.field("failed_count", &self.failed_count)
.field("pending", &self.work_queue.len())
.finish()
}
}