use std::sync::Arc;
use qubit_function::Callable;
use crate::{
TrackedTask,
hook::{
TaskHook,
notify_rejected_optional,
},
service::SubmissionError,
task::{
spi::TaskEndpointPair,
task_admission_gate::TaskAdmissionGate,
},
};
use super::{
Executor,
ThreadPerTaskExecutorBuilder,
thread_spawn_config::ThreadSpawnConfig,
};
#[derive(Clone)]
pub struct ThreadPerTaskExecutor {
pub(crate) stack_size: Option<usize>,
pub(crate) hook: Option<Arc<dyn TaskHook>>,
}
impl ThreadPerTaskExecutor {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn builder() -> ThreadPerTaskExecutorBuilder {
ThreadPerTaskExecutorBuilder::new()
}
#[inline]
pub fn with_hook(mut self, hook: Arc<dyn TaskHook>) -> Self {
self.hook = Some(hook);
self
}
fn spawn_worker(&self, worker: impl FnOnce() + Send + 'static) -> Result<(), SubmissionError> {
ThreadSpawnConfig::new(self.stack_size).spawn(worker)
}
}
impl Default for ThreadPerTaskExecutor {
#[inline]
fn default() -> Self {
Self {
stack_size: None,
hook: None,
}
}
}
impl Executor for ThreadPerTaskExecutor {
fn call<C, R, E>(&self, task: C) -> Result<TrackedTask<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, slot) =
TaskEndpointPair::with_optional_hook(self.hook.clone()).into_tracked_parts();
let gate = TaskAdmissionGate::new(self.hook.is_some());
let worker_gate = gate.clone();
let hook = self.hook.clone();
self.spawn_worker(move || {
worker_gate.wait();
slot.run(task);
})
.inspect_err(|error| notify_rejected_optional(hook.as_ref(), error))?;
handle.accept();
gate.open();
Ok(handle)
}
}