use std::sync::Arc;
use parking_lot::{
Condvar,
Mutex,
};
use qubit_function::{
Callable,
Runnable,
};
use crate::executor::thread_spawn_config::ThreadSpawnConfig;
use crate::{
TaskHandle,
TrackedTask,
hook::{
TaskHook,
notify_rejected,
notify_rejected_optional,
},
task::{
spi::{
TaskEndpointPair,
TaskSlot,
},
task_admission_gate::TaskAdmissionGate,
},
};
use super::{
ExecutorService,
ExecutorServiceLifecycle,
StopReport,
SubmissionError,
ThreadPerTaskExecutorServiceBuilder,
};
type Worker = Box<dyn FnOnce() + Send + 'static>;
trait TaskAdmissionHandle {
fn mark_accepted(&self);
}
impl<R, E> TaskAdmissionHandle for TaskHandle<R, E> {
#[inline]
fn mark_accepted(&self) {
self.accept();
}
}
impl<R, E> TaskAdmissionHandle for TrackedTask<R, E> {
#[inline]
fn mark_accepted(&self) {
self.accept();
}
}
#[derive(Debug, Clone, Copy)]
struct ServiceState {
lifecycle: ExecutorServiceLifecycle,
active_tasks: usize,
}
impl Default for ServiceState {
#[inline]
fn default() -> Self {
Self {
lifecycle: ExecutorServiceLifecycle::Running,
active_tasks: 0,
}
}
}
#[derive(Default)]
struct ThreadPerTaskExecutorServiceState {
state: Mutex<ServiceState>,
termination: Condvar,
}
struct ActiveTaskGuard {
state: Arc<ThreadPerTaskExecutorServiceState>,
}
impl ActiveTaskGuard {
#[inline]
fn new(state: Arc<ThreadPerTaskExecutorServiceState>) -> Self {
Self { state }
}
}
impl Drop for ActiveTaskGuard {
#[inline]
fn drop(&mut self) {
self.state.finish_task();
}
}
impl ThreadPerTaskExecutorServiceState {
#[inline]
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.state.lock().lifecycle
}
#[inline]
fn accept_task(&self) -> Result<(), SubmissionError> {
let mut state = self.state.lock();
if state.lifecycle != ExecutorServiceLifecycle::Running {
return Err(SubmissionError::Shutdown);
}
state.active_tasks += 1;
Ok(())
}
#[inline]
fn finish_task(&self) {
let mut state = self.state.lock();
state.active_tasks -= 1;
Self::terminate_if_ready(&mut state, &self.termination);
}
fn wait_for_termination(&self) {
let mut state = self.state.lock();
while state.lifecycle != ExecutorServiceLifecycle::Terminated {
self.termination.wait(&mut state);
}
}
#[inline]
fn shutdown(&self) {
let mut state = self.state.lock();
if state.lifecycle == ExecutorServiceLifecycle::Running {
state.lifecycle = ExecutorServiceLifecycle::ShuttingDown;
}
Self::terminate_if_ready(&mut state, &self.termination);
}
#[inline]
fn stop(&self) -> usize {
let mut state = self.state.lock();
if state.lifecycle != ExecutorServiceLifecycle::Terminated {
state.lifecycle = ExecutorServiceLifecycle::Stopping;
}
let running = state.active_tasks;
Self::terminate_if_ready(&mut state, &self.termination);
running
}
#[inline]
fn terminate_if_ready(state: &mut ServiceState, termination: &Condvar) {
if state.lifecycle != ExecutorServiceLifecycle::Running && state.active_tasks == 0 {
state.lifecycle = ExecutorServiceLifecycle::Terminated;
termination.notify_all();
}
}
}
#[derive(Clone)]
pub struct ThreadPerTaskExecutorService {
state: Arc<ThreadPerTaskExecutorServiceState>,
stack_size: Option<usize>,
pub(crate) hook: Option<Arc<dyn TaskHook>>,
}
impl Default for ThreadPerTaskExecutorService {
#[inline]
fn default() -> Self {
Self {
state: Arc::default(),
stack_size: None,
hook: None,
}
}
}
impl ThreadPerTaskExecutorService {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub(crate) fn from_stack_size(stack_size: Option<usize>) -> Self {
Self {
state: Arc::default(),
stack_size,
hook: None,
}
}
#[inline]
pub fn builder() -> ThreadPerTaskExecutorServiceBuilder {
ThreadPerTaskExecutorServiceBuilder::new()
}
fn spawn_worker_after_accept(&self, worker: Worker) -> Result<(), SubmissionError> {
ThreadSpawnConfig::new(self.stack_size).spawn(worker)
}
#[inline]
fn notify_rejected(&self, error: &SubmissionError) {
if let Some(hook) = &self.hook {
notify_rejected(hook.as_ref(), error);
}
}
fn submit_with_slot<R, E, H, S, F>(
&self,
split_pair: S,
run_slot: F,
) -> Result<H, SubmissionError>
where
R: Send + 'static,
E: Send + 'static,
H: TaskAdmissionHandle,
S: FnOnce(TaskEndpointPair<R, E>) -> (H, TaskSlot<R, E>),
F: FnOnce(TaskSlot<R, E>) + Send + 'static,
{
if let Err(error) = self.state.accept_task() {
self.notify_rejected(&error);
return Err(error);
}
let pair = TaskEndpointPair::with_optional_hook(self.hook.clone());
let (handle, slot) = split_pair(pair);
let guard = ActiveTaskGuard::new(Arc::clone(&self.state));
let gate = TaskAdmissionGate::new(self.hook.is_some());
let worker_gate = gate.clone();
let hook = self.hook.clone();
if let Err(error) = self.spawn_worker_after_accept(Box::new(move || {
worker_gate.wait();
let _guard = guard;
run_slot(slot);
})) {
notify_rejected_optional(hook.as_ref(), &error);
return Err(error);
}
handle.mark_accepted();
gate.open();
Ok(handle)
}
}
impl ExecutorService for ThreadPerTaskExecutorService {
type ResultHandle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type TrackedHandle<R, E>
= TrackedTask<R, E>
where
R: Send + 'static,
E: Send + 'static;
fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
let handle = self.submit_with_slot(
|pair| pair.into_parts(),
move |slot| {
let mut task = task;
slot.run(move || task.run());
},
)?;
drop(handle);
Ok(())
}
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
self.submit_with_slot(
|pair| pair.into_parts(),
move |slot| {
slot.run(task);
},
)
}
fn submit_tracked_callable<C, R, E>(
&self,
task: C,
) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
self.submit_with_slot(
|pair| pair.into_tracked_parts(),
move |slot| {
slot.run(task);
},
)
}
fn shutdown(&self) {
self.state.shutdown();
}
fn stop(&self) -> StopReport {
let running = self.state.stop();
StopReport::new(0, running, 0)
}
#[inline]
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.state.lifecycle()
}
#[inline]
fn wait_termination(&self) {
self.state.wait_for_termination();
}
}