use std::{
future::Future,
pin::Pin,
sync::{
Arc,
Condvar,
Mutex,
MutexGuard,
},
thread,
};
use qubit_atomic::{
Atomic,
AtomicCount,
};
use qubit_function::Callable;
use crate::{
TaskCompletionPair,
TaskHandle,
TaskRunner,
};
use super::{
ExecutorService,
RejectedExecution,
ShutdownReport,
};
#[derive(Default)]
struct ThreadPerTaskExecutorServiceState {
shutdown: Atomic<bool>,
active_tasks: AtomicCount,
submission_lock: Mutex<()>,
termination_lock: Mutex<()>,
termination: Condvar,
}
impl ThreadPerTaskExecutorServiceState {
#[inline]
fn lock_submission(&self) -> MutexGuard<'_, ()> {
self.submission_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[inline]
fn lock_termination(&self) -> MutexGuard<'_, ()> {
self.termination_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[inline]
fn notify_if_terminated(&self) {
if self.shutdown.load() && self.active_tasks.is_zero() {
self.termination.notify_all();
}
}
fn wait_for_termination(&self) {
let mut guard = self.lock_termination();
while !(self.shutdown.load() && self.active_tasks.is_zero()) {
guard = self
.termination
.wait(guard)
.unwrap_or_else(std::sync::PoisonError::into_inner);
}
}
}
#[derive(Default, Clone)]
pub struct ThreadPerTaskExecutorService {
state: Arc<ThreadPerTaskExecutorServiceState>,
}
impl ThreadPerTaskExecutorService {
#[inline]
pub fn new() -> Self {
Self::default()
}
}
impl ExecutorService for ThreadPerTaskExecutorService {
type Handle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type Termination<'a>
= Pin<Box<dyn Future<Output = ()> + Send + 'a>>
where
Self: 'a;
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::Handle<R, E>, RejectedExecution>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let submission_guard = self.state.lock_submission();
if self.state.shutdown.load() {
return Err(RejectedExecution::Shutdown);
}
self.state.active_tasks.inc();
drop(submission_guard);
let (handle, completion) = TaskCompletionPair::new().into_parts();
let state = Arc::clone(&self.state);
thread::spawn(move || {
TaskRunner::new(task).run(completion);
if state.active_tasks.dec() == 0 {
state.notify_if_terminated();
}
});
Ok(handle)
}
fn shutdown(&self) {
let _guard = self.state.lock_submission();
self.state.shutdown.store(true);
self.state.notify_if_terminated();
}
fn shutdown_now(&self) -> ShutdownReport {
let _guard = self.state.lock_submission();
self.state.shutdown.store(true);
let running = self.state.active_tasks.get();
self.state.notify_if_terminated();
ShutdownReport::new(0, running, 0)
}
#[inline]
fn is_shutdown(&self) -> bool {
self.state.shutdown.load()
}
#[inline]
fn is_terminated(&self) -> bool {
self.is_shutdown() && self.state.active_tasks.is_zero()
}
#[inline]
fn await_termination(&self) -> Self::Termination<'_> {
Box::pin(async move {
self.state.wait_for_termination();
})
}
}