use std::{
future::Future,
pin::Pin,
sync::{
Arc,
MutexGuard,
},
};
use qubit_function::{
Callable,
Runnable,
};
use qubit_executor::TaskHandle;
use qubit_executor::task::spi::{
TaskEndpointPair,
TaskRunner,
};
use crate::TokioBlockingTaskHandle;
use crate::tokio_executor_service_state::TokioExecutorServiceState;
use crate::tokio_runtime::ensure_tokio_runtime_entered;
use crate::tokio_service_task_guard::TokioServiceTaskGuard;
use crate::tokio_task_slot_cancellation::{
cancel_unstarted_task_slot_if_queued,
share_task_slot,
take_task_slot,
};
use qubit_executor::service::{
ExecutorService,
ExecutorServiceLifecycle,
StopReport,
SubmissionError,
};
use tokio::task::AbortHandle;
#[derive(Default, Clone)]
pub struct TokioExecutorService {
state: Arc<TokioExecutorServiceState>,
}
pub type TokioBlockingExecutorService = TokioExecutorService;
impl TokioExecutorService {
#[inline]
pub fn new() -> Self {
Self::default()
}
fn prepare_blocking_submission(
&self,
) -> Result<(MutexGuard<'_, ()>, Arc<()>, TokioServiceTaskGuard), SubmissionError> {
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
ensure_tokio_runtime_entered()?;
self.state.accept_task();
let marker = Arc::new(());
let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
Ok((submission_guard, marker, guard))
}
fn spawn_accepted_blocking_task<F, C>(
&self,
submission_guard: MutexGuard<'_, ()>,
marker: Arc<()>,
guard: TokioServiceTaskGuard,
task: F,
cancel: C,
) -> AbortHandle
where
F: FnOnce() + Send + 'static,
C: FnOnce() -> bool + Send + 'static,
{
let join_handle = tokio::task::spawn_blocking(move || {
let guard = guard;
if !guard.mark_started() {
return;
}
task();
});
let abort_handle = join_handle.abort_handle();
self.state
.register_abort_handle(marker, abort_handle.clone(), cancel);
drop(submission_guard);
abort_handle
}
}
impl ExecutorService for TokioExecutorService {
type ResultHandle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type TrackedHandle<R, E>
= TokioBlockingTaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
let (submission_guard, marker, guard) = self.prepare_blocking_submission()?;
let abort_queued_task = guard.finish_queued_once_callback();
self.spawn_accepted_blocking_task(
submission_guard,
marker,
guard,
move || {
let mut task = task;
let runner = TaskRunner::new(move || task.run());
let _ = runner.call::<(), E>();
},
abort_queued_task,
);
Ok(())
}
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (submission_guard, marker, guard) = self.prepare_blocking_submission()?;
let (handle, completion) = TaskEndpointPair::new().into_parts();
completion.accept();
let completion = share_task_slot(completion);
let abort_completion = Arc::clone(&completion);
let abort_queued_task = guard.finish_queued_once_callback();
self.spawn_accepted_blocking_task(
submission_guard,
marker,
guard,
move || {
if let Some(completion) = take_task_slot(&completion) {
TaskRunner::new(task).run(completion);
}
},
move || cancel_unstarted_task_slot_if_queued(&abort_completion, abort_queued_task),
);
Ok(handle)
}
fn submit_tracked_callable<C, R, E>(
&self,
task: C,
) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (submission_guard, marker, guard) = self.prepare_blocking_submission()?;
let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
completion.accept();
let completion = share_task_slot(completion);
let abort_completion = Arc::clone(&completion);
let abort_queued_task = guard.finish_queued_once_callback();
let cancel_queued_task = guard.cancel_queued_callback();
let abort_handle = self.spawn_accepted_blocking_task(
submission_guard,
marker,
guard,
move || {
if let Some(completion) = take_task_slot(&completion) {
TaskRunner::new(task).run(completion);
}
},
move || cancel_unstarted_task_slot_if_queued(&abort_completion, abort_queued_task),
);
Ok(TokioBlockingTaskHandle::new(
handle,
abort_handle,
cancel_queued_task,
))
}
fn shutdown(&self) {
let _guard = self.state.lock_submission();
self.state.shutdown();
self.state.notify_if_terminated();
}
fn stop(&self) -> StopReport {
let _guard = self.state.lock_submission();
self.state.stop();
let (queued_count, running_count) = self.state.task_count_snapshot();
let cancellation_count = self.state.abort_tracked_tasks();
self.state.notify_if_terminated();
StopReport::new(queued_count, running_count, cancellation_count)
}
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.state.lifecycle()
}
fn is_not_running(&self) -> bool {
self.state.is_not_running()
}
fn is_terminated(&self) -> bool {
self.lifecycle() == ExecutorServiceLifecycle::Terminated
}
fn wait_termination(&self) {
self.state.wait_termination();
}
}
impl TokioExecutorService {
pub fn await_termination(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async move {
let notified = self.state.terminated_notify.notified();
tokio::pin!(notified);
loop {
notified.as_mut().enable();
if self.is_terminated() {
return;
}
notified.as_mut().await;
notified.set(self.state.terminated_notify.notified());
}
})
}
}