use std::{
sync::{
Arc,
Mutex,
},
thread,
time::Duration,
};
use qubit_function::{
Callable,
Runnable,
};
use qubit_executor::{
TaskHandle,
TrackedTask,
task::spi::{
TaskEndpointPair,
TaskRunner,
},
};
use crate::tokio_executor_service_state::TokioExecutorServiceState;
use crate::tokio_service_task_guard::TokioServiceTaskGuard;
use qubit_executor::service::{
ExecutorService,
ExecutorServiceLifecycle,
StopReport,
SubmissionError,
};
#[derive(Default, Clone)]
pub struct TokioExecutorService {
state: Arc<TokioExecutorServiceState>,
}
pub type TokioBlockingExecutorService = TokioExecutorService;
impl TokioExecutorService {
#[inline]
pub fn new() -> Self {
Self::default()
}
}
impl ExecutorService for TokioExecutorService {
type ResultHandle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type TrackedHandle<R, E>
= TrackedTask<R, E>
where
R: Send + 'static,
E: Send + 'static;
fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
self.state.active_tasks.inc();
let marker = Arc::new(());
let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
let handle = tokio::task::spawn_blocking(move || {
let _guard = guard;
let mut task = task;
let runner = TaskRunner::new(move || task.run());
let _ = runner.call::<(), E>();
});
self.state
.register_abort_handle(marker, handle.abort_handle(), || {});
drop(submission_guard);
Ok(())
}
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
self.state.active_tasks.inc();
let (handle, completion) = TaskEndpointPair::new().into_parts();
completion.accept();
let completion = Arc::new(Mutex::new(Some(completion)));
let abort_completion = Arc::clone(&completion);
let marker = Arc::new(());
let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
let join_handle = tokio::task::spawn_blocking(move || {
let _guard = guard;
let completion = completion
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(completion) = completion {
TaskRunner::new(task).run(completion);
}
});
self.state
.register_abort_handle(marker, join_handle.abort_handle(), move || {
let completion = abort_completion
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(completion) = completion {
let _cancelled = completion.cancel_unstarted();
}
});
drop(submission_guard);
Ok(handle)
}
fn submit_tracked_callable<C, R, E>(
&self,
task: C,
) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
self.state.active_tasks.inc();
let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
completion.accept();
let completion = Arc::new(Mutex::new(Some(completion)));
let abort_completion = Arc::clone(&completion);
let marker = Arc::new(());
let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
let join_handle = tokio::task::spawn_blocking(move || {
let _guard = guard;
let completion = completion
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(completion) = completion {
TaskRunner::new(task).run(completion);
}
});
self.state
.register_abort_handle(marker, join_handle.abort_handle(), move || {
let completion = abort_completion
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(completion) = completion {
let _cancelled = completion.cancel_unstarted();
}
});
drop(submission_guard);
Ok(handle)
}
fn shutdown(&self) {
let _guard = self.state.lock_submission();
self.state.shutdown();
self.state.notify_if_terminated();
}
fn stop(&self) -> StopReport {
let _guard = self.state.lock_submission();
self.state.stop();
let running = self.state.active_tasks.get();
let cancellation_count = self.state.abort_tracked_tasks();
self.state.notify_if_terminated();
StopReport::new(0, running, cancellation_count)
}
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.state.lifecycle()
}
fn is_not_running(&self) -> bool {
self.state.is_not_running()
}
fn is_terminated(&self) -> bool {
self.lifecycle() == ExecutorServiceLifecycle::Terminated
}
fn wait_termination(&self) {
while !self.is_terminated() {
thread::sleep(Duration::from_millis(1));
}
}
}