use std::{
collections::HashMap,
panic::{
AssertUnwindSafe,
catch_unwind,
resume_unwind,
},
sync::{
Arc,
Condvar,
Mutex,
MutexGuard,
},
};
use qubit_function::{
Callable,
Runnable,
};
use qubit_executor::TaskHandle;
use qubit_executor::service::{
ExecutorService,
StopReport,
};
use qubit_executor::task::spi::{
TaskEndpointPair,
TaskSlot,
};
use qubit_thread_pool::{
ExecutorServiceBuilderError,
PoolJob,
ThreadPool,
};
use super::{
task_execution_service_builder::TaskExecutionServiceBuilder,
task_execution_service_error::TaskExecutionServiceError,
task_execution_stats::TaskExecutionStats,
task_id::TaskId,
task_status::TaskStatus,
};
pub struct TaskExecutionService {
pool: ThreadPool,
state: Arc<TaskExecutionServiceState>,
}
impl TaskExecutionService {
pub fn new() -> Result<Self, ExecutorServiceBuilderError> {
Self::builder().build()
}
#[inline]
pub fn builder() -> TaskExecutionServiceBuilder {
TaskExecutionServiceBuilder::default()
}
pub(crate) fn from_thread_pool(pool: ThreadPool) -> Self {
Self {
pool,
state: Arc::new(TaskExecutionServiceState::default()),
}
}
#[inline]
pub fn submit<T, E>(
&self,
task_id: TaskId,
mut task: T,
) -> Result<TaskHandle<(), E>, TaskExecutionServiceError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
self.submit_callable(task_id, move || task.run())
}
pub fn submit_callable<C, R, E>(
&self,
task_id: TaskId,
task: C,
) -> Result<TaskHandle<R, E>, TaskExecutionServiceError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, slot) = TaskEndpointPair::new().into_parts();
let slot = Arc::new(Mutex::new(Some(slot)));
let accept_slot = Arc::clone(&slot);
let cancel_slot = Arc::clone(&slot);
let run_slot = Arc::clone(&slot);
let cancel_state = Arc::clone(&self.state);
let cancel: Arc<dyn Fn() -> bool + Send + Sync> = Arc::new(move || {
let slot = cancel_slot
.lock()
.expect("task slot lock should not be poisoned")
.take();
let cancelled = slot.is_some_and(TaskSlot::cancel_unstarted);
if cancelled {
cancel_state.set_status(task_id, TaskStatus::Cancelled);
}
cancelled
});
self.state.register(task_id, Arc::clone(&cancel))?;
let run_state = Arc::clone(&self.state);
let cancel_for_job = Arc::clone(&cancel);
let job = PoolJob::with_accept(
Box::new(move || {
if let Some(slot) = accept_slot
.lock()
.expect("task slot lock should not be poisoned")
.as_ref()
{
slot.accept();
}
}),
Box::new(move || {
let slot = run_slot
.lock()
.expect("task slot lock should not be poisoned")
.take();
if let Some(slot) = slot {
let task = StatusReportingTask {
task_id,
task,
state: run_state,
};
if !slot.run(task) {
cancel_for_job();
}
}
}),
Box::new(move || {
cancel();
}),
);
if let Err(error) = self.pool.submit_job(job) {
self.state.remove(task_id);
return Err(error.into());
}
Ok(handle)
}
pub fn cancel(&self, task_id: TaskId) -> bool {
let cancel = self.state.cancel_callback(task_id);
cancel.is_some_and(|cancel| cancel())
}
#[inline]
pub fn status(&self, task_id: TaskId) -> Option<TaskStatus> {
self.state.status(task_id)
}
#[inline]
pub fn stats(&self) -> TaskExecutionStats {
self.state.stats()
}
#[inline]
pub fn suspend(&self) {
self.state.set_suspended(true);
}
#[inline]
pub fn resume(&self) {
self.state.set_suspended(false);
}
#[inline]
pub fn is_suspended(&self) -> bool {
self.state.is_suspended()
}
pub fn await_in_flight_tasks_completion(&self) {
self.state.await_in_flight_tasks_completion();
}
pub fn await_idle(&self) {
self.state.await_idle();
}
#[inline]
pub fn shutdown(&self) {
self.pool.shutdown();
}
#[inline]
pub fn stop(&self) -> StopReport {
self.pool.stop()
}
#[inline]
pub fn is_not_running(&self) -> bool {
self.pool.is_not_running()
}
#[inline]
pub fn is_terminated(&self) -> bool {
self.pool.is_terminated()
}
#[inline]
pub fn wait_termination(&self) {
self.pool.wait_termination();
}
#[inline]
pub fn thread_pool(&self) -> &ThreadPool {
&self.pool
}
}
#[derive(Default)]
struct TaskExecutionServiceState {
inner: Mutex<TaskExecutionServiceInner>,
idle: Condvar,
}
impl TaskExecutionServiceState {
fn lock_inner(&self) -> MutexGuard<'_, TaskExecutionServiceInner> {
self.inner
.lock()
.expect("task execution service state lock should not be poisoned")
}
fn register(
&self,
task_id: TaskId,
cancel: Arc<dyn Fn() -> bool + Send + Sync>,
) -> Result<(), TaskExecutionServiceError> {
let mut inner = self.lock_inner();
if inner.suspended {
return Err(TaskExecutionServiceError::Suspended);
}
if inner.tasks.contains_key(&task_id) {
return Err(TaskExecutionServiceError::DuplicateTask(task_id));
}
inner.tasks.insert(
task_id,
TaskRecord {
status: TaskStatus::Submitted,
cancel,
},
);
Ok(())
}
fn remove(&self, task_id: TaskId) {
let mut inner = self.lock_inner();
inner.tasks.remove(&task_id);
self.idle.notify_all();
}
fn status(&self, task_id: TaskId) -> Option<TaskStatus> {
self.lock_inner()
.tasks
.get(&task_id)
.map(|record| record.status)
}
fn cancel_callback(&self, task_id: TaskId) -> Option<Arc<dyn Fn() -> bool + Send + Sync>> {
let inner = self.lock_inner();
let record = inner.tasks.get(&task_id)?;
record
.status
.is_active()
.then(|| Arc::clone(&record.cancel))
}
fn set_status(&self, task_id: TaskId, status: TaskStatus) {
let mut inner = self.lock_inner();
let record = inner
.tasks
.get_mut(&task_id)
.expect("task status can only be updated for a registered task");
record.status = status;
self.idle.notify_all();
}
fn set_suspended(&self, suspended: bool) {
self.lock_inner().suspended = suspended;
}
fn is_suspended(&self) -> bool {
self.lock_inner().suspended
}
fn stats(&self) -> TaskExecutionStats {
let inner = self.lock_inner();
let mut stats = TaskExecutionStats::default();
for record in inner.tasks.values() {
stats.add_status(record.status);
}
stats
}
fn await_in_flight_tasks_completion(&self) {
let mut inner = self.lock_inner();
let task_ids = inner
.tasks
.iter()
.filter_map(|(&task_id, record)| record.status.is_active().then_some(task_id))
.collect::<Vec<_>>();
while task_ids
.iter()
.any(|task_id| inner.task_is_active(*task_id))
{
inner = self.wait_for_idle_notification(inner);
}
}
fn await_idle(&self) {
let mut inner = self.lock_inner();
while inner.has_active_tasks() {
inner = self.wait_for_idle_notification(inner);
}
}
fn wait_for_idle_notification<'a>(
&self,
inner: MutexGuard<'a, TaskExecutionServiceInner>,
) -> MutexGuard<'a, TaskExecutionServiceInner> {
self.idle
.wait(inner)
.expect("task execution service state lock should not be poisoned")
}
}
#[derive(Default)]
struct TaskExecutionServiceInner {
suspended: bool,
tasks: HashMap<TaskId, TaskRecord>,
}
impl TaskExecutionServiceInner {
fn task_is_active(&self, task_id: TaskId) -> bool {
self.tasks
.get(&task_id)
.is_some_and(|record| record.status.is_active())
}
fn has_active_tasks(&self) -> bool {
self.tasks.values().any(|record| record.status.is_active())
}
}
struct TaskRecord {
status: TaskStatus,
cancel: Arc<dyn Fn() -> bool + Send + Sync>,
}
struct StatusReportingTask<C> {
task_id: TaskId,
task: C,
state: Arc<TaskExecutionServiceState>,
}
impl<C, R, E> Callable<R, E> for StatusReportingTask<C>
where
C: Callable<R, E>,
{
fn call(&mut self) -> Result<R, E> {
self.state.set_status(self.task_id, TaskStatus::Running);
match catch_unwind(AssertUnwindSafe(|| self.task.call())) {
Ok(Ok(value)) => {
self.state.set_status(self.task_id, TaskStatus::Succeeded);
Ok(value)
}
Ok(Err(error)) => {
self.state.set_status(self.task_id, TaskStatus::Failed);
Err(error)
}
Err(payload) => {
self.state.set_status(self.task_id, TaskStatus::Panicked);
resume_unwind(payload);
}
}
}
}